前言
从PPO到GRPO。
GRPO是由deepseek团队提出,主要来自于DeepSeekMath
链接:https://arxiv.org/pdf/2402.03300
从PPO到GRPO
回顾PPO
回顾下 PPO的细节:
https://zhuanlan.zhihu.com/p/1937882242291598044
更稳定的做法是对loss函数做clip,此时loss函数为:
Lppo-clip(θ)=E^t[min(πθ(at∣st)πθold(at∣st)A^t,clip(πθ(at∣st)πθold(at∣st),1−ϵ,1+ϵ)A^t)]
L_{\text{ppo-clip}}(\theta) = \hat{\mathbb{E}}_t \left[ \min\left( \frac{\pi_\theta(a_t \vert s_t)}{\pi_{\theta_{\text{old}}}(a_t \vert s_t)} \hat{A}_t, \text{clip}\left( \frac{\pi_\theta(a_t \vert s_t)}{\pi_{\theta_{\text{old}}}(a_t \vert s_t)}, 1 - \epsilon, 1 + \epsilon \right) \hat{A}_t \right) \right]
Lppo-clip(θ)=E^t[min(πθold(at∣st)πθ(at∣st)A^t,clip(πθold(at∣st)πθ(at∣st),1−ϵ,1+ϵ)A^t)]
如果按照sts_tst为用户的输入query和上下文,ata_tat为LLM ouput的token, 则可以表示为:
JPPO(θ)=Eq∼P(Q), o∼πθold(O∣q)1∣o∣∑t=1∣o∣min[πθ(ot∣q,o<t)πθold(ot∣q,o<t)At, clip(πθ(ot∣q,o<t)πθold(ot∣q,o<t),1−ε,1+ε)At], \mathcal{J}_{\text{PPO}}(\theta) = \mathbb{E}_{q \sim P(Q),\ o \sim \pi_{\theta_{\text{old}}}(O|q)} \frac{1}{|o|} \sum_{t=1}^{|o|} \min\left[ \frac{\pi_\theta(o_t \mid q, o_{< t})}{\pi_{\theta_{\text{old}}}(o_t \mid q, o_{< t})} A_t,\ \operatorname{clip}\left( \frac{\pi_\theta(o_t \mid q, o_{< t})}{\pi_{\theta_{\text{old}}}(o_t \mid q, o_{< t})}, 1 - \varepsilon, 1 + \varepsilon \right) A_t \right], JPPO(θ)=Eq∼P(Q), o∼πθold(O∣q)∣o∣1t=1∑∣o∣min[πθold(ot∣q,o<t)πθ(ot∣q,o<t)At, clip(πθold(ot∣q,o<t)πθ(ot∣q,o<t),1−ε,1+ε)At],
为了减轻对reward model的过度优化,标准方法是在每个token level reward中添加KL散度的惩罚项:
rt=rφ(q,o≤t)−βlogπθ(ot∣q,o<t)πref(ot∣q,o<t),
r_t = r_\varphi(q, o_{\leq t}) - \beta \log \frac{\pi_\theta(o_t \vert q, o_{< t})}{\pi_{\text{ref}}(o_t \vert q, o_{< t})},
rt=rφ(q,o≤t)−βlogπref(ot∣q,o<t)πθ(ot∣q,o<t),
由于PPO中Value Model通常是与PolicyModel规模相当的模型,这带来了巨大的内存和计算负担。
此外,在RL训练中,Value Model被用作计算优势函数的baseline以减少方差。而在大语言模型的情况下,通常只有最后一个token被Reward Model分配奖励分数,这可能会使训练在每个token上都准确的价值函数变得复杂。
GRPO
主要的操作是采样多个输出的平均reward作为baseline,而不用单独的Value Model去预测baseline,节省了Value Model,从而显著的减少了训练资源的使用。
流程基本上和PPO大差不差,主要的区分点在于不需要Value Model去预估baseline,而是直接采样一个Group sample output,根据采样多个输出的reward的平均作为baseline。(这里和RLOO思路差不多: https://zhuanlan.zhihu.com/p/1932821169993675011)
作为基线,用于回答相同问题。对于每个问题 qqq,从旧策略πθold\pi_{\theta_{\text{old}}}πθold 中采样一组输出 $ {o_1, o_2, \cdots, o_G}$,然后通过最大化以下目标来优化策略模型:
JGRPO(θ)=E[q∼P(Q),{oi}i=1G∼πθold(O∣q)]1G∑i=1G1∣oi∣∑t=1∣oi∣{min[πθ(oi,t∣q,oi,<t)πθold(oi,t∣q,oi,<t)A^i,t,clip(πθ(oi,t∣q,oi,<t)πθold(oi,t∣q,oi,<t),1−ε,1+ε)A^i,t]−βDKL[πθ∥πref]} \begin{align*} \mathcal{J}_{\text{GRPO}}(\theta) &= \mathbb{E} \left[ q \sim P(Q), \{o_i\}_{i=1}^G \sim \pi_{\theta_{\text{old}}}(O|q) \right] \\ &\quad \frac{1}{G} \sum_{i=1}^G \frac{1}{|o_i|} \sum_{t=1}^{|o_i|} \left\{ \min \left[ \frac{\pi_\theta(o_{i,t}|q, o_{i, < t})}{\pi_{\theta_{\text{old}}}(o_{i,t}|q, o_{i,< t})} \hat{A}_{i,t}, \text{clip} \left( \frac{\pi_\theta(o_{i,t}|q, o_{i,< t})}{\pi_{\theta_{\text{old}}}(o_{i,t}|q, o_{i,< t})}, 1 - \varepsilon, 1 + \varepsilon \right) \hat{A}_{i,t} \right] - \beta \mathbb{D}_{\text{KL}} \left[ \pi_\theta \|\pi_{\text{ref}} \right] \right\} \end{align*} JGRPO(θ)=E[q∼P(Q),{oi}i=1G∼πθold(O∣q)]G1i=1∑G∣oi∣1t=1∑∣oi∣{min[πθold(oi,t∣q,oi,<t)πθ(oi,t∣q,oi,<t)A^i,t,clip(πθold(oi,t∣q,oi,<t)πθ(oi,t∣q,oi,<t),1−ε,1+ε)A^i,t]−βDKL[πθ∥πref]}
其中 $ \varepsilon $ 和 $ \beta $ 是超参数,$ \hat{A}_{i,t} $ 是仅基于每组内部输出的相对奖励计算的优势。
这里还会有个区别于PPO的部分,直接在Loss中添加KL(这里相当于把PPO的两种形式相结合),还有个不一样的细节:
GRPO使用无偏估计来估计 KL 散度:
DKL[πθ∥πref]=πref(oi,t∣q,oi,<t)πθ(oi,t∣q,oi,<t)−logπref(oi,t∣q,oi,<t)πθ(oi,t∣q,oi,<t)−1,
\mathbb{D}_{KL} \left[ \pi_\theta \|\pi_{\text{ref}} \right] = \frac{\pi_{\text{ref}}(o_{i,t} \mid q, o_{i,< t})}{\pi_\theta(o_{i,t} \mid q, o_{i,< t})} - \log \frac{\pi_{\text{ref}}(o_{i,t} \mid q, o_{i,< t})}{\pi_\theta(o_{i,t} \mid q, o_{i,< t})} - 1,
DKL[πθ∥πref]=πθ(oi,t∣q,oi,<t)πref(oi,t∣q,oi,<t)−logπθ(oi,t∣q,oi,<t)πref(oi,t∣q,oi,<t)−1,
这个值一定为正。
可以通过这个图再梳理一下,非常清晰:
Ref
- https://pub.towardsai.net/group-relative-policy-optimization-grpo-illustrated-breakdown-explanation-684e71b8a3f2
- https://arxiv.org/pdf/2402.03300