DeepSeek MLA

transformerattentioninferencedeepseekkv-cache

MLA는 Multi-head Latent Attention의 약자다.

한 문장으로 말하면:

head별 K/V 전체를 cache하지 않고,
작은 latent vector를 cache한 뒤 필요할 때 K/V처럼 펼쳐 쓴다.

일반 multi-head attention에서는 decode 중 과거 token마다 K와 V를 저장한다.

token t의 cache:
K_t: [N, H]
V_t: [N, H]

여기서 N은 head 수, H는 head 하나의 차원이다. context가 길어지고 batch가 커지면 이 KV cache가 GPU memory와 memory bandwidth를 크게 잡아먹는다.

MLA는 이 지점에서 질문을 바꾼다.

정말 과거 token마다 모든 head의 K/V를 그대로 저장해야 할까?

핵심 아이디어

MLA는 K와 V를 바로 만들고 저장하는 대신, 먼저 더 작은 latent vector를 만든다.

hidden state
  -> down projection
  -> compressed KV latent

cache에는 이 압축된 latent를 저장한다.

저장한다:
c_t^KV

바로 저장하지 않는다:
head별 전체 K_t, V_t

attention 계산이 필요할 때는 이 latent를 다시 projection해서 각 head가 사용할 K/V 표현을 만든다.

c_t^KV
  -> up projection for K
  -> up projection for V

그래서 MLA는 memory를 아끼는 대신 projection 계산을 더 한다.

덜 읽는다: KV cache bytes
더 계산한다: latent -> K/V projection

decode는 자주 memory-bound가 되기 때문에, 이 trade-off가 유리할 수 있다. GPU가 기다리던 HBM read를 줄이고, 남는 compute를 더 쓰는 방향으로 병목을 옮기는 셈이다.

MHA와 비교

일반 MHA의 cache는 token마다 대략 이렇게 커진다.

MHA cache per token ~= K + V
                    ~= 2 * N * H

MLA의 cache는 head별 K/V 전체가 아니라 compact latent 중심이다.

MLA cache per token ~= compressed KV latent + positional part

DeepSeek-V2 논문은 MLA를 통해 KV cache를 크게 줄였다고 보고한다. 숫자 자체보다 중요한 것은 방향이다.

MHA: 과거 token의 K/V를 크게 저장한다.
MLA: 과거 token을 작은 latent로 저장하고, 필요할 때 펼친다.

RoPE 때문에 생기는 문제

여기서 한 가지 문제가 생긴다.

Transformer attention은 보통 RoPE 같은 positional encoding을 Q/K에 적용한다. 그런데 K를 latent로 압축해서 저장하면, 위치 정보가 섞인 K를 나중에 깔끔하게 복원하기 어렵다.

DeepSeek의 MLA는 이를 위해 content 부분과 positional 부분을 분리한다.

content part:
compressed KV latent로 다룬다.

positional part:
RoPE가 필요한 작은 key 조각으로 따로 둔다.

attention score는 개념적으로 두 항을 합친다.

content score    = query content dot key content
positional score = query position dot key position

attention score = content score + positional score

즉 MLA는 단순히 “KV cache를 압축한다”에서 끝나지 않는다. RoPE와 함께 쓰기 위해 attention score를 content 경로와 positional 경로로 나눠 생각한다.

왜 추론에서 중요한가

LLM serving에서 decode는 token을 하나씩 생성한다.

이때 매 step마다 GPU는:

1. model weight를 읽는다.
2. 과거 token의 KV cache를 읽는다.
3. 새 token의 logits를 만든다.

context가 길거나 batch가 커질수록 KV cache read가 점점 중요해진다. MLA는 이 read 자체를 줄이는 구조적 방법이다.

PagedAttention이 “KV cache를 어떻게 배치하고 관리할 것인가”의 문제라면, MLA는 “KV cache에 무엇을 저장할 것인가”의 문제다.

PagedAttention:
cache layout / memory manager 문제

MLA:
cache representation / attention architecture 문제

둘은 경쟁 관계가 아니라 서로 다른 층의 최적화다.

MQA/GQA와 어떻게 다른가

MLA는 MQA나 GQA와 비슷해 보일 수 있다. 모두 decode에서 KV cache를 줄이려는 구조이기 때문이다.

MQA/GQA는 여러 query head가 더 적은 수의 KV head를 공유하게 해서 cache를 줄인다.

MQA/GQA:
KV head 수를 줄인다.

MLA는 K/V를 작은 latent로 압축해 저장하고, projection으로 head별 표현을 만든다.

MLA:
KV cache 표현 자체를 latent로 바꾼다.

둘 다 KV cache를 줄이지만, 줄이는 방식이 다르다.

연결

참고

확인

  • MLA는 일반 MHA와 달리 cache에 무엇을 저장하는가?
  • MLA가 memory bandwidth를 줄이는 대신 더 쓰는 자원은 무엇인가?
  • RoPE 때문에 MLA가 content part와 positional part를 나눠야 하는 이유는 무엇인가?
  • PagedAttention과 MLA는 각각 KV cache 문제의 어느 층을 다루는가?