JAX Functional Programming Model

마지막 수정:

jaxtransformerfunctional-programmingpytree

JAX 트랙의 첫 번째 전환은 “모델 객체를 만든다”가 아니라 계산 가능한 함수와 명시적 state를 만든다는 것이다.

JAX Model Shape Parameters are data. Forward and train steps are pure functions.
params, batch -> loss, grads
parameter pytree {"blocks": [...], "token_embedding": ...}
+
pure forward forward(params, input_ids, config)
->
logits [batch, seq, vocab]
PyTorch module owns parameters and behavior model(input_ids)
JAX state is explicit and transformed by functions jit(value_and_grad(loss_fn))
JAX code becomes easier to compile and shard when mutable training state is made explicit.

PyTorch에서는 nn.Module이 parameter와 forward behavior를 함께 가진다.

model = TinyTransformerLM(config)
logits = model(input_ids)

JAX lab에서는 parameter pytree와 forward function을 분리한다.

params = init_params(key, config)
logits = forward(params, input_ids, config)

이 차이는 취향 문제가 아니다. JAX의 jit, grad, vmap, sharding은 모두 “함수를 다른 함수로 바꾸는” transformation이다. 그래서 학습 코드가 다음 형태에 가까울수록 JAX답다.

old_state, input -> new_state, output

실습 위치

labs/jax-transformer/src/tiny_transformer.py

먼저 TransformerConfig, init_params, forward를 읽는다. JAX 모델의 기준선은 class hierarchy가 아니라 parameter tree shape이다.