JAX vs PyTorch Comparison Report

마지막 수정:

jaxpytorchcomparisontraining-systems

JAX 트랙의 최종 산출물은 “JAX 코드도 있다”가 아니다. 같은 tiny Transformer를 두 시스템이 어떻게 다르게 표현하는지 설명할 수 있어야 한다.

JAX Model Shape Parameters are data. Forward and train steps are pure functions.
params, batch -> loss, grads
parameter pytree {"blocks": [...], "token_embedding": ...}
+
pure forward forward(params, input_ids, config)
->
logits [batch, seq, vocab]
PyTorch module owns parameters and behavior model(input_ids)
JAX state is explicit and transformed by functions jit(value_and_grad(loss_fn))
JAX code becomes easier to compile and shard when mutable training state is made explicit.

비교 기준은 다음이다.

Model state:
PyTorch nn.Module parameters vs JAX parameter pytree

Randomness:
implicit/global-ish RNG usage vs explicit PRNG key split

Training step:
loss.backward/optimizer.step vs value_and_grad + explicit optimizer state

Compilation:
eager/torch.compile vs jit/XLA-first execution

Profiling:
CUDA op timeline vs compile/execution/block_until_ready separation

Distributed:
DDP wrapper/FSDP wrapper vs mesh/sharding/function transformation

이 비교를 통해 얻어야 할 결론은 우열이 아니다. 둘은 시스템을 보는 초점이 다르다.

PyTorch:
imperative debugging, library ecosystem, model authoring ergonomics가 강함

JAX:
program transformation, compiler-first execution, explicit sharding mental model이 강함

frontier-scale training stack을 읽을 때 JAX 경험이 중요한 이유는, 거대한 학습 시스템이 결국 “array program을 어떻게 transform하고 shard하고 compile할 것인가”의 문제로 자주 환원되기 때문이다.