JAX PRNG and Parameter Pytrees

마지막 수정:

jaxprngpytreeinitialization

JAX에서는 random state가 전역으로 숨어 있지 않다. key를 직접 만들고 나누고 전달한다.

JAX Model Shape Parameters are data. Forward and train steps are pure functions.
params, batch -> loss, grads
parameter pytree {"blocks": [...], "token_embedding": ...}
+
pure forward forward(params, input_ids, config)
->
logits [batch, seq, vocab]
PyTorch module owns parameters and behavior model(input_ids)
JAX state is explicit and transformed by functions jit(value_and_grad(loss_fn))
JAX code becomes easier to compile and shard when mutable training state is made explicit.
key = jax.random.PRNGKey(0)
key, init_key = jax.random.split(key)
params = init_params(init_key, config)

이 방식은 처음에는 번거롭지만, 대규모 학습에서는 장점이 크다.

reproducibility:
어떤 randomness가 어디서 쓰였는지 코드에 드러남

parallelism:
rank/device별 key split을 명시할 수 있음

jit:
함수 안팎 state mutation이 줄어 compile boundary가 명확함

Parameter도 object field가 아니라 pytree다.

params = {
  "token_embedding": ...,
  "blocks": [
    {"norm1": ..., "attn": ..., "mlp": ...},
    ...
  ],
}

JAX optimizer는 이 pytree와 같은 구조의 gradient, momentum, variance tree를 만든다. 그래서 모델, gradient, optimizer state가 같은 tree structure를 공유한다.

실습 위치

labs/jax-transformer/src/tiny_transformer.py
labs/jax-transformer/train.py

init_paramsinit_adamw를 같이 읽으면 “JAX에서 model state와 optimizer state가 둘 다 pytree”라는 감각이 잡힌다.