Attention Score Matrix

cudaattentiontransformermatmul

attention의 첫 단계는 query와 key의 dot product다.

S = Q K^T
A [4 x 3]
x
B [3 x 4]
=
C [4 x 4]

score matrix가 크다

sequence length가 N이면 score matrix는 [N, N]이다. token 수가 4096이면 score 원소는 약 1,677만 개다. head마다, batch마다 이 matrix가 생긴다.

naive 구현

naive attention은 보통 다음 순서로 생각한다.

QK^T 계산 -> score matrix 저장 -> softmax -> V와 곱함

이 방식은 이해하기 쉽지만 score matrix를 HBM에 쓰고 다시 읽는다.

확인

  • Q [N, d]K [N, d]가 만나면 왜 S [N, N]이 되는가?
  • sequence length가 2배가 되면 score matrix는 몇 배가 되는가?
  • FlashAttention은 이 score matrix 저장을 어떻게 피하려고 하는가?