JAX XLA/HLO Inspection
마지막 수정:
JAX 최적화는 “소스 코드 줄 수”가 아니라 XLA가 만든 program을 봐야 한다.
jit 함수는 lowering을 통해 compiler IR로 내려갈 수 있다.
lowered = jax.jit(fn).lower(*example_args)
print(lowered.as_text())
HLO를 읽을 때는 처음부터 모든 줄을 이해하려고 하지 않는다. 먼저 세 가지를 본다.
같은 elementwise 연산이 fusion되었는가?
large matmul / dot_general이 어디 있는가?
shape가 기대한 대로 specialized 되었는가?
RMSNorm 같은 연산은 source에서는 여러 op로 보이지만 XLA가 하나의 fused computation으로 만들 수 있다. 반대로 attention은 matmul, softmax, mask, value matmul 사이의 memory traffic이 핵심 병목이 될 수 있다.
실습
python3 labs/jax-transformer/inspect_hlo.py --target rmsnorm
python3 labs/jax-transformer/inspect_hlo.py --target attention
확인
- JAX source code와 compiled program이 다를 수 있는 이유는 무엇인가?
- RMSNorm에서 fusion 여부를 확인해야 하는 이유는 무엇인가?
- attention에서 HLO를 볼 때 matmul과 softmax 사이의 경계를 확인하는 이유는 무엇인가?