JAX Array와 Transform 기본기

마지막 수정:

jaxarraytransformtransformer

JAX를 처음 볼 때 가장 중요한 전환은 “딥러닝 프레임워크”보다 array program transformation 도구로 보는 것이다.

PyTorch에서 먼저 확인하던 습관은 JAX에서도 그대로 유지한다.

shape
dtype
device placement

하지만 JAX에서는 여기에 하나가 더 붙는다.

이 함수가 transformation 가능한가?

JAX의 핵심 API는 대부분 함수를 다른 함수로 바꾼다.

compiled = jax.jit(fn)
grads = jax.grad(loss_fn)
batched = jax.vmap(single_example_fn)

그래서 JAX Transformer 구현은 class hierarchy를 먼저 만드는 일이 아니다. 먼저 순수한 array 계산을 만들고, 그 계산을 jit, grad, vmap, sharding으로 변환할 수 있게 유지하는 일이다.

input array
  -> pure function
  -> output array
  -> transform: jit / grad / shard

확인

  • JAX에서 jit, grad, vmap은 공통적으로 무엇을 입력으로 받는가?
  • PyTorch의 nn.Module 중심 사고와 JAX의 function transformation 사고는 어디서 갈라지는가?
  • Transformer 구현에서 shape와 dtype 외에 compile boundary를 같이 봐야 하는 이유는 무엇인가?