JAX Transformer Forward Pass
마지막 수정:
JAX Transformer forward pass는 PyTorch 모델과 같은 tensor algebra를 수행하지만, 구조는 함수 합성이다.
JAX Model Shape Parameters are data. Forward and train steps are pure functions.
params, batch -> loss, grads parameter pytree
+ {"blocks": [...], "token_embedding": ...} pure forward
-> forward(params, input_ids, config) logits
[batch, seq, vocab] PyTorch module owns parameters and behavior
model(input_ids) JAX state is explicit and transformed by functions
jit(value_and_grad(loss_fn)) def decoder_block(params, x, config):
x = x + self_attention(params["attn"], rms_norm(params["norm1"], x), config)
x = x + mlp(params["mlp"], rms_norm(params["norm2"], x))
return x
이 코드는 self.norm1, self.attn 같은 object field를 참조하지 않는다. 필요한 parameter는 모두 argument로 들어온다.
Attention도 같은 방식이다.
qkv = x @ params["qkv"]
q, k, v = jnp.split(qkv, 3, axis=-1)
scores = q @ swap(k) * scale
probs = jax.nn.softmax(masked_scores, axis=-1)
out = probs @ v
JAX에서 중요한 것은 shape가 compile contract가 된다는 점이다. 같은 forward라도 batch/sequence/head shape가 바뀌면 새로운 compiled executable이 필요할 수 있다.
실습 위치
python3 labs/jax-transformer/train.py --steps 10
JAX 설치 환경에서는 이 명령으로 tiny Transformer가 next-token task를 학습하는지 확인한다.