PyTorch JIT Extension으로 RMSNorm 연결하기

cudapytorchextensionrmsnorm

.cu 파일만 실행하는 것과 PyTorch tensor에 연결하는 것은 다른 단계다.

이 카드의 목표는 naive RMSNorm kernel을 PyTorch에서 호출할 수 있게 만드는 것이다.

rmsnorm_kernel.cu
  CUDA kernel

rmsnorm.cpp
  torch.Tensor binding

Python
  torch.utils.cpp_extension.load(...)
rmsnorm_kernel.cu
CUDA kernel
rmsnorm.cpp
C++ binding
load(...)
JIT compile
Python call
torch.Tensor in/out
`torch.utils.cpp_extension.load`는 C++/CUDA 파일을 빌드해 Python module처럼 불러온다.

Colab에서는 load가 편하다

setup.py를 만들 수도 있지만, Colab 실습에서는 JIT loader가 간단하다.

from torch.utils.cpp_extension import load

rmsnorm_ext = load(
    name="rmsnorm_ext",
    sources=["rmsnorm.cpp", "rmsnorm_kernel.cu"],
    extra_cuda_cflags=["-O3"],
    verbose=True,
)

PyTorch가 내부적으로 C++/CUDA 코드를 컴파일하고 Python module처럼 사용할 수 있는 객체를 돌려준다.

C++ binding의 역할

binding은 torch.Tensor에서 raw pointer를 꺼내 CUDA kernel을 launch한다.

torch::Tensor rmsnorm_forward(torch::Tensor x, torch::Tensor weight, double eps) {
    auto y = torch::empty_like(x);

    rmsnorm_forward_kernel<<<grid, block>>>(
        x.data_ptr<float>(),
        weight.data_ptr<float>(),
        y.data_ptr<float>(),
        rows,
        hidden,
        static_cast<float>(eps)
    );

    return y;
}

Python reference

먼저 PyTorch reference를 만든다.

def rmsnorm_ref(x, weight, eps=1e-6):
    inv_rms = torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + eps)
    return weight * x * inv_rms

그다음 custom kernel output과 비교한다.

out = rmsnorm_ext.forward(x, weight, 1e-6)
ref = rmsnorm_ref(x, weight, 1e-6)
torch.testing.assert_close(out, ref, rtol=1e-4, atol=1e-4)

autograd.Function

forward와 backward를 모두 연결하려면 torch.autograd.Function으로 감싼다.

class CustomRMSNorm(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, weight, eps):
        y, inv_rms = rmsnorm_ext.forward(x, weight, eps)
        ctx.save_for_backward(x, weight, inv_rms)
        ctx.eps = eps
        return y

    @staticmethod
    def backward(ctx, grad_out):
        x, weight, inv_rms = ctx.saved_tensors
        grad_x, grad_weight = rmsnorm_ext.backward(grad_out, x, weight, inv_rms)
        return grad_x, grad_weight, None

이 path의 산출물

최종 산출물은 다음이다.

Colab에서 naive RMSNorm forward/backward CUDA kernel을 빌드하고,
PyTorch reference와 비교한다.

이제 같은 kernel을 Path 3에서 profile하고 최적화할 수 있다.

확인

  • torch.utils.cpp_extension.load는 무엇을 해주는가?
  • C++ binding에서 data_ptr<float>()가 필요한 이유는 무엇인가?
  • custom backward를 PyTorch autograd에 연결하려면 어떤 class를 사용해야 하는가?