Fake Quantization and STE
마지막 수정:
QAT는 보통 진짜 INT weight를 직접 학습하지 않는다.
대신 학습된 FP 모델에 fake quantization을 넣고 fine-tune한다.
stored parameter:
FP weight
forward에서 쓰는 값:
fake-quantized weight
optimizer가 업데이트하는 값:
FP weight
즉 QAT는 “양자화된 모델을 직접 학습”한다기보다, FP 모델을 quantized inference 환경에 적응시키는 과정이다.
0.137 6.85 7 0.140 loss dL/dw_fake pretend slope = 1 0.137 -> 0.136 Fake quantization은 양자화된 척하는 forward다
Fake quantization은 forward에서 quantize와 dequantize를 이어 붙인다.
w_fp
-> divide by scale
-> round to integer
-> clamp to integer range
-> multiply by scale
-> w_fake
예를 들어:
w_fp = 0.137
scale = 0.02
0.137 / 0.02 = 6.85
round(6.85) = 7
7 * 0.02 = 0.14
Forward에서는 0.137이 아니라 0.14로 계산한다. 모델은 학습 중에 quantization error를 계속 경험한다.
하지만 weight storage는 아직 FP다.
0.137 -> optimizer update -> 0.1364 -> 0.1358
이 연속적인 FP weight가 있어야 작은 gradient update를 누적할 수 있다.
Round는 gradient를 끊는다
문제는 round()다.
round(0.10) = 0
round(0.20) = 0
round(0.49) = 0
round(0.51) = 1
round()는 계단 함수다. 대부분의 구간에서 출력이 변하지 않기 때문에 gradient가 거의 0이다.
그대로 두면 backpropagation이 막힌다.
loss
-> round()
-> gradient 0
-> weight update 없음
QAT가 학습되려면 이 문제를 우회해야 한다.
STE는 backward에서 하는 실용적인 거짓말이다
STE는 Straight-Through Estimator의 약자다.
아이디어는 단순하다.
forward:
round를 실제로 적용한다
backward:
round가 없었던 것처럼 gradient를 통과시킨다
즉:
forward: x -> round(x)
backward: d round(x) / dx ~= 1
수학적으로 정확한 미분은 아니다. 하지만 QAT의 목적은 정확한 round 미분을 구하는 것이 아니라, 모델이 quantization noise 속에서도 loss를 줄이는 방향으로 움직이게 만드는 것이다.
Clipped STE는 boundary 밖 gradient를 막는다
기본 STE는 gradient를 전부 통과시킬 수 있다. 하지만 clipping boundary 밖의 값은 이미 saturate되어 있다.
too large -> q_max
too small -> q_min
이 값들을 더 크게 밀어도 quantized output은 바뀌지 않는다. 그래서 clipped STE는 quantizable range 안에서는 gradient를 통과시키고, clipped region에서는 gradient를 0으로 둔다.
inside range:
gradient passes
outside range:
gradient stops
이렇게 하면 optimizer가 boundary에 붙은 값을 더 바깥으로 미는 데 capacity를 낭비하지 않는다.
QAT가 실제로 배우는 것
QAT 중 모델은 이런 압력을 받는다.
loss를 낮춰야 한다
하지만 forward에서는 rounded value만 보인다
그래서 weight는 integer grid에 유리한 위치로 이동한다.
나쁜 위치:
grid boundary 근처, rounding에 민감함
좋은 위치:
grid point 근처, quantize 후에도 값이 안정적임
QAT는 quantization error를 제거하지 않는다. 모델이 그 error를 고려한 weight 배치를 찾게 한다.
마지막에는 진짜 quantized model로 export한다
QAT가 끝나면 더 이상 fake quantization이 필요 없다.
QAT fine-tuned FP weight
-> scale / zero-point 고정
-> 실제 INT8 / INT4 weight 생성
-> runtime이 읽을 quantized checkpoint로 export
정리하면:
학습 중:
FP weight + fake quantized forward
배포 시:
INT weight + scale metadata
확인
- QAT 중 optimizer가 업데이트하는 weight는 FP인가 INT인가?
- STE가 필요한 이유는
round()의 어떤 성질 때문인가? - Forward에서는 양자화된 척하고 backward에서는 미분 가능한 척한다는 말은 무슨 뜻인가?