Multi-Head Attention
Single-head attention은 token 간 참조 관계를 하나의 방식으로 계산한다. Multi-head attention은 이 과정을 여러 head로 나눠 병렬로 수행한다.
head 1: 한 종류의 관계를 본다.
head 2: 다른 종류의 관계를 본다.
head 3: 또 다른 종류의 관계를 본다.
...
모델 차원 D는 여러 head로 나뉜다.
[B, T, D] [B, T, H] syntax [B, T, H] reference [B, T, H] position [B, T, H] topic [B, T, D] D = N * H
여기서:
N = query head 수
H = head 하나의 차원
예를 들어 D = 4096, N = 32이면 head 하나의 차원은 H = 128이다.
왜 여러 head를 쓰나
문장 안에는 여러 종류의 관계가 있다.
주어와 동사의 관계
대명사가 가리키는 대상
이전 문장의 핵심 단어
코드에서 열린 괄호와 닫힌 괄호
하나의 attention 패턴만으로 모든 관계를 잘 잡기는 어렵다. 여러 head를 쓰면 head마다 다른 참조 패턴을 학습할 수 있다.
다시 합치기
각 head의 attention 결과는 다시 합쳐져 모델 차원 D로 돌아온다.
head outputs
-> concatenate
-> output projection
-> [B, T, D]
그래서 block 바깥에서 보면 multi-head attention도 여전히 [B, T, D]를 입력받아 [B, T, D]를 출력한다.
확인
D = N * H에서N과H는 각각 무엇인가?- multi-head attention은 왜 single-head attention보다 표현력이 좋은가?
- attention 결과는 block 밖으로 나갈 때 어떤 shape으로 돌아오는가?