Temperature와 KL Distillation Loss
마지막 수정:
soft distillation의 핵심은 teacher와 student의 확률분포를 맞추는 것이다.
하지만 teacher 분포를 그대로 쓰면 문제가 생긴다. 강한 teacher는 대개 top answer에 확률을 몰아준다. 그러면 dog, fox 같은 작은 확률에 들어 있던 dark knowledge가 거의 보이지 않는다.
그래서 temperature를 쓴다.
softmax(logits / T)
T가 커질수록 분포는 부드러워진다.
T = 1
too sharpdark knowledge is mostly hidden
T = 3
useful softnesswrong-answer ranking is visible
T = 10
too flatsignal is close to uniform
q = softmax(z_teacher / T) p = softmax(z_student / T) T^2 * KL(q || p) Temperature가 하는 일
logit은 softmax에 들어가기 전의 raw score다.
cat: 5.0
dog: 3.0
fox: 1.0
bird: -1.0
T=1이면 일반 softmax다. teacher는 cat에 대부분의 확률을 준다. 이 경우 student가 보는 신호는 hard label과 크게 다르지 않다.
T=3처럼 온도를 높이면 logit 차이가 줄어든다.
cat: 5.0 / 3
dog: 3.0 / 3
fox: 1.0 / 3
bird: -1.0 / 3
이제 dog > fox > bird라는 오답 사이의 순서가 더 잘 보인다. student는 정답만 배우는 것이 아니라 teacher가 생각한 class similarity를 배운다.
하지만 T를 너무 크게 하면 분포가 거의 uniform이 된다. 그러면 모든 후보가 비슷해 보이고, 오히려 signal이 흐려진다.
too cold: 정답만 보인다.
just right: 관계가 보인다.
too hot: 구분이 흐려진다.
KL은 무엇을 맞추나
KL divergence는 두 분포가 얼마나 다른지 재는 방식이다. distillation에서는 teacher 분포를 target으로 두고 student 분포가 그쪽으로 가까워지게 한다.
teacher_soft = softmax(teacher_logits / T)
student_soft = softmax(student_logits / T)
soft_loss = KL(teacher_soft || student_soft)
직관적으로는 이렇다.
teacher: cat 52%, dog 27%, fox 14%, bird 7%
student: cat 44%, dog 27%, fox 19%, bird 10%
KL loss:
student가 teacher의 분포 모양을 더 닮도록 밀어준다.
student가 정답 class만 높이면 충분하지 않다. teacher가 dog와 fox에 남긴 상대적 확률도 맞춰야 한다.
왜 T²를 곱하나
temperature를 높이면 분포가 부드러워져서 dark knowledge는 잘 보인다. 그런데 동시에 gradient는 약해진다.
대략적으로 T가 커지면 soft-target loss의 gradient 크기가 1 / T²만큼 작아진다. 그러면 T=3에서 얻은 좋은 분포가 학습에는 너무 약하게 작용할 수 있다.
그래서 고전적인 distillation loss는 soft loss에 T²를 곱한다.
L_soft = T^2 * KL(
softmax(teacher_logits / T),
softmax(student_logits / T)
)
의미는 간단하다.
T는 정보의 모양을 바꾼다.
T^2는 그 정보가 너무 약해지지 않게 보정한다.
전체 loss에서의 위치
실전에서는 hard loss와 soft loss를 섞는 경우가 많다.
L = alpha * L_hard + beta * L_soft
여기서:
L_hard: 정답 또는 target token을 맞추는 cross-entropy
L_soft: teacher 분포를 맞추는 KL loss
hard loss는 grounding을 맡고, soft loss는 teacher의 uncertainty와 관계 구조를 옮긴다.
LLM에서는 주의할 점
이 카드는 classical soft distillation의 핵심이다. 그러나 LLM에서는 full vocabulary distribution을 모든 token 위치마다 저장하거나 teacher에서 받아오는 것이 매우 비싸다.
그래서 많은 LLM distillation은 다음 카드들에서 볼 sequence-level distillation, reasoning trace distillation처럼 teacher-generated text를 활용한다.
그래도 temperature와 KL은 중요하다. on-policy distillation이나 같은 tokenizer를 공유하는 teacher-student 설정에서는 여전히 분포 단위 feedback이 핵심 도구가 된다.
참고 자료
- Geoffrey Hinton, Oriol Vinyals, Jeff Dean, Distilling the Knowledge in a Neural Network
- 로컬 참고:
reference-books/deepseek from scratch/ch08-Knowledge-distillation-Making-powerful-models-practical.md - 로컬 참고:
reference-books/Rearchitecting LLMs/ch06-Knowledge-recovery-through-distillation.md - 로컬 참고:
reference-books/rlhf-book/ch12-Synthetic-Data.md
확인
- teacher 분포가 너무 뾰족하면 dark knowledge가 왜 잘 보이지 않는가?
- temperature를 너무 크게 하면 왜 signal이 흐려지는가?
- soft distillation loss에서
T²를 곱하는 이유는 무엇인가?