Causal Mask와 Tiled Attention
LLM decoder는 현재 token이 미래 token을 보면 안 된다. 그래서 causal mask가 필요하다.
Q block
K tile
scores in SRAM
V tile
O update
naive attention의 mask
naive 구현에서는 score matrix를 만든 뒤 upper triangle에 -inf를 넣고 softmax를 적용한다.
S[i, j] = -inf if j > i
tiled attention의 mask
FlashAttention에서는 전체 score matrix가 없으므로 tile을 계산할 때 현재 Q row와 K column의 위치를 보고 mask를 적용한다.
if (key_col > query_row) {
score = -INFINITY;
}
확인
- causal mask는 어느 위치의 score를 막는가?
- score matrix를 저장하지 않아도 mask를 적용할 수 있는 이유는 무엇인가?
- prefill과 decode에서 attention shape는 어떻게 달라지는가?