JAX Array와 Transform 기본기
마지막 수정:
JAX를 처음 볼 때 가장 중요한 전환은 “딥러닝 프레임워크”보다 array program transformation 도구로 보는 것이다.
PyTorch에서 먼저 확인하던 습관은 JAX에서도 그대로 유지한다.
shape
dtype
device placement
하지만 JAX에서는 여기에 하나가 더 붙는다.
이 함수가 transformation 가능한가?
JAX의 핵심 API는 대부분 함수를 다른 함수로 바꾼다.
compiled = jax.jit(fn)
grads = jax.grad(loss_fn)
batched = jax.vmap(single_example_fn)
그래서 JAX Transformer 구현은 class hierarchy를 먼저 만드는 일이 아니다. 먼저 순수한 array 계산을 만들고, 그 계산을 jit, grad, vmap, sharding으로 변환할 수 있게 유지하는 일이다.
input array
-> pure function
-> output array
-> transform: jit / grad / shard
확인
- JAX에서
jit,grad,vmap은 공통적으로 무엇을 입력으로 받는가? - PyTorch의
nn.Module중심 사고와 JAX의 function transformation 사고는 어디서 갈라지는가? - Transformer 구현에서 shape와 dtype 외에 compile boundary를 같이 봐야 하는 이유는 무엇인가?