JAX vs PyTorch Comparison Report
마지막 수정:
JAX 트랙의 최종 산출물은 “JAX 코드도 있다”가 아니다. 같은 tiny Transformer를 두 시스템이 어떻게 다르게 표현하는지 설명할 수 있어야 한다.
JAX Model Shape Parameters are data. Forward and train steps are pure functions.
params, batch -> loss, grads parameter pytree
+ {"blocks": [...], "token_embedding": ...} pure forward
-> forward(params, input_ids, config) logits
[batch, seq, vocab] PyTorch module owns parameters and behavior
model(input_ids) JAX state is explicit and transformed by functions
jit(value_and_grad(loss_fn)) 비교 기준은 다음이다.
Model state:
PyTorch nn.Module parameters vs JAX parameter pytree
Randomness:
implicit/global-ish RNG usage vs explicit PRNG key split
Training step:
loss.backward/optimizer.step vs value_and_grad + explicit optimizer state
Compilation:
eager/torch.compile vs jit/XLA-first execution
Profiling:
CUDA op timeline vs compile/execution/block_until_ready separation
Distributed:
DDP wrapper/FSDP wrapper vs mesh/sharding/function transformation
이 비교를 통해 얻어야 할 결론은 우열이 아니다. 둘은 시스템을 보는 초점이 다르다.
PyTorch:
imperative debugging, library ecosystem, model authoring ergonomics가 강함
JAX:
program transformation, compiler-first execution, explicit sharding mental model이 강함
frontier-scale training stack을 읽을 때 JAX 경험이 중요한 이유는, 거대한 학습 시스템이 결국 “array program을 어떻게 transform하고 shard하고 compile할 것인가”의 문제로 자주 환원되기 때문이다.