Softmax와 Value Mixing
score row 하나는 query token 하나가 모든 key token을 얼마나 볼지 나타낸다.
두 번째 matmul
P = softmax(S)
O = P V
P[row, col]은 row query가 col key/value를 보는 가중치다. output O[row]는 모든 V row의 weighted sum이다.
CUDA 관점
attention은 단순히 GEMM 두 번이 아니다. 가운데 softmax가 row 전체 reduction을 요구하고, 안정적인 softmax를 위해 max와 sum을 관리해야 한다.
확인
- attention에서 softmax는 어느 축으로 적용되는가?
O[row]를 만들 때 어떤 V row들이 필요한가?- naive attention은 왜
S와P를 memory에 저장하기 쉬운가?