PyTorch JIT Extension으로 RMSNorm 연결하기
.cu 파일만 실행하는 것과 PyTorch tensor에 연결하는 것은 다른 단계다.
이 카드의 목표는 naive RMSNorm kernel을 PyTorch에서 호출할 수 있게 만드는 것이다.
rmsnorm_kernel.cu
CUDA kernel
rmsnorm.cpp
torch.Tensor binding
Python
torch.utils.cpp_extension.load(...)
rmsnorm_kernel.cu
CUDA kernel
CUDA kernel
rmsnorm.cpp
C++ binding
C++ binding
load(...)
JIT compile
JIT compile
Python call
torch.Tensor in/out
torch.Tensor in/out
Colab에서는 load가 편하다
setup.py를 만들 수도 있지만, Colab 실습에서는 JIT loader가 간단하다.
from torch.utils.cpp_extension import load
rmsnorm_ext = load(
name="rmsnorm_ext",
sources=["rmsnorm.cpp", "rmsnorm_kernel.cu"],
extra_cuda_cflags=["-O3"],
verbose=True,
)
PyTorch가 내부적으로 C++/CUDA 코드를 컴파일하고 Python module처럼 사용할 수 있는 객체를 돌려준다.
C++ binding의 역할
binding은 torch.Tensor에서 raw pointer를 꺼내 CUDA kernel을 launch한다.
torch::Tensor rmsnorm_forward(torch::Tensor x, torch::Tensor weight, double eps) {
auto y = torch::empty_like(x);
rmsnorm_forward_kernel<<<grid, block>>>(
x.data_ptr<float>(),
weight.data_ptr<float>(),
y.data_ptr<float>(),
rows,
hidden,
static_cast<float>(eps)
);
return y;
}
Python reference
먼저 PyTorch reference를 만든다.
def rmsnorm_ref(x, weight, eps=1e-6):
inv_rms = torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + eps)
return weight * x * inv_rms
그다음 custom kernel output과 비교한다.
out = rmsnorm_ext.forward(x, weight, 1e-6)
ref = rmsnorm_ref(x, weight, 1e-6)
torch.testing.assert_close(out, ref, rtol=1e-4, atol=1e-4)
autograd.Function
forward와 backward를 모두 연결하려면 torch.autograd.Function으로 감싼다.
class CustomRMSNorm(torch.autograd.Function):
@staticmethod
def forward(ctx, x, weight, eps):
y, inv_rms = rmsnorm_ext.forward(x, weight, eps)
ctx.save_for_backward(x, weight, inv_rms)
ctx.eps = eps
return y
@staticmethod
def backward(ctx, grad_out):
x, weight, inv_rms = ctx.saved_tensors
grad_x, grad_weight = rmsnorm_ext.backward(grad_out, x, weight, inv_rms)
return grad_x, grad_weight, None
이 path의 산출물
최종 산출물은 다음이다.
Colab에서 naive RMSNorm forward/backward CUDA kernel을 빌드하고,
PyTorch reference와 비교한다.
이제 같은 kernel을 Path 3에서 profile하고 최적화할 수 있다.
확인
torch.utils.cpp_extension.load는 무엇을 해주는가?- C++ binding에서
data_ptr<float>()가 필요한 이유는 무엇인가? - custom backward를 PyTorch autograd에 연결하려면 어떤 class를 사용해야 하는가?