Fake Quantization and STE

마지막 수정:

Draft
설명: 점검 전 위젯: 점검 전 링크: 점검 전
quantizationqatfake-quantizationstetraininggradients

QAT는 보통 진짜 INT weight를 직접 학습하지 않는다.

대신 학습된 FP 모델에 fake quantization을 넣고 fine-tune한다.

stored parameter:
FP weight

forward에서 쓰는 값:
fake-quantized weight

optimizer가 업데이트하는 값:
FP weight

즉 QAT는 “양자화된 모델을 직접 학습”한다기보다, FP 모델을 quantized inference 환경에 적응시키는 과정이다.

Forward quantized value is used
FP weight 0.137
Divide by scale 6.85
Round 7
Dequantize 0.140
Forward compute loss
Backward gradient reaches FP weight
Loss gradient dL/dw_fake
STE through round pretend slope = 1
Update FP weight 0.137 -> 0.136
0.10 0.12 0.14 0.16 FP shadow weight
Fake quantization makes the forward pass feel the integer grid. STE keeps the backward pass useful by letting gradients update the FP shadow weight.

Fake quantization은 양자화된 척하는 forward다

Fake quantization은 forward에서 quantize와 dequantize를 이어 붙인다.

w_fp
-> divide by scale
-> round to integer
-> clamp to integer range
-> multiply by scale
-> w_fake

예를 들어:

w_fp = 0.137
scale = 0.02

0.137 / 0.02 = 6.85
round(6.85) = 7
7 * 0.02 = 0.14

Forward에서는 0.137이 아니라 0.14로 계산한다. 모델은 학습 중에 quantization error를 계속 경험한다.

하지만 weight storage는 아직 FP다.

0.137 -> optimizer update -> 0.1364 -> 0.1358

이 연속적인 FP weight가 있어야 작은 gradient update를 누적할 수 있다.

Round는 gradient를 끊는다

문제는 round()다.

round(0.10) = 0
round(0.20) = 0
round(0.49) = 0
round(0.51) = 1

round()는 계단 함수다. 대부분의 구간에서 출력이 변하지 않기 때문에 gradient가 거의 0이다.

그대로 두면 backpropagation이 막힌다.

loss
-> round()
-> gradient 0
-> weight update 없음

QAT가 학습되려면 이 문제를 우회해야 한다.

STE는 backward에서 하는 실용적인 거짓말이다

STE는 Straight-Through Estimator의 약자다.

아이디어는 단순하다.

forward:
round를 실제로 적용한다

backward:
round가 없었던 것처럼 gradient를 통과시킨다

즉:

forward:  x -> round(x)
backward: d round(x) / dx ~= 1

수학적으로 정확한 미분은 아니다. 하지만 QAT의 목적은 정확한 round 미분을 구하는 것이 아니라, 모델이 quantization noise 속에서도 loss를 줄이는 방향으로 움직이게 만드는 것이다.

Clipped STE는 boundary 밖 gradient를 막는다

기본 STE는 gradient를 전부 통과시킬 수 있다. 하지만 clipping boundary 밖의 값은 이미 saturate되어 있다.

too large -> q_max
too small -> q_min

이 값들을 더 크게 밀어도 quantized output은 바뀌지 않는다. 그래서 clipped STE는 quantizable range 안에서는 gradient를 통과시키고, clipped region에서는 gradient를 0으로 둔다.

inside range:
gradient passes

outside range:
gradient stops

이렇게 하면 optimizer가 boundary에 붙은 값을 더 바깥으로 미는 데 capacity를 낭비하지 않는다.

QAT가 실제로 배우는 것

QAT 중 모델은 이런 압력을 받는다.

loss를 낮춰야 한다
하지만 forward에서는 rounded value만 보인다

그래서 weight는 integer grid에 유리한 위치로 이동한다.

나쁜 위치:
grid boundary 근처, rounding에 민감함

좋은 위치:
grid point 근처, quantize 후에도 값이 안정적임

QAT는 quantization error를 제거하지 않는다. 모델이 그 error를 고려한 weight 배치를 찾게 한다.

마지막에는 진짜 quantized model로 export한다

QAT가 끝나면 더 이상 fake quantization이 필요 없다.

QAT fine-tuned FP weight
-> scale / zero-point 고정
-> 실제 INT8 / INT4 weight 생성
-> runtime이 읽을 quantized checkpoint로 export

정리하면:

학습 중:
FP weight + fake quantized forward

배포 시:
INT weight + scale metadata

확인

  • QAT 중 optimizer가 업데이트하는 weight는 FP인가 INT인가?
  • STE가 필요한 이유는 round()의 어떤 성질 때문인가?
  • Forward에서는 양자화된 척하고 backward에서는 미분 가능한 척한다는 말은 무슨 뜻인가?