JAX Pallas Kernel Ladder

마지막 수정:

jaxpallaskerneloptimization

JAX에서 custom kernel은 첫 선택지가 아니다.

권장 순서는 다음이다.

plain JAX
  -> jax.jit
  -> shape/dtype/layout 정리
  -> HLO inspection
  -> Pallas kernel 후보
  -> CUDA/Triton/Pallas 비교

Pallas는 JAX 생태계에서 kernel-level programming으로 내려가는 길이다. 다만 모든 연산을 Pallas로 바꾸는 것이 목표가 아니다. XLA가 이미 잘 fusion하는 elementwise chain이라면 custom kernel의 이득이 작을 수 있다.

Pallas 후보는 보통 다음 조건을 만족한다.

memory traffic이 병목이다
XLA fusion만으로 layout을 제어하기 어렵다
특정 shape에 특화된 kernel이 필요하다
benchmark에서 병목이 재현된다

이 path에서는 먼저 RMSNorm과 attention을 후보로 둔다. 하지만 실제로 바꿀지는 profiler와 HLO를 보고 결정한다.

확인

  • JAX에서 custom kernel로 바로 내려가면 안 되는 이유는 무엇인가?
  • Pallas 후보를 고를 때 profiler, HLO, benchmark 중 무엇이 필요한가?
  • RMSNorm과 attention은 각각 어떤 종류의 병목을 의심할 수 있는가?