Multi-Query Attention

transformerattentioninferencememory

Multi-query attention, 줄여서 MQA는 multi-head attention에서 query head는 여러 개 유지하고, key/value head는 하나만 쓰는 구조다.

MHA: query head N개, KV head N개
MQA: query head N개, KV head 1개

attention이 여러 관점으로 질문하는 능력은 query head 쪽에 남겨두고, cache에 저장해야 하는 K/V는 크게 줄인다.

MHA

query head마다 자기 K/V head를 가진다.

GQA

여러 query head가 group 단위로 K/V를 공유한다.

MQA

모든 query head가 하나의 K/V head를 공유한다.

MHA에서 MQA로 갈수록 query head 수는 유지하고 KV head 수만 줄어든다. GQA는 그 중간 지점이다.

왜 K/V head를 줄이나

decode에서는 새 token 하나의 Q가 과거 모든 token의 K/V cache를 읽는다.

매 decode step:
새 Q 1개
-> 과거 token들의 K/V cache 전체 참조

그래서 긴 context나 큰 batch에서는 KV cache 크기와 읽기 bandwidth가 중요해진다. MQA는 layer마다 저장하는 K/V head 수를 N개에서 1개로 줄인다.

예를 들어 query head가 N = 32, head 차원이 H = 128이라고 하자.

MHA의 K cache 폭: 32 * 128
MQA의 K cache 폭: 1 * 128

V도 똑같이 줄어든다. 그래서 KV cache 저장량과 decode에서 읽는 K/V 양이 크게 줄어든다.

무엇을 포기하나

MQA는 모든 query head가 같은 K/V head를 공유한다.

Q1, Q2, Q3, ... QN
  모두 같은 K/V를 본다.

이 구조는 cache와 bandwidth에는 좋지만, head마다 서로 다른 K/V 표현을 갖는 자유도는 줄어든다. 즉 MHA보다 메모리는 작고 빠를 수 있지만, 표현력 측면에서는 더 강한 제약을 건다.

언제 중요한가

MQA는 특히 decode 성능을 볼 때 중요하다.

prefill에서는 많은 token을 한 번에 처리하므로 큰 GEMM이 중심이다. decode에서는 token을 하나씩 만들고, 매 step마다 KV cache를 계속 읽는다.

decode 병목:
model weight 읽기
+ KV cache 읽기
+ 새 K/V cache 쓰기

MQA는 이 중 KV cache 쪽 비용을 줄이는 구조다. 그래서 MQA는 attention 수식만의 문제가 아니라 serving throughput, memory capacity, 긴 context 처리와 연결된다.

GQA와의 관계

MQA는 GQA의 극단적인 경우로 볼 수 있다.

MHA: K = N
GQA: 1 < K < N
MQA: K = 1

여기서 N은 query head 수, K는 key/value head 수다. MQA는 KV head를 가장 적게 쓰는 선택이다.

연결

확인

  • MQA에서 query head 수와 KV head 수는 각각 어떻게 되는가?
  • MQA가 KV cache 저장량을 줄이는 이유는 무엇인가?
  • MQA가 MHA보다 강하게 거는 제약은 무엇인가?
  • MQA가 prefill보다 decode에서 특히 중요해지는 이유는 무엇인가?