PyTorch Causal Self-Attention 구현

마지막 수정:

pytorchtransformerattentioncausal-mask

Causal self-attention은 각 token이 자기 자신과 이전 token만 보게 만든다.

큰 흐름은 다음과 같다.

x -> Q, K, V
QK^T -> causal mask -> softmax
softmax @ V -> output projection

구현할 때 핵심 shape는 head를 분리한 뒤의 모양이다.

x: [B, T, D]
q, k, v: [B, H, T, Dh]
scores: [B, H, T, T]
out: [B, T, D]

처음 구현은 빠를 필요가 없다. 올바른 shape와 causal mask가 더 중요하다.

확인

  • Causal mask가 없으면 decoder-only Transformer에서 어떤 문제가 생기는가?
  • attention score matrix의 shape는 무엇인가?
  • head dimension Dh와 model dimension D는 어떤 관계인가?