JAX RMSNorm and Attention Benchmark

마지막 수정:

jaxprofilingrmsnormattentionbenchmark

PyTorch profiling path에서는 eager RMSNorm, custom CUDA extension, SDPA를 비교했다. JAX에서는 먼저 같은 연산을 jit boundary 안에 넣고 steady-state time을 본다.

JAX Timing Model First call compiles. Later calls execute cached XLA programs asynchronously.
jit(fn)(static shapes)
call 1 trace Python lower to XLA compile execute
call 2..N cache lookup enqueue execute block_until_ready
Benchmark rule output.block_until_ready() Without this, you often measure enqueue time.
Compile rule shape / dtype / static args Changing these can trigger a new compiled program.
JAX profiling starts by separating compile time, enqueue time, and real device execution time.

RMSNorm benchmark는 작지만 좋은 시작점이다.

python3 labs/jax-transformer/bench_rmsnorm.py

Attention benchmark는 materialized score matrix를 쓰는 manual causal attention이다.

python3 labs/jax-transformer/bench_attention.py

이 둘의 목적은 “JAX가 자동으로 빠르다”를 증명하는 것이 아니다. 목적은 다음 질문을 분리하는 것이다.

이 연산은 XLA fusion으로 충분한가?
shape가 바뀔 때 compile cache가 깨지는가?
attention score matrix materialization이 병목인가?
Pallas/custom kernel 후보가 될 만큼 문제가 좁혀졌는가?

JAX에서는 custom kernel로 바로 내려가기 전에 XLA가 만든 compiled program을 먼저 확인해야 한다. 많은 elementwise chain은 XLA fusion으로 이미 사라질 수 있다.