JAX Attention Prefill vs Decode

마지막 수정:

jaxattentionprofilinginference

Attention benchmark는 seq_len 하나로 끝나지 않는다.

LLM serving에서는 attention workload가 둘로 갈라진다.

prefill: q_len = T, k_len = T
decode:  q_len = 1, k_len = T

JAX에서는 이 차이가 성능뿐 아니라 compile cache에도 영향을 준다. shape가 달라지면 다른 executable이 필요할 수 있기 때문이다.

python3 labs/jax-transformer/bench_attention.py --workload prefill --seq-len 256
python3 labs/jax-transformer/bench_attention.py --workload decode --seq-len 256

이 실습은 nano-vLLM과도 연결된다. prefill은 prompt 전체를 처리하는 단계이고, decode는 KV cache를 읽으며 한 token씩 생성하는 단계다.

확인

  • prefill과 decode에서 Q/K/V shape는 어떻게 달라지는가?
  • JAX에서 shape 변화가 compile cache miss로 이어질 수 있는 이유는 무엇인가?
  • 이 attention benchmark가 vLLM의 KV cache 이해와 연결되는 지점은 어디인가?