torch.compile에서 Triton과 CUDA까지
마지막 수정:
새 연산을 빠르게 만들고 싶다고 해서 바로 CUDA kernel부터 작성할 필요는 없다.
실전에서는 보통 다음 사다리로 내려간다.
PyTorch
-> torch.compile
-> Triton
-> CUDA
PyTorch
PyTorch 코드는 가장 쓰기 쉽다.
def elu(x):
return torch.where(x < 0, torch.exp(x) - 1, x)
하지만 여러 작은 op가 따로 실행되면 kernel launch와 HBM read/write가 많아질 수 있다.
compare
exp
subtract
where
각 op가 별도 kernel로 실행되면 중간 결과가 HBM을 오갈 수 있다.
torch.compile
torch.compile은 PyTorch 연산 그래프를 잡아서 더 낮은 수준의 kernel로 바꿔준다.
@torch.compile
def elu(x):
return torch.where(x < 0, torch.exp(x) - 1, x)
장점은 쉽다는 것이다.
easy
often fast
minimal code change
한계는 원하는 low-level scheduling을 직접 다 제어하기 어렵다는 점이다.
generated Triton 보기
torch.compile이 만든 코드를 보고 싶으면 TORCH_LOGS를 켤 수 있다.
export TORCH_LOGS="output_code"
이렇게 하면 생성된 Triton kernel을 참고할 수 있다. 처음부터 Triton을 직접 쓰기 어렵다면, generated kernel을 출발점으로 삼을 수 있다.
Triton
Triton은 CUDA보다 Python에 가깝게 block 단위 kernel을 쓸 수 있게 해준다.
harder than PyTorch
more flexible than torch.compile
easier than raw CUDA
Triton에서는 program id, block indices, mask, load/store 같은 개념을 직접 다룬다.
program_id -> 이 program이 맡을 block
mask -> boundary check
tl.load -> memory load
tl.store -> memory store
CUDA
CUDA는 가장 어렵지만 가장 많은 제어권을 준다.
thread/block shape
shared memory
warp-level primitive
register pressure
SM resource usage
Triton으로 충분히 빠르지 않거나, shared memory와 warp scheduling을 더 세밀하게 제어해야 할 때 CUDA로 내려간다.
선택 기준
먼저 PyTorch로 reference를 만든다
torch.compile로 쉽게 빨라지는지 본다
부족하면 generated Triton을 읽고 Triton으로 실험한다
그래도 부족하면 CUDA로 직접 제어한다
확인
- 왜 바로 CUDA부터 시작하지 않는가?
torch.compile은 어떤 종류의 최적화를 자동화해주는가?- Triton은 PyTorch와 CUDA 사이에서 어떤 위치인가?
- CUDA로 내려가야 하는 신호는 무엇인가?