JAX Profiling과 Kernel Optimization
이 경로는 JAX 학습 코드를 측정하고 최적화 후보를 찾는 profiling 경로입니다.
PyTorch profiling에서는 eager op와 CUDA kernel timeline을 먼저 봅니다. JAX에서는 먼저 compile boundary를 분리해야 합니다.
first call:
trace + compile + execute
later calls:
cached executable enqueue + device execute
따라서 JAX benchmark는 항상 warmup과 block_until_ready()가 기본입니다.
최적화 순서는 다음입니다.
baseline timing
-> profiler trace
-> RMSNorm / attention microbenchmark
-> prefill vs decode shape split
-> HLO / lowered program inspection
-> Pallas or custom kernel 후보 판단
-> benchmark report
- JAX jit Compile Boundary — JAX에서 첫 호출 compile time과 이후 execution time을 분리하고, static shape와 block_until_ready가 왜 중요한지 이해한다.
- JAX Profiler로 Baseline 측정하기 — JAX training step의 trace를 남기고 compile time, host dispatch, device execution을 구분한다.
- JAX RMSNorm and Attention Benchmark — PyTorch profiling 실습의 RMSNorm/attention 비교를 JAX에서는 jit-compiled function과 static shape benchmark로 옮긴다.
- JAX XLA/HLO Inspection — jit-compiled JAX 함수의 lowered program을 확인해 fusion과 shape-specialization을 읽는다.
- JAX Attention Prefill vs Decode — JAX attention benchmark를 prefill-like와 decode-like workload로 나눠 shape와 compile cache 차이를 확인한다.
- JAX Pallas Kernel Ladder — JAX에서 XLA fusion 이후에도 custom kernel이 필요한 경우 Pallas로 내려가는 판단 기준을 세운다.
- JAX Profiler and Kernel Optimization Map — JAX profiling에서 trace capture, HLO/compiled program 확인, Pallas/custom call 후보를 어떤 순서로 볼지 정리한다.
- JAX Transformer Benchmark Report 작성하기 — JAX compile, profiling, RMSNorm, attention 결과를 PyTorch 결과와 비교 가능한 리포트로 정리한다.