JAX XLA/HLO Inspection

마지막 수정:

jaxxlahloprofiling

JAX 최적화는 “소스 코드 줄 수”가 아니라 XLA가 만든 program을 봐야 한다.

jit 함수는 lowering을 통해 compiler IR로 내려갈 수 있다.

lowered = jax.jit(fn).lower(*example_args)
print(lowered.as_text())

HLO를 읽을 때는 처음부터 모든 줄을 이해하려고 하지 않는다. 먼저 세 가지를 본다.

같은 elementwise 연산이 fusion되었는가?
large matmul / dot_general이 어디 있는가?
shape가 기대한 대로 specialized 되었는가?

RMSNorm 같은 연산은 source에서는 여러 op로 보이지만 XLA가 하나의 fused computation으로 만들 수 있다. 반대로 attention은 matmul, softmax, mask, value matmul 사이의 memory traffic이 핵심 병목이 될 수 있다.

실습

python3 labs/jax-transformer/inspect_hlo.py --target rmsnorm
python3 labs/jax-transformer/inspect_hlo.py --target attention

확인

  • JAX source code와 compiled program이 다를 수 있는 이유는 무엇인가?
  • RMSNorm에서 fusion 여부를 확인해야 하는 이유는 무엇인가?
  • attention에서 HLO를 볼 때 matmul과 softmax 사이의 경계를 확인하는 이유는 무엇인가?