PyTorch Causal Self-Attention 구현
마지막 수정:
Causal self-attention은 각 token이 자기 자신과 이전 token만 보게 만든다.
큰 흐름은 다음과 같다.
x -> Q, K, V
QK^T -> causal mask -> softmax
softmax @ V -> output projection
구현할 때 핵심 shape는 head를 분리한 뒤의 모양이다.
x: [B, T, D]
q, k, v: [B, H, T, Dh]
scores: [B, H, T, T]
out: [B, T, D]
처음 구현은 빠를 필요가 없다. 올바른 shape와 causal mask가 더 중요하다.
확인
- Causal mask가 없으면 decoder-only Transformer에서 어떤 문제가 생기는가?
- attention score matrix의 shape는 무엇인가?
- head dimension
Dh와 model dimensionD는 어떤 관계인가?