Mamba Selective State Update Kernel
마지막 수정:
Mamba decode의 핵심 연산 중 하나는 selective state update다.
한 step을 단순화하면 이렇게 볼 수 있다.
현재 token에서 update 파라미터를 만든다.
이전 SSM state를 읽는다.
state를 갱신한다.
출력을 만든다.
이 과정은 매 decode step마다, 모든 active request에 대해 반복된다.
왜 커널 문제가 되는가
Transformer serving에서 attention kernel이 중요하듯, Mamba serving에서는 state update kernel이 중요해진다.
decode에서는 batch 안의 request들이 서로 다른 길이와 state를 가진다. 커널은 각 request의 state 위치를 찾아 읽고, 업데이트하고, 다시 써야 한다.
read state
compute update
write state
이 연산은 단순한 수식보다 memory layout, dtype, batching, GPU occupancy와 강하게 연결된다.
vLLM에서 볼 지점
vLLM은 Mamba selective state update backend를 설정할 수 있다.
triton
flashinfer
이것은 Mamba 지원이 단순히 Python model class 추가가 아니라, serving kernel과 cache layout까지 포함한다는 신호다.
연결
- mamba-prefill-and-decode: decode에서 state update가 반복되는 이유
- vllm-mamba-cache-spec: update할 state가 cache manager 안에서 어떻게 표현되는지
- cuda-transformer-kernel-map: Transformer 쪽 kernel map과 비교하기
확인
- selective state update는 decode에서 어떤 일을 하는가?
- 이 연산이 batching과 memory layout의 영향을 받는 이유는 무엇인가?
- vLLM이 Mamba backend를 따로 두는 것이 의미하는 바는 무엇인가?