Causal Mask와 Tiled Attention

cudaattentioncausal-maskflashattention

LLM decoder는 현재 token이 미래 token을 보면 안 된다. 그래서 causal mask가 필요하다.

Q block
K tile
scores in SRAM
V tile
O update

naive attention의 mask

naive 구현에서는 score matrix를 만든 뒤 upper triangle에 -inf를 넣고 softmax를 적용한다.

S[i, j] = -inf if j > i

tiled attention의 mask

FlashAttention에서는 전체 score matrix가 없으므로 tile을 계산할 때 현재 Q row와 K column의 위치를 보고 mask를 적용한다.

if (key_col > query_row) {
    score = -INFINITY;
}

확인

  • causal mask는 어느 위치의 score를 막는가?
  • score matrix를 저장하지 않아도 mask를 적용할 수 있는 이유는 무엇인가?
  • prefill과 decode에서 attention shape는 어떻게 달라지는가?