JAX Transformer Forward Pass

마지막 수정:

jaxtransformerattentionrmsnorm

JAX Transformer forward pass는 PyTorch 모델과 같은 tensor algebra를 수행하지만, 구조는 함수 합성이다.

JAX Model Shape Parameters are data. Forward and train steps are pure functions.
params, batch -> loss, grads
parameter pytree {"blocks": [...], "token_embedding": ...}
+
pure forward forward(params, input_ids, config)
->
logits [batch, seq, vocab]
PyTorch module owns parameters and behavior model(input_ids)
JAX state is explicit and transformed by functions jit(value_and_grad(loss_fn))
JAX code becomes easier to compile and shard when mutable training state is made explicit.
def decoder_block(params, x, config):
    x = x + self_attention(params["attn"], rms_norm(params["norm1"], x), config)
    x = x + mlp(params["mlp"], rms_norm(params["norm2"], x))
    return x

이 코드는 self.norm1, self.attn 같은 object field를 참조하지 않는다. 필요한 parameter는 모두 argument로 들어온다.

Attention도 같은 방식이다.

qkv = x @ params["qkv"]
q, k, v = jnp.split(qkv, 3, axis=-1)
scores = q @ swap(k) * scale
probs = jax.nn.softmax(masked_scores, axis=-1)
out = probs @ v

JAX에서 중요한 것은 shape가 compile contract가 된다는 점이다. 같은 forward라도 batch/sequence/head shape가 바뀌면 새로운 compiled executable이 필요할 수 있다.

실습 위치

python3 labs/jax-transformer/train.py --steps 10

JAX 설치 환경에서는 이 명령으로 tiny Transformer가 next-token task를 학습하는지 확인한다.