Mamba Prefill과 Decode

마지막 수정:

ssmmambainferenceprefilldecode

Transformer에서도 prefill과 decode는 다르다. Mamba/SSM에서도 다르다.

다만 달라지는 지점이 KV cache가 아니라 state다.

Prefill

prefill은 prompt 전체를 읽는 구간이다.

Transformer에서는 prompt token들의 KV cache를 채운다.

Mamba에서는 prompt를 통과시키며 request의 state를 만든다.

prompt tokens
  -> recurrent / scan computation
  -> final conv state + SSM state

이 state가 decode의 시작점이 된다.

Decode

decode는 새 token을 하나씩 만들며 state를 갱신하는 구간이다.

state_{t-1} + new_token_t -> state_t -> logits

Transformer decode가 past KV를 계속 읽는다면, Mamba decode는 현재 state를 읽고 갱신한다.

왜 scan이 중요해지는가

recurrent하게 보면 SSM은 token을 순서대로 처리한다.

하지만 GPU에서는 긴 sequence를 순수 for-loop처럼 처리하면 느리다. 그래서 prefill에서는 scan/parallel scan 계열의 계산이 중요해진다.

모델 관점:
state를 순서대로 갱신한다.

GPU 관점:
가능한 한 병렬화된 scan으로 prompt를 처리해야 한다.

이 지점에서 Mamba는 모델 구조와 커널 구현이 강하게 연결된다.

연결

확인

  • Mamba prefill의 산출물은 무엇인가?
  • Mamba decode에서 매 step 갱신되는 것은 무엇인가?
  • SSM prefill에서 scan/parallel scan이 중요해지는 이유는 무엇인가?