Grouped-Query Attention

transformerattentioninferencememory

Grouped-query attention, 줄여서 GQA는 MHA와 MQA 사이의 절충안이다.

한 문장으로 말하면:

여러 query head가 group을 이루고,
같은 group 안에서는 하나의 K/V head를 공유한다.

MHA

query head마다 자기 K/V head를 가진다.

GQA

여러 query head가 group 단위로 K/V를 공유한다.

MQA

모든 query head가 하나의 K/V head를 공유한다.

MHA에서 MQA로 갈수록 query head 수는 유지하고 KV head 수만 줄어든다. GQA는 그 중간 지점이다.

세 구조를 나란히 보기

N을 query head 수, K를 key/value head 수라고 하자.

MHA: K = N
GQA: 1 < K < N
MQA: K = 1

예를 들어 N = 32인 모델에서 K = 8로 잡으면, KV head 하나가 query head 4개를 담당한다.

N = 32
K = 8
G = N / K = 4

KV head 1개당 query head 4개가 묶인다.

여기서 G는 group 크기다.

왜 MHA와 MQA 사이가 필요한가

MHA는 query head마다 K/V head를 따로 둔다.

장점: head별 표현 자유도가 크다.
단점: KV cache가 크다.

MQA는 모든 query head가 하나의 K/V head를 공유한다.

장점: KV cache가 작다.
단점: K/V 표현 공유가 너무 강하다.

GQA는 이 둘 사이에서 KV cache를 줄이되, K/V head를 하나만 남기는 극단까지는 가지 않는다.

MHA보다 KV cache가 작다.
MQA보다 K/V 표현 자유도가 크다.

Tensor shape으로 보기

Transformer 표기법에서 Q, K, V는 이렇게 볼 수 있다.

Q: [B, T, K, G, H]
K: [B, S, K, H]
V: [B, S, K, H]

K는 KV head 수이고, G는 KV head 하나가 담당하는 query head 수다. 전체 query head 수는 다음과 같다.

N = K * G

MHA에서는 G = 1이라 query head마다 KV head가 하나씩 있다. MQA에서는 K = 1이라 모든 query head가 하나의 KV head를 공유한다.

추론 시스템 관점

GQA는 모델 구조이지만, 효과는 serving 시스템에서 크게 드러난다.

KV cache는 보통 다음 크기에 비례한다.

layers * sequence length * KV head 수 * head dimension

GQA는 여기서 KV head 수를 줄인다. 같은 GPU memory에서 더 긴 context를 담거나, 더 많은 요청을 batch에 올릴 여지가 생긴다.

그래서 GQA는 단순히 “attention head를 다르게 묶는 기법”이 아니라, decode memory bandwidth와 serving capacity를 직접 건드리는 구조다.

연결

확인

  • GQA에서 N, K, G는 각각 무엇인가?
  • N = 32, K = 8이면 KV head 하나당 query head는 몇 개인가?
  • GQA가 MHA보다 KV cache를 줄이는 이유는 무엇인가?
  • GQA가 MQA보다 덜 극단적인 절충안인 이유는 무엇인가?