Mamba Prefill과 Decode
마지막 수정:
Transformer에서도 prefill과 decode는 다르다. Mamba/SSM에서도 다르다.
다만 달라지는 지점이 KV cache가 아니라 state다.
Prefill
prefill은 prompt 전체를 읽는 구간이다.
Transformer에서는 prompt token들의 KV cache를 채운다.
Mamba에서는 prompt를 통과시키며 request의 state를 만든다.
prompt tokens
-> recurrent / scan computation
-> final conv state + SSM state
이 state가 decode의 시작점이 된다.
Decode
decode는 새 token을 하나씩 만들며 state를 갱신하는 구간이다.
state_{t-1} + new_token_t -> state_t -> logits
Transformer decode가 past KV를 계속 읽는다면, Mamba decode는 현재 state를 읽고 갱신한다.
왜 scan이 중요해지는가
recurrent하게 보면 SSM은 token을 순서대로 처리한다.
하지만 GPU에서는 긴 sequence를 순수 for-loop처럼 처리하면 느리다. 그래서 prefill에서는 scan/parallel scan 계열의 계산이 중요해진다.
모델 관점:
state를 순서대로 갱신한다.
GPU 관점:
가능한 한 병렬화된 scan으로 prompt를 처리해야 한다.
이 지점에서 Mamba는 모델 구조와 커널 구현이 강하게 연결된다.
연결
- prefill-vs-decode: Transformer serving의 prefill/decode 구분
- ssm-state-vs-kv-cache: KV cache와 SSM state cache 비교
- mamba-selective-state-update-kernel: decode state update를 빠르게 하는 커널
확인
- Mamba prefill의 산출물은 무엇인가?
- Mamba decode에서 매 step 갱신되는 것은 무엇인가?
- SSM prefill에서 scan/parallel scan이 중요해지는 이유는 무엇인가?