Multi-Head Attention

transformerattentionheads

Single-head attention은 token 간 참조 관계를 하나의 방식으로 계산한다. Multi-head attention은 이 과정을 여러 head로 나눠 병렬로 수행한다.

head 1: 한 종류의 관계를 본다.
head 2: 다른 종류의 관계를 본다.
head 3: 또 다른 종류의 관계를 본다.
...

모델 차원 D는 여러 head로 나뉜다.

X [B, T, D]
head 1 [B, T, H] syntax
head 2 [B, T, H] reference
head 3 [B, T, H] position
head 4 [B, T, H] topic
Output [B, T, D]
Multi-head attention은 모델 차원 D를 여러 head 차원 H로 나눠 attention을 병렬 수행한 뒤 다시 D차원으로 합친다.
D = N * H

여기서:

N = query head 수
H = head 하나의 차원

예를 들어 D = 4096, N = 32이면 head 하나의 차원은 H = 128이다.

왜 여러 head를 쓰나

문장 안에는 여러 종류의 관계가 있다.

주어와 동사의 관계
대명사가 가리키는 대상
이전 문장의 핵심 단어
코드에서 열린 괄호와 닫힌 괄호

하나의 attention 패턴만으로 모든 관계를 잘 잡기는 어렵다. 여러 head를 쓰면 head마다 다른 참조 패턴을 학습할 수 있다.

다시 합치기

각 head의 attention 결과는 다시 합쳐져 모델 차원 D로 돌아온다.

head outputs
  -> concatenate
  -> output projection
  -> [B, T, D]

그래서 block 바깥에서 보면 multi-head attention도 여전히 [B, T, D]를 입력받아 [B, T, D]를 출력한다.

확인

  • D = N * H에서 NH는 각각 무엇인가?
  • multi-head attention은 왜 single-head attention보다 표현력이 좋은가?
  • attention 결과는 block 밖으로 나갈 때 어떤 shape으로 돌아오는가?