JAX Profiler and Kernel Optimization Map
마지막 수정:
JAX kernel optimization은 PyTorch custom CUDA extension과 출발점이 다르다. JAX에서는 먼저 XLA가 어느 정도까지 fusion/optimization을 했는지 확인해야 한다.
JAX Timing Model First call compiles. Later calls execute cached XLA programs asynchronously.
jit(fn)(static shapes) call 1 trace Python lower to XLA compile execute
call 2..N cache lookup enqueue execute block_until_ready
Benchmark rule
output.block_until_ready() Without this, you often measure enqueue time. Compile rule
shape / dtype / static args Changing these can trigger a new compiled program. 권장 순서는 다음이다.
1. Python code를 pure function으로 정리한다
2. jit로 compile boundary를 만든다
3. warmup 이후 block_until_ready로 steady-state time을 잰다
4. profiler trace로 device 실행 시간을 본다
5. HLO / lowered program에서 fusion 여부를 확인한다
6. 그래도 좁은 병목이면 Pallas/custom call 후보로 본다
PyTorch에서는 custom extension이 “Python overhead 제거”의 의미도 가질 수 있다. JAX에서는 jit가 Python loop를 compile boundary 밖으로 밀어내므로, custom kernel의 기준이 더 높다.
JAX custom kernel 후보:
XLA fusion이 부족한 특수 memory layout
attention/KV-cache류의 irregular access
collective와 compute overlap이 필요한 경로
기존 primitive 조합이 너무 많은 temporary를 만드는 경로
따라서 JAX profiling report에는 compile time, steady-state time, shape/dtype, lowered program 관찰이 같이 들어가야 한다.