RMSNorm Backward Kernel
backward를 하나만 해본다면 RMSNorm이 좋다.
Transformer와 직접 연결됨
LayerNorm보다 단순함
그래도 reduction이 필요함
PyTorch autograd.Function으로 연결하기 좋음
grad_out
saved x, weight, inv_rms
row reduction for grad_x
column reduction for grad_weight
forward 복습
y_i = weight_i * x_i * inv_rms
inv_rms = 1 / sqrt(mean(x^2) + eps)
forward에서 inv_rms를 저장해두면 backward가 쉬워진다.
backward에서 필요한 것
backward는 grad_out을 받아 다음을 계산한다.
grad_x
grad_weight
grad_weight는 hidden column마다 rows 방향으로 누적한다.
grad_weight[col] = sum_row grad_out[row, col] * x[row, col] * inv_rms[row]
grad_x는 elementwise 항과 row-wise reduction 항이 섞인다.
Path 2에서의 수준
이 카드는 선택 심화다. 처음에는 다음을 목표로 한다.
1. forward에서 x, weight, inv_rms를 저장해야 함을 이해
2. backward output이 grad_x와 grad_weight임을 이해
3. grad_weight가 reduction이라는 점을 이해
4. PyTorch custom autograd에 어떻게 연결되는지 이해
최적화된 backward는 Path 3 이후에 다룬다.
확인
- RMSNorm backward에서 forward 때 저장해두면 좋은 값은 무엇인가?
grad_weight는 왜 rows 방향 reduction인가?- backward kernel이 forward보다 어려운 이유는 무엇인가?