Online Softmax

cudaattentionsoftmaxflashattention

FlashAttention의 핵심 재료는 online softmax다.

m
l
O

왜 online인가

일반 softmax는 row 전체 score를 보고 max와 sum을 계산한다. 하지만 attention score row가 너무 크면 전체를 저장하는 비용이 커진다.

online softmax는 tile을 하나씩 보면서 다음 상태를 갱신한다.

m = running max
l = running sum(exp(score - m))
O = running output

새 tile에서 더 큰 max가 나오면 이전 lO를 새 max 기준으로 rescale한다.

확인

  • online softmax가 저장하지 않는 것은 무엇인가?
  • 새 tile의 max가 이전 max보다 크면 왜 rescale이 필요한가?
  • FlashAttention에서 m, l, O는 어떤 역할을 하는가?