JAX jit Compile Boundary
마지막 수정:
JAX profiling의 첫 함정은 “첫 호출이 느리다”는 것이다. 첫 호출은 계산만 하는 것이 아니라 tracing과 compilation을 포함한다.
JAX Timing Model First call compiles. Later calls execute cached XLA programs asynchronously.
jit(fn)(static shapes) call 1 trace Python lower to XLA compile execute
call 2..N cache lookup enqueue execute block_until_ready
Benchmark rule
output.block_until_ready() Without this, you often measure enqueue time. Compile rule
shape / dtype / static args Changing these can trigger a new compiled program. train_step = jax.jit(train_step)
params, opt_state, loss = train_step(params, opt_state, batch_key)
첫 호출에서는 Python function을 tracing하고 XLA executable을 만든다. 이후 같은 shape/dtype/static argument로 호출하면 compiled executable을 재사용한다.
그래서 JAX benchmark는 다음을 분리해야 한다.
compile time
first execution time
steady-state execution time
또 JAX 실행은 asynchronous dispatch가 많다. 정확히 시간을 재려면 결과가 device에서 준비될 때까지 기다려야 한다.
loss.block_until_ready()
실습 위치
python3 labs/jax-transformer/bench_rmsnorm.py
python3 labs/jax-transformer/bench_attention.py
benchmark 함수에서 block_until_ready()를 빼면 실제 kernel execution이 아니라 enqueue 비용만 측정할 수 있다.