JAX Profiler로 Baseline 측정하기
마지막 수정:
JAX profiling에서 첫 실수는 Python timer만 보고 결론을 내리는 것이다.
JAX 실행에는 최소 세 층이 있다.
Python tracing / dispatch
XLA compile
device execution
그래서 baseline은 두 가지를 함께 남긴다.
steady-state tokens/sec
profiler trace
lab의 profiling script는 jax.profiler.trace로 trace directory를 만든다.
python3 labs/jax-transformer/profile_transformer.py --steps 20 --trace-dir runs/jax-profiler
결과를 볼 때는 첫 step과 이후 step을 구분한다. 첫 step에는 compile 비용이 섞일 수 있고, 이후 step이 실제 반복 학습의 기준선이다.
확인
- JAX profiling에서 compile time과 execution time을 분리해야 하는 이유는 무엇인가?
- asynchronous dispatch 때문에 Python timer가 과소평가할 수 있는 것은 무엇인가?
- profiler trace와 microbenchmark를 함께 봐야 하는 이유는 무엇인가?