FlashAttention IO Analysis
FlashAttention은 계산량을 마법처럼 없애는 기법이 아니다. 핵심은 HBM traffic을 줄이는 것이다.
Q block
K tile
scores in SRAM
V tile
O update
naive attention의 memory traffic
write S = QK^T
read S for softmax
write P = softmax(S)
read P for PV
큰 [N, N] matrix를 쓰고 읽는 비용이 크다.
FlashAttention의 차이
FlashAttention은 score tile을 SRAM/register 근처에서 처리하고, HBM에는 최종 output과 필요한 작은 상태만 쓴다.
큰 중간 matrix write/read를 줄인다.
확인
- FlashAttention의 이득은 FLOP 감소인가, memory IO 감소인가?
- sequence length가 길어질수록 왜 IO 문제가 커지는가?
- SRAM/shared memory가 작은데도 tiling으로 도움이 되는 이유는 무엇인가?