JAX Profiler and Kernel Optimization Map

마지막 수정:

jaxprofilerxlapallaskernel-optimization

JAX kernel optimization은 PyTorch custom CUDA extension과 출발점이 다르다. JAX에서는 먼저 XLA가 어느 정도까지 fusion/optimization을 했는지 확인해야 한다.

JAX Timing Model First call compiles. Later calls execute cached XLA programs asynchronously.
jit(fn)(static shapes)
call 1 trace Python lower to XLA compile execute
call 2..N cache lookup enqueue execute block_until_ready
Benchmark rule output.block_until_ready() Without this, you often measure enqueue time.
Compile rule shape / dtype / static args Changing these can trigger a new compiled program.
JAX profiling starts by separating compile time, enqueue time, and real device execution time.

권장 순서는 다음이다.

1. Python code를 pure function으로 정리한다
2. jit로 compile boundary를 만든다
3. warmup 이후 block_until_ready로 steady-state time을 잰다
4. profiler trace로 device 실행 시간을 본다
5. HLO / lowered program에서 fusion 여부를 확인한다
6. 그래도 좁은 병목이면 Pallas/custom call 후보로 본다

PyTorch에서는 custom extension이 “Python overhead 제거”의 의미도 가질 수 있다. JAX에서는 jit가 Python loop를 compile boundary 밖으로 밀어내므로, custom kernel의 기준이 더 높다.

JAX custom kernel 후보:
XLA fusion이 부족한 특수 memory layout
attention/KV-cache류의 irregular access
collective와 compute overlap이 필요한 경로
기존 primitive 조합이 너무 많은 temporary를 만드는 경로

따라서 JAX profiling report에는 compile time, steady-state time, shape/dtype, lowered program 관찰이 같이 들어가야 한다.