JAX RMSNorm and Attention Benchmark
마지막 수정:
PyTorch profiling path에서는 eager RMSNorm, custom CUDA extension, SDPA를 비교했다. JAX에서는 먼저 같은 연산을 jit boundary 안에 넣고 steady-state time을 본다.
JAX Timing Model First call compiles. Later calls execute cached XLA programs asynchronously.
jit(fn)(static shapes) call 1 trace Python lower to XLA compile execute
call 2..N cache lookup enqueue execute block_until_ready
Benchmark rule
output.block_until_ready() Without this, you often measure enqueue time. Compile rule
shape / dtype / static args Changing these can trigger a new compiled program. RMSNorm benchmark는 작지만 좋은 시작점이다.
python3 labs/jax-transformer/bench_rmsnorm.py
Attention benchmark는 materialized score matrix를 쓰는 manual causal attention이다.
python3 labs/jax-transformer/bench_attention.py
이 둘의 목적은 “JAX가 자동으로 빠르다”를 증명하는 것이 아니다. 목적은 다음 질문을 분리하는 것이다.
이 연산은 XLA fusion으로 충분한가?
shape가 바뀔 때 compile cache가 깨지는가?
attention score matrix materialization이 병목인가?
Pallas/custom kernel 후보가 될 만큼 문제가 좁혀졌는가?
JAX에서는 custom kernel로 바로 내려가기 전에 XLA가 만든 compiled program을 먼저 확인해야 한다. 많은 elementwise chain은 XLA fusion으로 이미 사라질 수 있다.