JAX PRNG and Parameter Pytrees
마지막 수정:
JAX에서는 random state가 전역으로 숨어 있지 않다. key를 직접 만들고 나누고 전달한다.
JAX Model Shape Parameters are data. Forward and train steps are pure functions.
params, batch -> loss, grads parameter pytree
+ {"blocks": [...], "token_embedding": ...} pure forward
-> forward(params, input_ids, config) logits
[batch, seq, vocab] PyTorch module owns parameters and behavior
model(input_ids) JAX state is explicit and transformed by functions
jit(value_and_grad(loss_fn)) key = jax.random.PRNGKey(0)
key, init_key = jax.random.split(key)
params = init_params(init_key, config)
이 방식은 처음에는 번거롭지만, 대규모 학습에서는 장점이 크다.
reproducibility:
어떤 randomness가 어디서 쓰였는지 코드에 드러남
parallelism:
rank/device별 key split을 명시할 수 있음
jit:
함수 안팎 state mutation이 줄어 compile boundary가 명확함
Parameter도 object field가 아니라 pytree다.
params = {
"token_embedding": ...,
"blocks": [
{"norm1": ..., "attn": ..., "mlp": ...},
...
],
}
JAX optimizer는 이 pytree와 같은 구조의 gradient, momentum, variance tree를 만든다. 그래서 모델, gradient, optimizer state가 같은 tree structure를 공유한다.
실습 위치
labs/jax-transformer/src/tiny_transformer.py
labs/jax-transformer/train.py
init_params와 init_adamw를 같이 읽으면 “JAX에서 model state와 optimizer state가 둘 다 pytree”라는 감각이 잡힌다.