Online Softmax
FlashAttention의 핵심 재료는 online softmax다.
m
l
O
왜 online인가
일반 softmax는 row 전체 score를 보고 max와 sum을 계산한다. 하지만 attention score row가 너무 크면 전체를 저장하는 비용이 커진다.
online softmax는 tile을 하나씩 보면서 다음 상태를 갱신한다.
m = running max
l = running sum(exp(score - m))
O = running output
새 tile에서 더 큰 max가 나오면 이전 l과 O를 새 max 기준으로 rescale한다.
확인
- online softmax가 저장하지 않는 것은 무엇인가?
- 새 tile의 max가 이전 max보다 크면 왜 rescale이 필요한가?
- FlashAttention에서
m,l,O는 어떤 역할을 하는가?