RMSNorm Forward Kernel

cudarmsnormtransformernormalization

RMSNorm은 최신 decoder-only LLM에서 자주 쓰이는 normalization이다.

LayerNorm과 달리 mean subtraction을 하지 않는다.

rms = sqrt(mean(x^2) + eps)
y = weight * x / rms
x[row, :]
sum(x^2)
inv_rms = rsqrt(mean + eps)
y = weight * x * inv_rms
RMSNorm forward는 row마다 sum(x^2)를 구한 뒤 normalize와 scale을 적용한다.

shape

입력을 [rows, hidden]으로 본다.

x      [rows, hidden]
weight [hidden]
y      [rows, hidden]

여기서 rows는 batch와 sequence를 합친 축으로 볼 수 있다.

Naive forward

가장 단순한 버전은 thread 하나가 output 원소 하나를 맡고, row 전체를 loop로 읽어 sum(x^2)를 구한다.

__global__ void rmsnorm_forward_naive_kernel(
    const float* x,
    const float* weight,
    float* y,
    int rows,
    int hidden,
    float eps
) {
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    int total = rows * hidden;

    if (idx < total) {
        int row = idx / hidden;
        int col = idx % hidden;

        float ss = 0.0f;
        for (int h = 0; h < hidden; h++) {
            float v = x[row * hidden + h];
            ss += v * v;
        }

        float inv_rms = rsqrtf(ss / hidden + eps);
        y[idx] = weight[col] * x[idx] * inv_rms;
    }
}

왜 naive인가

같은 row의 모든 output element가 같은 sum(x^2)를 필요로 한다. 그런데 naive kernel은 각 thread가 이 sum을 다시 계산한다.

맞지만 느림:
  correctness 학습에 좋음

나중에 최적화:
  row마다 한 번만 reduction하고 재사용

확인

  • RMSNorm이 LayerNorm보다 단순한 이유는 무엇인가?
  • weight[col]은 왜 hidden 축에만 의존하는가?
  • naive RMSNorm forward에서 어떤 계산이 중복되는가?