Fused Kernels
마지막 수정:
Kernel fusion은 연속된 여러 연산을 하나의 GPU kernel로 합치는 최적화다.
before:
kernel A -> write HBM
kernel B -> read HBM -> write HBM
kernel C -> read HBM -> write HBM
after:
one fused kernel
keep intermediate values local
write final result once
왜 빠른가
GPU에서 느린 일 중 하나는 중간 결과를 HBM에 쓰고 다시 읽는 것이다.
pointwise 연산이 연속되어 있으면 중간 값을 굳이 HBM에 저장할 필요가 없는 경우가 많다.
x
-> multiply
-> add
-> activation
-> dropout
각 원소가 독립적으로 처리된다면, thread가 값을 register에 들고 연속 계산한 뒤 최종 결과만 저장할 수 있다.
load x once
compute several ops locally
store y once
launch overhead도 줄어든다
여러 PyTorch op가 각각 kernel launch를 만들면 CPU host가 GPU에 여러 번 일을 지시해야 한다.
CPU launches kernel 1
CPU launches kernel 2
CPU launches kernel 3
fused kernel은 launch 횟수도 줄인다.
CPU launches one fused kernel
Transformer에서 어디에 쓰나
fused kernel은 pointwise 연산이 이어지는 곳에서 특히 자연스럽다.
LayerNorm / RMSNorm
bias + activation
dropout + residual add
scale + mask + softmax 일부
FlashAttention도 넓은 의미에서는 fused/tiled kernel engineering의 대표 사례다. naive attention은 QK^T, softmax, PV 사이에 큰 intermediate matrix를 HBM에 저장하지만, FlashAttention은 tile 단위로 score를 계산하고 online softmax 상태와 output을 kernel 안에서 갱신한다.
언제 어렵나
모든 연산을 무조건 합치면 좋은 것은 아니다.
register 사용량 증가
kernel 코드 복잡도 증가
재사용성 감소
컴파일 시간 증가
occupancy 감소 가능
그래서 fusion은 HBM traffic과 launch overhead를 줄이는 이득이 resource pressure 증가보다 클 때 좋다.
확인
- fused kernel은 무엇을 줄이는가?
- pointwise op가 fusion에 잘 맞는 이유는 무엇인가?
- fusion이 오히려 나빠질 수 있는 이유는 무엇인가?
- FlashAttention은 어떤 의미에서 fusion과 tiling의 대표 사례인가?