JAX Profiler로 Baseline 측정하기

마지막 수정:

jaxprofilingprofilerbenchmark

JAX profiling에서 첫 실수는 Python timer만 보고 결론을 내리는 것이다.

JAX 실행에는 최소 세 층이 있다.

Python tracing / dispatch
XLA compile
device execution

그래서 baseline은 두 가지를 함께 남긴다.

steady-state tokens/sec
profiler trace

lab의 profiling script는 jax.profiler.trace로 trace directory를 만든다.

python3 labs/jax-transformer/profile_transformer.py --steps 20 --trace-dir runs/jax-profiler

결과를 볼 때는 첫 step과 이후 step을 구분한다. 첫 step에는 compile 비용이 섞일 수 있고, 이후 step이 실제 반복 학습의 기준선이다.

확인

  • JAX profiling에서 compile time과 execution time을 분리해야 하는 이유는 무엇인가?
  • asynchronous dispatch 때문에 Python timer가 과소평가할 수 있는 것은 무엇인가?
  • profiler trace와 microbenchmark를 함께 봐야 하는 이유는 무엇인가?