Quantized Matmul and Arithmetic Intensity
마지막 수정:
Quantization을 적용했다고 해서 항상 낮은 precision으로 계산되는 것은 아니다.
quantized storage != quantized execution
이 차이를 모르면 “모델은 4-bit인데 왜 안 빨라지지?”라는 혼란이 생긴다.
Store small, compute high
- Storage
- INT4 / INT8 weights
- Execution
- dequantize then FP matmul
- Main gain
- memory footprint
extra dequant overhead
Fused weight-only kernel
- Storage
- packed INT4 weights
- Execution
- load packed weights, scale inside kernel
- Main gain
- bandwidth + decode latency
kernel support required
W8A8 / FP8 kernel
- Storage
- low-bit W and A
- Execution
- low-precision GEMM, high-precision accumulate
- Main gain
- throughput on compute-bound work
activation scaling is hard
Quantized tensor는 의미를 scale과 함께 가진다
INT4나 INT8 값은 그 자체로 real value가 아니다. Scale과 zero-point가 있어야 해석된다.
r ~= S(q - Z)
Matmul에서는 이 해석이 dot product 안으로 들어간다.
y = sum(w_i x_i)
두 operand가 모두 quantized라면 원칙적으로는 다음 구조가 된다.
w_i ~= S_w(q_wi - Z_w)
x_i ~= S_x(q_xi - Z_x)
y ~= S_w S_x sum((q_wi - Z_w)(q_xi - Z_x))
이때 weights를 symmetric으로 두면 Z_w = 0이 되어 correction term이 줄어든다. Activations는 ReLU, GeLU, LayerNorm 이후처럼 shifted distribution이 많아서 asymmetric이나 dynamic scale이 필요할 수 있다.
Weight-only는 보통 dequantization을 피할 수 없다
W4A16에서는 weight만 INT4다.
weights: INT4
activations: FP16/BF16
가장 단순한 실행은 다음과 같다.
packed INT4 weights
-> unpack
-> dequantize to FP16/BF16
-> FP matmul with FP16/BF16 activations
이 경우 weight storage와 memory bandwidth는 줄지만 compute 자체는 high precision 경로다. 심하면 dequantize overhead 때문에 FP baseline보다 느릴 수도 있다.
그래서 weight-only quantization의 성능은 runtime kernel에 크게 의존한다.
좋은 kernel은 dequantize를 matmul 안에 fuse한다
생산용 weight-only kernel은 보통 이렇게 움직인다.
1. packed INT4 weight를 연속 load
2. group scale을 load
3. tile 안에서 unpack / scale 적용
4. FP16 activation과 multiply
5. accumulate
핵심은 intermediate FP16 weight matrix를 HBM에 다시 쓰지 않는 것이다. Dequantized weight를 따로 materialize하면 memory traffic이 다시 커진다.
Group quantization이 input dimension을 따라 자주 쓰이는 이유도 여기에 있다.
contiguous G weights
-> one scale
-> vectorized load
-> partial sum에 scale 적용
즉 granularity는 정확도 문제이면서 kernel layout 문제다.
W8A8은 integer 또는 FP8 GEMM에 가까워진다
Activation까지 낮은 precision이면 더 직접적인 low-precision GEMM이 가능하다.
INT8 weights x INT8 activations
-> INT32 accumulate
-> rescale to output dtype
또는 FP8에서는 FP8 tensor core path를 쓸 수 있다.
FP8 weights x FP8 activations
-> higher precision accumulate
-> output scale / cast
이때 compute throughput 자체가 올라갈 수 있다. 그래서 W8A8/FP8은 prefill이나 high-batch decode처럼 compute-bound에 가까운 workload에서 강하다.
하지만 activation scale을 runtime에 계산해야 하거나, hardware가 해당 format을 native로 지원하지 않으면 기대한 이득이 사라진다.
Arithmetic intensity는 두 방향으로 움직인다
Arithmetic intensity는 대략 다음 비율이다.
arithmetic intensity = FLOPs / bytes moved
Quantization은 bytes moved를 줄인다. 따라서 같은 matmul이 같은 양의 연산을 한다면 arithmetic intensity는 올라간다.
FP16 weight load: 2 bytes/value
INT4 weight load: 0.5 bytes/value
-> bytes down
-> FLOPs per byte up
Decode는 보통 memory bandwidth-bound다. 매 token마다 거대한 weight를 읽지만 batch가 작으면 weight reuse가 낮다. Weight-only quantization은 이 상황에서 강하다.
반면 prefill은 많은 input token을 한 번에 처리하므로 matmul reuse가 크고 compute-bound에 가까워진다. 이때는 bytes를 줄이는 것만으로는 부족하고, activation까지 낮춰 실제 compute throughput을 높이는 W8A8/FP8이 더 중요해질 수 있다.
Dequantization overhead도 bytes와 ops다
Quantization은 공짜가 아니다.
scale load
zero-point correction
INT4 unpack
dequant multiply
output rescale
이 비용이 줄어든 memory traffic보다 크면 speedup이 없다.
그래서 실제 질문은 이것이다.
줄인 HBM traffic > 추가 dequant/metadata/kernel overhead ?
이 부등식이 성립할 때 quantization이 빨라진다.
Speedup을 검증하는 법
Quantized model을 만들었다면 정확도만 보면 안 된다.
1. model memory가 줄었는가?
2. operator가 실제 INT8/INT4/FP8 kernel로 실행되는가?
3. dequantized tensor를 HBM에 materialize하지 않는가?
4. prefill과 decode를 따로 benchmark했는가?
5. batch/concurrency가 바뀌면 bottleneck도 바뀌는가?
특히 runtime이 unsupported quantization scheme을 조용히 FP16/FP32로 upcast할 수 있다. 이 경우 checkpoint는 quantized처럼 보여도 execution은 그렇지 않다.
확인
- Quantized storage와 quantized execution은 어떻게 다른가?
- Weight-only quantization이 memory bandwidth에는 이득이 있지만 compute speedup은 제한적인 이유는 무엇인가?
- Quantization이 arithmetic intensity를 높이는 조건과, dequantization overhead 때문에 이득이 사라지는 조건은 무엇인가?