Multi-Query Attention
Multi-query attention, 줄여서 MQA는 multi-head attention에서 query head는 여러 개 유지하고, key/value head는 하나만 쓰는 구조다.
MHA: query head N개, KV head N개
MQA: query head N개, KV head 1개
attention이 여러 관점으로 질문하는 능력은 query head 쪽에 남겨두고, cache에 저장해야 하는 K/V는 크게 줄인다.
MHA
query head마다 자기 K/V head를 가진다.
GQA
여러 query head가 group 단위로 K/V를 공유한다.
MQA
모든 query head가 하나의 K/V head를 공유한다.
왜 K/V head를 줄이나
decode에서는 새 token 하나의 Q가 과거 모든 token의 K/V cache를 읽는다.
매 decode step:
새 Q 1개
-> 과거 token들의 K/V cache 전체 참조
그래서 긴 context나 큰 batch에서는 KV cache 크기와 읽기 bandwidth가 중요해진다. MQA는 layer마다 저장하는 K/V head 수를 N개에서 1개로 줄인다.
예를 들어 query head가 N = 32, head 차원이 H = 128이라고 하자.
MHA의 K cache 폭: 32 * 128
MQA의 K cache 폭: 1 * 128
V도 똑같이 줄어든다. 그래서 KV cache 저장량과 decode에서 읽는 K/V 양이 크게 줄어든다.
무엇을 포기하나
MQA는 모든 query head가 같은 K/V head를 공유한다.
Q1, Q2, Q3, ... QN
모두 같은 K/V를 본다.
이 구조는 cache와 bandwidth에는 좋지만, head마다 서로 다른 K/V 표현을 갖는 자유도는 줄어든다. 즉 MHA보다 메모리는 작고 빠를 수 있지만, 표현력 측면에서는 더 강한 제약을 건다.
언제 중요한가
MQA는 특히 decode 성능을 볼 때 중요하다.
prefill에서는 많은 token을 한 번에 처리하므로 큰 GEMM이 중심이다. decode에서는 token을 하나씩 만들고, 매 step마다 KV cache를 계속 읽는다.
decode 병목:
model weight 읽기
+ KV cache 읽기
+ 새 K/V cache 쓰기
MQA는 이 중 KV cache 쪽 비용을 줄이는 구조다. 그래서 MQA는 attention 수식만의 문제가 아니라 serving throughput, memory capacity, 긴 context 처리와 연결된다.
GQA와의 관계
MQA는 GQA의 극단적인 경우로 볼 수 있다.
MHA: K = N
GQA: 1 < K < N
MQA: K = 1
여기서 N은 query head 수, K는 key/value head 수다. MQA는 KV head를 가장 적게 쓰는 선택이다.
연결
- multi-head-attention: MQA가 줄이는 대상
- kv-cache: MQA가 decode memory를 줄이는 이유
- grouped-query-attention: MQA와 MHA 사이의 절충안
확인
- MQA에서 query head 수와 KV head 수는 각각 어떻게 되는가?
- MQA가 KV cache 저장량을 줄이는 이유는 무엇인가?
- MQA가 MHA보다 강하게 거는 제약은 무엇인가?
- MQA가 prefill보다 decode에서 특히 중요해지는 이유는 무엇인가?