RMSNorm Forward Kernel
RMSNorm은 최신 decoder-only LLM에서 자주 쓰이는 normalization이다.
LayerNorm과 달리 mean subtraction을 하지 않는다.
rms = sqrt(mean(x^2) + eps)
y = weight * x / rms
x[row, :]
sum(x^2)
inv_rms = rsqrt(mean + eps)
y = weight * x * inv_rms
shape
입력을 [rows, hidden]으로 본다.
x [rows, hidden]
weight [hidden]
y [rows, hidden]
여기서 rows는 batch와 sequence를 합친 축으로 볼 수 있다.
Naive forward
가장 단순한 버전은 thread 하나가 output 원소 하나를 맡고, row 전체를 loop로 읽어 sum(x^2)를 구한다.
__global__ void rmsnorm_forward_naive_kernel(
const float* x,
const float* weight,
float* y,
int rows,
int hidden,
float eps
) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int total = rows * hidden;
if (idx < total) {
int row = idx / hidden;
int col = idx % hidden;
float ss = 0.0f;
for (int h = 0; h < hidden; h++) {
float v = x[row * hidden + h];
ss += v * v;
}
float inv_rms = rsqrtf(ss / hidden + eps);
y[idx] = weight[col] * x[idx] * inv_rms;
}
}
왜 naive인가
같은 row의 모든 output element가 같은 sum(x^2)를 필요로 한다. 그런데 naive kernel은 각 thread가 이 sum을 다시 계산한다.
맞지만 느림:
correctness 학습에 좋음
나중에 최적화:
row마다 한 번만 reduction하고 재사용
확인
- RMSNorm이 LayerNorm보다 단순한 이유는 무엇인가?
weight[col]은 왜 hidden 축에만 의존하는가?- naive RMSNorm forward에서 어떤 계산이 중복되는가?