FlashAttention Forward
FlashAttention forward는 attention 수식을 바꾸는 것이 아니라 memory access 방식을 바꾼다.
Q block
K tile
scores in SRAM
V tile
O update
수식은 그대로다
O = softmax(QK^T / sqrt(d)) V
하지만 구현은 전체 QK^T를 만들고 저장하지 않는다. Q block을 잡고, K/V tile을 하나씩 가져오며, online softmax 상태와 output을 갱신한다.
핵심 상태
m_i: row별 running max
l_i: row별 running normalizer
O_i: row별 partial output
이 상태들이 있으므로 tile을 지나간 뒤에도 전체 softmax와 같은 결과를 만들 수 있다.
확인
- FlashAttention은 근사 attention인가, 정확 attention인가?
- score matrix를 만들지 않는다는 말과 score tile을 계산하지 않는다는 말은 어떻게 다른가?
- forward에서 저장해야 하는 상태는 backward와 어떤 관련이 있는가?