PPO for LLMs
마지막 수정:
PPO는 RLHF 초기에 가장 대표적으로 쓰인 policy-gradient 알고리즘이다.
핵심 목적은 간단하다.
좋은 completion의 token 확률은 올리고,
나쁜 completion의 token 확률은 낮추되,
한 번에 너무 멀리 가지 않게 제한한다.
policy ratio
PPO는 rollout을 만든 old policy와 현재 update 중인 policy를 비교한다.
ratio = pi_new(token | state) / pi_old(token | state)
ratio가 1이면 새 policy와 old policy가 그 token을 같은 정도로 좋아한다는 뜻이다.
ratio > 1
새 policy가 token을 더 가능하게 만듦
ratio < 1
새 policy가 token을 덜 가능하게 만듦
clipping
advantage가 양수이면 그 token을 더 강화하고 싶다. 하지만 무한히 강화하면 policy가 망가진다.
ratio를 1 - epsilon ~ 1 + epsilon 범위로 제한
이 clipping이 PPO의 핵심 안정화 장치다. 좋은 token도 너무 빨리 강화하지 않고, 나쁜 token도 너무 과하게 억제하지 않는다.
value model과 advantage
PPO는 보통 value model을 둔다.
value model:
이 상태에서 앞으로 얻을 reward가 어느 정도인지 예측
advantage:
실제 rollout이 value estimate보다 얼마나 좋거나 나빴는지
이 baseline이 있으면 raw reward보다 variance가 줄어든다. 대신 policy, reference, reward, value model을 함께 다뤄야 하므로 메모리와 구현 복잡도가 커진다.
LLM PPO가 어려운 이유
LLM에서는 action이 token sequence다.
prompt token은 loss에서 빼야 한다
padding token도 빼야 한다
generated token별 logprob를 저장해야 한다
KL penalty를 token 또는 sequence 단위로 계산해야 한다
그래서 PPO는 알고리즘보다 구현 디테일에서 많이 흔들린다.
GRPO와의 연결
GRPO는 PPO의 clipping/reference-control 아이디어를 일부 유지하면서 value model을 없애는 쪽으로 간다. value baseline 대신 같은 prompt에서 뽑은 여러 completion의 reward를 비교한다.
PPO:
learned value model로 advantage 추정
GRPO:
group reward mean/std로 advantage 추정
참고 자료
- John Schulman et al., Proximal Policy Optimization Algorithms
- Nathan Lambert, The RLHF Book
- 로컬 참고:
reference-books/rlhf-book/ch06-Reinforcement-Learning.md
확인
- PPO의 policy ratio는 무엇을 비교하는가?
- clipping은 왜 필요한가?
- PPO가 GRPO보다 시스템적으로 무거운 이유는 무엇인가?