Grouped-Query Attention
Grouped-query attention, 줄여서 GQA는 MHA와 MQA 사이의 절충안이다.
한 문장으로 말하면:
여러 query head가 group을 이루고,
같은 group 안에서는 하나의 K/V head를 공유한다.
MHA
query head마다 자기 K/V head를 가진다.
GQA
여러 query head가 group 단위로 K/V를 공유한다.
MQA
모든 query head가 하나의 K/V head를 공유한다.
세 구조를 나란히 보기
N을 query head 수, K를 key/value head 수라고 하자.
MHA: K = N
GQA: 1 < K < N
MQA: K = 1
예를 들어 N = 32인 모델에서 K = 8로 잡으면, KV head 하나가 query head 4개를 담당한다.
N = 32
K = 8
G = N / K = 4
KV head 1개당 query head 4개가 묶인다.
여기서 G는 group 크기다.
왜 MHA와 MQA 사이가 필요한가
MHA는 query head마다 K/V head를 따로 둔다.
장점: head별 표현 자유도가 크다.
단점: KV cache가 크다.
MQA는 모든 query head가 하나의 K/V head를 공유한다.
장점: KV cache가 작다.
단점: K/V 표현 공유가 너무 강하다.
GQA는 이 둘 사이에서 KV cache를 줄이되, K/V head를 하나만 남기는 극단까지는 가지 않는다.
MHA보다 KV cache가 작다.
MQA보다 K/V 표현 자유도가 크다.
Tensor shape으로 보기
Transformer 표기법에서 Q, K, V는 이렇게 볼 수 있다.
Q: [B, T, K, G, H]
K: [B, S, K, H]
V: [B, S, K, H]
K는 KV head 수이고, G는 KV head 하나가 담당하는 query head 수다. 전체 query head 수는 다음과 같다.
N = K * G
MHA에서는 G = 1이라 query head마다 KV head가 하나씩 있다. MQA에서는 K = 1이라 모든 query head가 하나의 KV head를 공유한다.
추론 시스템 관점
GQA는 모델 구조이지만, 효과는 serving 시스템에서 크게 드러난다.
KV cache는 보통 다음 크기에 비례한다.
layers * sequence length * KV head 수 * head dimension
GQA는 여기서 KV head 수를 줄인다. 같은 GPU memory에서 더 긴 context를 담거나, 더 많은 요청을 batch에 올릴 여지가 생긴다.
그래서 GQA는 단순히 “attention head를 다르게 묶는 기법”이 아니라, decode memory bandwidth와 serving capacity를 직접 건드리는 구조다.
연결
- transformer-notation:
N,K,G,H표기 - multi-head-attention: GQA가 절충하는 기준점
- multi-query-attention: GQA의 한쪽 극단
- kv-cache: GQA가 줄이는 메모리 대상
확인
- GQA에서
N,K,G는 각각 무엇인가? N = 32,K = 8이면 KV head 하나당 query head는 몇 개인가?- GQA가 MHA보다 KV cache를 줄이는 이유는 무엇인가?
- GQA가 MQA보다 덜 극단적인 절충안인 이유는 무엇인가?