Optimized RMSNorm Benchmark
Path 3의 마무리는 RMSNorm이다. Path 2에서는 정확한 forward/backward를 만드는 것이 목표였고, 여기서는 reduction을 최적화한다.
x[row, :]
sum(x^2)
inv_rms = rsqrt(mean + eps)
y = weight * x * inv_rms
비교 대상
1. torch implementation
2. naive CUDA RMSNorm
3. block reduction RMSNorm
4. warp shuffle RMSNorm
모든 버전은 같은 input shape, 같은 dtype, 같은 tolerance로 비교한다.
왜 RMSNorm인가
RMSNorm은 Transformer에서 실제로 쓰이고, softmax보다 작게 시작하기 좋다. row-wise reduction과 elementwise scale이 함께 있어서 CUDA 최적화 기본기를 연결하기 좋다.
확인
- 최적화된 RMSNorm은 어떤 값을 줄이는가?
- forward와 backward 중 어느 쪽이 더 복잡한가?
- PyTorch extension benchmark에서 correctness와 speed를 왜 둘 다 봐야 하는가?