Softmax와 Value Mixing

cudaattentionsoftmaxvalue

score row 하나는 query token 하나가 모든 key token을 얼마나 볼지 나타낸다.

두 번째 matmul

P = softmax(S)
O = P V

P[row, col]은 row query가 col key/value를 보는 가중치다. output O[row]는 모든 V row의 weighted sum이다.

CUDA 관점

attention은 단순히 GEMM 두 번이 아니다. 가운데 softmax가 row 전체 reduction을 요구하고, 안정적인 softmax를 위해 max와 sum을 관리해야 한다.

확인

  • attention에서 softmax는 어느 축으로 적용되는가?
  • O[row]를 만들 때 어떤 V row들이 필요한가?
  • naive attention은 왜 SP를 memory에 저장하기 쉬운가?