FlashAttention IO Analysis

cudaattentionflashattentionmemory-io

FlashAttention은 계산량을 마법처럼 없애는 기법이 아니다. 핵심은 HBM traffic을 줄이는 것이다.

Q block
K tile
scores in SRAM
V tile
O update

naive attention의 memory traffic

write S = QK^T
read S for softmax
write P = softmax(S)
read P for PV

[N, N] matrix를 쓰고 읽는 비용이 크다.

FlashAttention의 차이

FlashAttention은 score tile을 SRAM/register 근처에서 처리하고, HBM에는 최종 output과 필요한 작은 상태만 쓴다.

큰 중간 matrix write/read를 줄인다.

확인

  • FlashAttention의 이득은 FLOP 감소인가, memory IO 감소인가?
  • sequence length가 길어질수록 왜 IO 문제가 커지는가?
  • SRAM/shared memory가 작은데도 tiling으로 도움이 되는 이유는 무엇인가?