JAX RMSNorm과 Residual Path
마지막 수정:
RMSNorm은 JAX의 함수형 구현을 보기 좋은 첫 연산이다.
def rms_norm(params, x, eps=1e-6):
inv_rms = jax.lax.rsqrt(jnp.mean(x * x, axis=-1, keepdims=True) + eps)
return params["weight"] * x * inv_rms
PyTorch에서는 nn.Module 안에 weight가 들어간다. JAX에서는 weight가 pytree의 leaf이고, 함수가 그 leaf를 읽어 계산한다.
Residual path도 state를 바꾸지 않는다.
x = x + attention(norm(x))
x = x + mlp(norm(x))
중요한 점은 JAX에서 이런 elementwise chain이 XLA compile 이후 하나의 fused computation으로 합쳐질 수 있다는 것이다. 그래서 custom kernel을 만들기 전에 먼저 compiled program의 steady-state 시간을 측정해야 한다.
실습
python3 labs/jax-transformer/bench_rmsnorm.py --iters 100 --warmup 20
확인
- RMSNorm parameter는 JAX pytree에서 어떤 leaf로 존재하는가?
- residual connection이 parameter update와 다른 종류의 state인 이유는 무엇인가?
- JAX에서 RMSNorm custom kernel로 바로 내려가기 전에 XLA 결과를 봐야 하는 이유는 무엇인가?