torch.compile에서 Triton과 CUDA까지

마지막 수정:

cudapytorchtorch-compiletritonkernel

새 연산을 빠르게 만들고 싶다고 해서 바로 CUDA kernel부터 작성할 필요는 없다.

실전에서는 보통 다음 사다리로 내려간다.

PyTorch
-> torch.compile
-> Triton
-> CUDA

PyTorch

PyTorch 코드는 가장 쓰기 쉽다.

def elu(x):
    return torch.where(x < 0, torch.exp(x) - 1, x)

하지만 여러 작은 op가 따로 실행되면 kernel launch와 HBM read/write가 많아질 수 있다.

compare
exp
subtract
where

각 op가 별도 kernel로 실행되면 중간 결과가 HBM을 오갈 수 있다.

torch.compile

torch.compile은 PyTorch 연산 그래프를 잡아서 더 낮은 수준의 kernel로 바꿔준다.

@torch.compile
def elu(x):
    return torch.where(x < 0, torch.exp(x) - 1, x)

장점은 쉽다는 것이다.

easy
often fast
minimal code change

한계는 원하는 low-level scheduling을 직접 다 제어하기 어렵다는 점이다.

generated Triton 보기

torch.compile이 만든 코드를 보고 싶으면 TORCH_LOGS를 켤 수 있다.

export TORCH_LOGS="output_code"

이렇게 하면 생성된 Triton kernel을 참고할 수 있다. 처음부터 Triton을 직접 쓰기 어렵다면, generated kernel을 출발점으로 삼을 수 있다.

Triton

Triton은 CUDA보다 Python에 가깝게 block 단위 kernel을 쓸 수 있게 해준다.

harder than PyTorch
more flexible than torch.compile
easier than raw CUDA

Triton에서는 program id, block indices, mask, load/store 같은 개념을 직접 다룬다.

program_id -> 이 program이 맡을 block
mask       -> boundary check
tl.load    -> memory load
tl.store   -> memory store

CUDA

CUDA는 가장 어렵지만 가장 많은 제어권을 준다.

thread/block shape
shared memory
warp-level primitive
register pressure
SM resource usage

Triton으로 충분히 빠르지 않거나, shared memory와 warp scheduling을 더 세밀하게 제어해야 할 때 CUDA로 내려간다.

선택 기준

먼저 PyTorch로 reference를 만든다
torch.compile로 쉽게 빨라지는지 본다
부족하면 generated Triton을 읽고 Triton으로 실험한다
그래도 부족하면 CUDA로 직접 제어한다

확인

  • 왜 바로 CUDA부터 시작하지 않는가?
  • torch.compile은 어떤 종류의 최적화를 자동화해주는가?
  • Triton은 PyTorch와 CUDA 사이에서 어떤 위치인가?
  • CUDA로 내려가야 하는 신호는 무엇인가?