Attention Score Matrix
attention의 첫 단계는 query와 key의 dot product다.
S = Q K^T
score matrix가 크다
sequence length가 N이면 score matrix는 [N, N]이다. token 수가 4096이면 score 원소는 약 1,677만 개다. head마다, batch마다 이 matrix가 생긴다.
naive 구현
naive attention은 보통 다음 순서로 생각한다.
QK^T 계산 -> score matrix 저장 -> softmax -> V와 곱함
이 방식은 이해하기 쉽지만 score matrix를 HBM에 쓰고 다시 읽는다.
확인
Q [N, d]와K [N, d]가 만나면 왜S [N, N]이 되는가?- sequence length가 2배가 되면 score matrix는 몇 배가 되는가?
- FlashAttention은 이 score matrix 저장을 어떻게 피하려고 하는가?