JAX Gradient Accumulation
마지막 수정:
Gradient accumulation의 목표는 global batch를 키우되 한 번에 올리는 activation memory를 줄이는 것이다.
microbatch 1 -> grads
microbatch 2 -> grads
...
mean grads -> optimizer update
PyTorch에서는 여러 번 loss.backward()를 호출한 뒤 마지막에 optimizer.step()을 한다. JAX에서는 gradient도 pytree이므로 명시적으로 더하고 평균낸다.
grad_sum = tree_zeros_like(params)
grad_sum = tree_map(lambda a, b: a + b, grad_sum, grads)
grads = tree_map(lambda g: g / accum_steps, grad_sum)
JAX답게 구현하려면 Python loop로 먼저 이해하고, 이후 jax.lax.scan으로 compile 가능한 반복으로 바꿀 수 있다.
먼저 명확한 Python loop
-> shape와 state 흐름 확인
-> lax.scan으로 compile boundary 안으로 이동
확인
- gradient accumulation은 memory와 batch size 중 무엇을 trade-off하는가?
- JAX에서 gradient 누적이 pytree 연산으로 자연스럽게 표현되는 이유는 무엇인가?
lax.scan으로 옮기기 전에 Python loop로 먼저 검증하는 이유는 무엇인가?