JAX Transformer 모델 직접 작성하기

이 경로는 PyTorch Transformer 모델을 JAX 방식으로 다시 작성하는 모델 구현 경로입니다.

목표는 PyTorch 코드를 기계적으로 옮기는 것이 아닙니다. JAX에서는 모델을 nn.Module 객체로 소유하기보다, parameter를 pytree data로 들고 pure function이 그 state를 받아 계산합니다.

PyTorch:
model = TinyTransformerLM(config)
logits = model(input_ids)

JAX:
params = init_params(key, config)
logits = forward(params, input_ids, config)

이 path가 끝나면 JAX에서 Transformer를 “상태를 명시적으로 받는 함수들의 합성”으로 설명할 수 있어야 합니다.

구현 순서는 PyTorch path와 같은 기능 단위를 따릅니다. 다만 각 단계에서 “객체가 state를 가진다”가 아니라 “array program이 transform된다”는 관점을 유지합니다.

array / transform 기본기
  -> parameter pytree / PRNG
  -> token embedding / tied LM head
  -> RMSNorm / residual
  -> causal self-attention
  -> MLP / decoder block
  -> pure forward function
  1. JAX Array와 Transform 기본기 — Transformer를 작성하기 전에 JAX array, device placement, transformable function이라는 기본 관점을 잡는다.
  2. JAX Functional Programming Model — JAX 모델 구현의 출발점인 pure function, explicit state, transformable program 구조를 PyTorch nn.Module 방식과 비교한다.
  3. JAX PRNG and Parameter Pytrees — JAX의 explicit PRNG key와 parameter pytree가 모델 초기화, reproducibility, jit-friendly state 표현에 어떤 의미를 갖는지 이해한다.
  4. JAX Token Embedding과 LM Head — JAX parameter pytree에서 token embedding을 꺼내고, tied LM head를 pure function으로 계산한다.
  5. JAX RMSNorm과 Residual Path — RMSNorm과 residual connection을 JAX pure function으로 작성하고 XLA fusion 관점에서 본다.
  6. JAX Causal Self-Attention 구현 — QKV projection, causal mask, softmax, attention output을 JAX array operation으로 직접 작성한다.
  7. JAX MLP와 Decoder Block 조립 — attention, RMSNorm, MLP를 JAX 함수 합성으로 묶어 decoder block을 만든다.
  8. JAX Transformer Forward Pass — RMSNorm, causal attention, MLP, decoder block, tied LM head를 JAX 함수 합성으로 구현하는 방식을 이해한다.