RMSNorm Backward Kernel

cudarmsnormbackwardautograd

backward를 하나만 해본다면 RMSNorm이 좋다.

Transformer와 직접 연결됨
LayerNorm보다 단순함
그래도 reduction이 필요함
PyTorch autograd.Function으로 연결하기 좋음
grad_out
saved x, weight, inv_rms
row reduction for grad_x
column reduction for grad_weight
RMSNorm backward는 grad_x와 grad_weight를 만들기 위해 저장된 forward 값과 reduction을 사용한다.

forward 복습

y_i = weight_i * x_i * inv_rms
inv_rms = 1 / sqrt(mean(x^2) + eps)

forward에서 inv_rms를 저장해두면 backward가 쉬워진다.

backward에서 필요한 것

backward는 grad_out을 받아 다음을 계산한다.

grad_x
grad_weight

grad_weight는 hidden column마다 rows 방향으로 누적한다.

grad_weight[col] = sum_row grad_out[row, col] * x[row, col] * inv_rms[row]

grad_x는 elementwise 항과 row-wise reduction 항이 섞인다.

Path 2에서의 수준

이 카드는 선택 심화다. 처음에는 다음을 목표로 한다.

1. forward에서 x, weight, inv_rms를 저장해야 함을 이해
2. backward output이 grad_x와 grad_weight임을 이해
3. grad_weight가 reduction이라는 점을 이해
4. PyTorch custom autograd에 어떻게 연결되는지 이해

최적화된 backward는 Path 3 이후에 다룬다.

확인

  • RMSNorm backward에서 forward 때 저장해두면 좋은 값은 무엇인가?
  • grad_weight는 왜 rows 방향 reduction인가?
  • backward kernel이 forward보다 어려운 이유는 무엇인가?