JAX Functional Programming Model
마지막 수정:
JAX 트랙의 첫 번째 전환은 “모델 객체를 만든다”가 아니라 계산 가능한 함수와 명시적 state를 만든다는 것이다.
JAX Model Shape Parameters are data. Forward and train steps are pure functions.
params, batch -> loss, grads parameter pytree
+ {"blocks": [...], "token_embedding": ...} pure forward
-> forward(params, input_ids, config) logits
[batch, seq, vocab] PyTorch module owns parameters and behavior
model(input_ids) JAX state is explicit and transformed by functions
jit(value_and_grad(loss_fn)) PyTorch에서는 nn.Module이 parameter와 forward behavior를 함께 가진다.
model = TinyTransformerLM(config)
logits = model(input_ids)
JAX lab에서는 parameter pytree와 forward function을 분리한다.
params = init_params(key, config)
logits = forward(params, input_ids, config)
이 차이는 취향 문제가 아니다. JAX의 jit, grad, vmap, sharding은 모두 “함수를 다른 함수로 바꾸는” transformation이다. 그래서 학습 코드가 다음 형태에 가까울수록 JAX답다.
old_state, input -> new_state, output
실습 위치
labs/jax-transformer/src/tiny_transformer.py
먼저 TransformerConfig, init_params, forward를 읽는다. JAX 모델의 기준선은 class hierarchy가 아니라 parameter tree shape이다.