LLM Sequence Distillation
마지막 수정:
앞 카드까지는 distillation target을 두 가지로 나눴다.
soft distillation: teacher의 확률분포를 맞춘다.
hard distillation: teacher가 고른 token이나 text를 맞춘다.
LLM으로 오면 이 차이가 더 커진다. classifier는 입력 하나에 대해 class 분포 하나를 내지만, autoregressive LLM은 긴 답변을 token by token으로 만든다.
그래서 LLM distillation의 실전 질문은 이것이다.
teacher의 모든 next-token 분포를 저장하고 맞출 것인가?
아니면 teacher가 만든 완성된 답변 sequence를 target으로 학습할 것인가?
1. Prompt set
s Start with tasks, questions, instructions, or problems.
2. Teacher generation
u = teacher(s) The large model produces a complete answer sequence.
3. Stored target
(s, u) The generated sequence becomes offline training data.
4. Student SFT
-log p(u | s) Train only on answer tokens with next-token cross-entropy.
-log p_student(u | s) = sum_t -log p_student(u_t | s, u_<t) Word / token KD
- Signal
- teacher distribution at every next-token position
- Needs
- logits or logprobs over the vocabulary
Sequence KD
- Signal
- one high-probability teacher output sequence
- Needs
- generated text is enough
LLM instruction distill
- Signal
- prompt plus teacher response, often filtered for quality
- Needs
- dataset curation and response-token masking
Word-level KD에서 Sequence-level KD로
Kim & Rush의 sequence-level knowledge distillation은 이 전환을 이해하는 좋은 다리다.
word-level KD는 각 위치마다 teacher의 next-token 분포를 student가 맞추는 방식이다.
prompt + previous tokens
-> teacher distribution over vocabulary
-> student distribution over vocabulary
-> KL or cross-entropy between distributions
이 방식은 dark knowledge를 잘 옮긴다. 하지만 긴 LLM 답변에서는 매 token 위치마다 vocabulary 전체 분포를 다뤄야 한다. 답변이 길고 vocabulary가 크면 저장, 전송, 학습 비용이 바로 커진다.
sequence-level KD는 관점을 바꾼다. 가능한 모든 출력 sequence의 분포를 직접 합산하는 것은 불가능하므로, teacher가 만든 고확률 출력 하나를 대표 target으로 놓는다.
s = prompt
u = teacher가 생성한 complete output sequence
loss = -log p_student(u | s)
= sum_t -log p_student(u_t | s, u_<t)
이제 student는 teacher의 full distribution을 보지 않는다. 대신 teacher가 실제로 생성한 sequence를 정답처럼 보고 next-token prediction으로 배운다.
LLM에서는 SFT처럼 보인다
현대 LLM distillation에서 이 방식은 거의 SFT 데이터셋처럼 보인다.
input: instruction / problem / prompt
target: teacher response
loss: response token에 대한 cross-entropy
중요한 점은 prompt token과 answer token을 구분하는 것이다. student가 벌점을 받아야 하는 대상은 prompt를 복사하는 능력이 아니라, prompt가 주어졌을 때 teacher response를 이어 쓰는 능력이다.
그래서 distillation loss는 보통 answer token 구간에만 걸린다.
[prompt tokens][teacher answer tokens]
^^^^^^^^^^^^^^^^^^^^^
loss is computed here
이 구조 때문에 LLM sequence distillation은 “teacher-generated synthetic data로 하는 supervised fine-tuning”처럼 구현된다.
Offline distillation이라는 말의 의미
sequence distillation은 보통 offline distillation이다.
1. teacher로 답변을 미리 생성한다.
2. 품질이 낮은 답변을 필터링하거나 재생성한다.
3. 저장된 dataset으로 student를 학습한다.
4. student training 중에는 teacher를 다시 돌리지 않는다.
이 분리는 실용적으로 매우 크다. teacher가 거대한 모델이어도 student 학습 시점에는 teacher를 GPU에 올릴 필요가 없다. API teacher를 쓴다면 generation cost와 training cost도 분리된다.
DeepSeek-R1의 distilled model들도 이 관점으로 읽을 수 있다. 큰 reasoning teacher가 만든 reasoning trace와 final answer를 curated dataset으로 만들고, 작은 dense model을 SFT한다. 여기서 핵심 신호는 logits가 아니라 teacher가 생성한 긴 sequence다.
무엇을 잃고 무엇을 얻는가
sequence distillation이 얻는 것은 단순성과 확장성이다.
teacher text만 있으면 된다.
closed API teacher도 쓸 수 있다.
teacher와 student tokenizer가 달라도 다루기 쉽다.
dataset을 검수하고 재사용할 수 있다.
teacher를 student training 중에 같이 실행하지 않아도 된다.
대신 잃는 것도 있다.
teacher가 고려한 대안 token의 확률 구조는 사라진다.
teacher가 한 번 뽑은 표현 방식에 student가 과하게 묶일 수 있다.
teacher output 품질이 낮으면 그 오류와 말투도 같이 증류된다.
student가 자기 prefix에서 벗어났을 때 teacher supervision을 직접 받지 못한다.
마지막 한계가 중요하다. offline sequence distillation은 teacher가 만든 prefix 위에서 학습한다. 하지만 실제 inference에서는 student가 자기 token을 만들어 가며 다른 prefix로 들어갈 수 있다. 이 train-test mismatch가 다음 단계에서 on-policy distillation을 생각하게 만드는 이유다.
정리
LLM sequence distillation의 본질은 다음 한 줄이다.
teacher distribution을 직접 복사하는 대신,
teacher가 만든 complete answer sequence를 training target으로 삼는다.
이 관점이 잡히면 reasoning distillation도 쉽게 읽힌다. reasoning distillation은 단순히 final answer만 target으로 삼는 것이 아니라, teacher가 만든 풀이 과정과 검증 습관까지 sequence 안에 넣어 student가 모방하게 만드는 방식이다.
참고 자료
- Yoon Kim, Alexander M. Rush, Sequence-Level Knowledge Distillation
- Geoffrey Hinton, Oriol Vinyals, Jeff Dean, Distilling the Knowledge in a Neural Network
- DeepSeek-AI, DeepSeek-R1: Incentivizing Reasoning Capability in LLMs via Reinforcement Learning
- 로컬 참고:
reference-books/rlhf-book/ch12-Synthetic-Data.md - 로컬 참고:
reference-books/reasoning-model-from-scratch/ch08-Distilling-reasoning-models-for-efficient-reasoning.md - 로컬 참고:
reference-books/deepseek from scratch/ch08-Knowledge-distillation-Making-powerful-models-practical.md
확인
- word-level KD와 sequence-level KD는 target이 어떻게 다른가?
- LLM sequence distillation이 SFT처럼 구현되는 이유는 무엇인가?
- offline sequence distillation이 teacher logits 없이도 가능한 이유는 무엇인가?