Mamba Selective State Update Kernel

마지막 수정:

mambassmkerneltritoninference

Mamba decode의 핵심 연산 중 하나는 selective state update다.

한 step을 단순화하면 이렇게 볼 수 있다.

현재 token에서 update 파라미터를 만든다.
이전 SSM state를 읽는다.
state를 갱신한다.
출력을 만든다.

이 과정은 매 decode step마다, 모든 active request에 대해 반복된다.

왜 커널 문제가 되는가

Transformer serving에서 attention kernel이 중요하듯, Mamba serving에서는 state update kernel이 중요해진다.

decode에서는 batch 안의 request들이 서로 다른 길이와 state를 가진다. 커널은 각 request의 state 위치를 찾아 읽고, 업데이트하고, 다시 써야 한다.

read state
compute update
write state

이 연산은 단순한 수식보다 memory layout, dtype, batching, GPU occupancy와 강하게 연결된다.

vLLM에서 볼 지점

vLLM은 Mamba selective state update backend를 설정할 수 있다.

triton
flashinfer

이것은 Mamba 지원이 단순히 Python model class 추가가 아니라, serving kernel과 cache layout까지 포함한다는 신호다.

연결

확인

  • selective state update는 decode에서 어떤 일을 하는가?
  • 이 연산이 batching과 memory layout의 영향을 받는 이유는 무엇인가?
  • vLLM이 Mamba backend를 따로 두는 것이 의미하는 바는 무엇인가?