FlashAttention Forward

cudaattentionflashattentionoptimization

FlashAttention forward는 attention 수식을 바꾸는 것이 아니라 memory access 방식을 바꾼다.

Q block
K tile
scores in SRAM
V tile
O update

수식은 그대로다

O = softmax(QK^T / sqrt(d)) V

하지만 구현은 전체 QK^T를 만들고 저장하지 않는다. Q block을 잡고, K/V tile을 하나씩 가져오며, online softmax 상태와 output을 갱신한다.

핵심 상태

m_i: row별 running max
l_i: row별 running normalizer
O_i: row별 partial output

이 상태들이 있으므로 tile을 지나간 뒤에도 전체 softmax와 같은 결과를 만들 수 있다.

확인

  • FlashAttention은 근사 attention인가, 정확 attention인가?
  • score matrix를 만들지 않는다는 말과 score tile을 계산하지 않는다는 말은 어떻게 다른가?
  • forward에서 저장해야 하는 상태는 backward와 어떤 관련이 있는가?