JAX Transformer Benchmark Report 작성하기

마지막 수정:

jaxbenchmarkprofilingreport

JAX path의 마지막 산출물은 재현 가능한 benchmark report다.

형식은 PyTorch report와 맞춘다.

question
setup
command
compile behavior
baseline result
bottleneck
change
new result
lesson

JAX report에는 PyTorch report에 없는 항목이 하나 더 들어간다.

compile behavior

예를 들면 다음 질문을 기록한다.

첫 호출과 두 번째 호출의 시간 차이는 얼마인가?
shape를 바꾸면 다시 compile되는가?
RMSNorm은 HLO에서 fusion되는가?
prefill/decode attention은 서로 다른 executable인가?

최종 목표는 “JAX가 빠르다/느리다”가 아니다. 같은 Transformer workload를 PyTorch와 JAX가 어떻게 다르게 실행 모델로 바꾸는지 설명하는 것이다.

확인

  • JAX benchmark report에 compile behavior가 필요한 이유는 무엇인가?
  • PyTorch와 같은 표로 비교해야 하는 항목은 무엇인가?
  • 숫자가 다르게 나왔을 때 framework 차이와 shape/dtype 차이를 어떻게 분리할 수 있는가?