JAX Attention Prefill vs Decode
마지막 수정:
Attention benchmark는 seq_len 하나로 끝나지 않는다.
LLM serving에서는 attention workload가 둘로 갈라진다.
prefill: q_len = T, k_len = T
decode: q_len = 1, k_len = T
JAX에서는 이 차이가 성능뿐 아니라 compile cache에도 영향을 준다. shape가 달라지면 다른 executable이 필요할 수 있기 때문이다.
python3 labs/jax-transformer/bench_attention.py --workload prefill --seq-len 256
python3 labs/jax-transformer/bench_attention.py --workload decode --seq-len 256
이 실습은 nano-vLLM과도 연결된다. prefill은 prompt 전체를 처리하는 단계이고, decode는 KV cache를 읽으며 한 token씩 생성하는 단계다.
확인
- prefill과 decode에서 Q/K/V shape는 어떻게 달라지는가?
- JAX에서 shape 변화가 compile cache miss로 이어질 수 있는 이유는 무엇인가?
- 이 attention benchmark가 vLLM의 KV cache 이해와 연결되는 지점은 어디인가?