LLM RLHF——PPO 代码笔记
文章目录
本文主要是针对已经有了 PPO for LLM理论基础的,但是缺乏 PPO 实践的同学。本人在看了一些 PPO 视频和文章,对于 PPO 的理论有了一定的了解,但是不知道是如何用代码实践的。
学计算机就不能不学代码,学习写代码的第一步要先运行代码和读代码,而 PPO 更是网上对于理论解释的内容远多于代码解读,所以本文主要内容就是对 PPO 的代码进行逐行 debug 搞清楚流程。
读完本文你将知道:
- PPO 算法的具体流程(你可以尝试自己实现)
- PPO 算法的实现代码(仅用 pytorch 实现,方便掌握原理,懂了以后再使用一些写好的 RLHF 库)
代码来源:
训练Reward Model
数据集示例:
可以看到结构是一个 rejected 和一个 chosen 表示偏好,并且也给出了 rejected_score
和 chosen_score
作为 label
在 main 函数进行了 args 的解析后,就进入了 train 函数,进行训练前的准备
从本地加载模型
这部分是由这个函数给出的
这里使用的是第 3 种方式,继承 PreTrainedModel 并自定义前向逻辑,适配特殊需求
具体代码
然后我们一个一个来看
- 加载模型的 config
模型的 config 是包含了模型的各种参数
比如模型的词表大小,模型的 top-k、top_p 等参数
- 获取预训练模型的类
动态生成一个类(因为不同模型需要的类不同,所以动态生成)然后返回这个类,不得不感叹 python 的灵活性 hhh
具体看类里面
初始化就是设置了模型的结构:基础模型+一个 value head
等训练的时候,在详细介绍前向传播 forward()
的过程
- 得到类以后,加载模型参数
和 AutoModelForCausalLM. from_pretrained()
类似,只不是将 AutoModelForCausalLM
换成了自己定义的类
准备数据集
在加载完模型后,还需加载 tokenizer、optimizer
然后是加载数据集,我们重要点看看处理后的数据长什么样子
首先混合数据集,然后选择 args. max_samples 个样本,然后再用 Dataset 类包装
我们具体来看如何用 Dataset 包装的
-
初始化
-
处理数据
process_data:
就干了一件事:从原始数据中,提出来想要的数据,并用字典表示
具体怎么提取出来的呢?
根据 arg 中给的数据集中 rejected chosen 对应的键值来提取,然后要将原本的字典格式的数据转换为序列(字符串),也就是 apply_chat_template
, 如下图所示
然后再对 dataset 的每一项进行如上操作后,就结束了。
进行训练
在处理完了数据后,就要进行训练,我们要重点看看训练的过程
首先从数据集中获取数据 data
data 的结构是四个 tensor(chosen、chosen_mask、rejected、rejected_mask)
然后前向传播
输入模型的数据是这样的
将 chosen(shape: 1,66)和 rejected (shape: 1,115) 填充后拼接起来得到 inputs_ids
(shape: 2,115),然后 mask 也是同样的维度,用于 mask 掉 user 的部分,不计算 loss
然后开始 forward()
函数
这里一次把 2 个序列传给模型(一个 chosen 一个 rejected),得到 hidden_states
(shape: 2 115 4096)
然后传入 value head 也就是一个线性层,得到 values (shape 2 115),因为线性层的输入是 2 115 4096 输出是 2 115 1,用 squeeze
移除大小为 1 的维度
最后计算 reward,是取序列中最后一个 token(不是填充 token)的 value 值
#通过mask 倒数,最后一个不为零的索引就是最后一个 token 的位置
eos_indices = attention_mask.size(1) - 1 - attention_mask.long().fliplr().argmax(dim=1, keepdim=True)
# 提取出第二维的index 位置的值
reward = values.gather(dim=1, index=eos_indices).squeeze(1)
最后返回 reward(shape 2) 和 outputs
然后传回到 fit()
函数,计算 loss,相当于是一个二分类任务的 loss
用的是这个损失函数
损失函数解释:
-
核心操作:
loss = -F.logsigmoid(chosen_reward - reject_reward - margin) # 若 margin 不为空 或 loss = -F.logsigmoid(chosen_reward - reject_reward) # 若 margin 为空
torch.nn.functional.logsigmoid(x)
相当于 l o g ( σ ( x ) ) log\bigl(\sigma(x)\bigr) log(σ(x)) 其中 σ ( x ) = 1 1 + e − x \sigma(x) = \frac{1}{1 + e^{-x}} σ(x)=1+e−x1是sigmoid函数。- 直观理解:希望 (chosen_reward−reject_reward−margin)尽可能大(也就是
σ
(
…
)
≈
1
\sigma(\dots)\approx 1
σ(…)≈1,这样
logsigmoid(...)
越大越好,但这里加了一个负号,所以我们想让- logsigmoid(...)
越小越好(最小化损失)。
-
为什么能表示“被选样本比拒绝样本好”?
- 当
chosen_reward
大于reject_reward
(并且若有margin也满足足够大)时,(chosen_reward - reject_reward - margin)
为正且数值越大,sigmoid(...)
就越接近 1,logsigmoid(...)
就越接近 0;所以负的 logsigmoid 就接近 0,损失很小,说明模型学习到“chosen 优于 reject”。 - 当
chosen_reward
不足够大或甚至比reject_reward
小的时候,(chosen_reward - reject_reward - margin)
变得较小或为负,sigmoid(...)
就明显小于 1,logsigmoid(...)
为一个负值,负的 logsigmoid(…) 则成为正值,损失增加,这会通过反向传播来推动模型提高chosen_reward
或降低reject_reward
。
- 当
-
含
margin
的作用:
当需要强调“优势要足够明显”,就可以设置一个margin > 0
,这样只有当 chosen_reward−reject_reward≥marginchosen_reward - reject_reward \ge marginchosen_reward−reject_reward≥margin 时,损失才会足够小,否则仍会对模型进行惩罚,引导模型拉大两者差距。 -
最后的
loss.mean()
:- 将批量中的每个 pairwise 差异的损失取平均值,以得到一个标量损失,便于反向传播。
- 若在一个 batch 中,我们有多组
(chosen_reward, reject_reward)
的数据对,这些对的损失会同时被平均在一起。
计算完损失就会,反向传播,更新参数(因为用的 deepspeed 的 self.strategy.backward(loss, ...)
所以DeepSpeed 会根据其内部配置的梯度累加逻辑(gradient_accumulation_steps
等)来决定何时真正执行参数更新,不是每一步就更新的,而是累积到一定步数才更新)
至此对于 reward model 的训练过程我们就大概了解了,接下来是 PPO 的过程了
PPO 训练过程
需要的数据长这样子
只有 user 的 prompt
加载模型
PPO 需要加载的模型有:actor、critic、reward、initial model 一共四个模型
其中 actor 是 Actor 类、critic model 是CriticModel 类、reward 类是RewardModel 类、initial model 是 Actor 类,每个类有所不同,当我们用到他们的时候再具体介绍
加载数据
先打乱数据,然后根据 arg 截断数据
prompt_data
是这样的
最后传入 PromptDataset
进行处理:提取出每个 context_message 内容,然后应用 chat_template,如下图的代码
然后对每一项应用,就得到了数据集。
训练过程
在处理完数据后就传入 trainer 开始训练了
我们重点看看训练的过程是怎样的
循环前
在循环前,首先计算了一些数,如下图。
我们来解释下:
在这段代码中,核心思路是根据已经消耗的样本(consumed_samples
)推算出之前已经完成了多少个 episode,从而在恢复(resume)训练时能够接着上一次训练进度继续往后进行。也就是说,整个循环 for episode in range(start_episode, args.num_episodes):
是为了在断点续训或重新加载时,跳过已经完成的 episode,从 start_episode
开始一直训练到设定的总 episode 数 args.num_episodes
。
让我们一步步分析关键变量:
-
steps = consumed_samples // args.rollout_batch_size + 1
- 这里的
steps
通常表示已经进行了多少个“rollout_batch”。 consumed_samples // args.rollout_batch_size
就是“已消耗的样本数量”对应多少次 rollouts。- 之所以
+1
,是因为程序可能需要从下一步 step 开始继续往后执行(类似于下标从 1 开始,而不是 0)。
- 这里的
-
start_episode = consumed_samples // args.rollout_batch_size // num_rollouts_per_episodes
- 我们先用
consumed_samples // args.rollout_batch_size
得到已经“完成”或“使用”了多少个并行 rollouts; - 然后再用
// num_rollouts_per_episodes
,可以理解为:一个 episode 需要多少个并行 rollouts(即num_rollouts_per_episodes
),那么这部分已经完成的 rollouts 对应了多少个完整的 episode; - 换言之,
start_episode
就是已经完整执行过的 episode 数。
- 我们先用
-
consumed_samples = consumed_samples % (num_rollouts_per_episodes * args.rollout_batch_size)
- 这里将
consumed_samples
重置为余数的作用,是为了处理“当前 episode 中已经消耗了一部分,但还没全部走完”的情况; - 也就是说,如果之前已经训练到第
start_episode
个 episode 了,而且在下一个 episode 的数据可能只用掉了部分,那就需要取余来表示“剩下还没用完的部分”。
- 这里将
-
for episode in range(start_episode, args.num_episodes):
- 最后这一行才是真正的“从哪一轮 episode 开始继续训练”。如果是全新训练,
consumed_samples
通常是 0,start_episode
也就会是 0; - 如果是断点续训,可能已经消耗了一些样本,就会得到一个大于 0 的
start_episode
,从而跳过前面已经完成的 episode,直接开始新的 episode。
- 最后这一行才是真正的“从哪一轮 episode 开始继续训练”。如果是全新训练,
因此,这个循环次数之所以这样设计,是为了在有断点恢复(或断点续训)需求时,不要重复训练已经做过的 episodes,从而实现训练过程的正确衔接和继续。同时,args.num_episodes
是预先设定的总 episode 数量,最终整个训练循环还是会走到 args.num_episodes
为止。
总结来说:
start_episode
通过“已消耗样本 / 每次 rollouts 数量 / 每个 episode 所需的 rollouts 数量”计算出上一次训练已经完成了多少个 episode;- 接着在
for episode in range(start_episode, args.num_episodes)
中,从这个 episode 继续往后训练至指定的最大 episode 数; - 这样就能保证在断点续训场景下的训练进度衔接以及数据一致性。
循环内
查看一下是如何根据 rand_prompts (len:1024)
这么多 prompt 来产生 experience
的
首先是在 generate_samples
中的循环
这里的 sequences (shape : 4 1626) ,4 是 4 个 prompt,1626 是 prompt+generate 的 token 数量
attention_mask (shape: 4,1626) 用于标识序列中哪些位置是有效的,哪些位置是填充或结束标记。在生成和训练过程中用于屏蔽无效位置。
action_mask (shape: 4,766) 766是 1626 减去输入 prompt 的 800 后剩下,其中也有部分是填充的 token 而不是有效的 token,用于标识在强化学习过程中哪些位置是有效的动作。在训练过程中用于屏蔽无效位置。
得到采样后的样本 samples
后返回,在逐项处理下得到 experience
具体怎么处理的呢?请往下看
- 首先根据产生的 sequence、attention_mask 计算每个 action(这里就是 token) 的概率,是使用
actor(sequence,……)
实现的(注意和之前 sample 的时候用的不同,sample 的时候是使用的generate()
方法实现的,这里是调用forward()
实现的)
action_log_probs (shape: 4 766)
-
使用参考模型计算和上面一样的每个 action 的概率,这里就不在赘述了
base_action_log_probs (shape: 4 766)
-
计算状态价值 V
这里我们可以详细看看状态价值网络是如何得到 V 的,状态价值网络和 reward model 的区别是什么
区别:
首先看网络结构,看样好像没有区别,都是在原本的网络中最后加上了一个线性层(value head)
所以区别就只是在 forward()
前向传播中了
上图是状态价值模型的前向传播,聪明同学是不是已经看出来区别了
区别就是 reward model 前向传播时计算出 value 后,只取用最后一个 token 对应的 value(如下图),因为只有最后一个 token 可以看得到所有的 token,所以用最后的 token 来当做整个句子的 reward,而状态价值模型是要返回每个 token 预期的可以达到收益。
是不是有一种书读百遍其义自见的感觉,代码读懂一部分,其他的就会触类旁通
得到 value(shape: 4 766)
- 计算 Reward 值
这里的过程和上一节的 reward model 前向传播的过程一样,就不再赘述了,得到 r (shape: 4)
- 计算初始模型(参考模型)和 Actor 模型的KL 散度
这个过程就不带大家看了,最后得到 kl (shape:4 766)
第一轮都是 0,说明 actor 和参考模型还没有区别
最后总结一下处理过程,如下图
得到 experience
部分值 reward(奖励) kl(KL散度) value(状态价值)
之后继续,还需要根据这些计算
A
G
A
E
A^{GAE}
AGAE
我们先来看看 KL 散度(shape: 4,766)是如何与 reward (shape :4)加的,我们想要的效果是,将 reward 加在最后一个 action 上,然后再全部加入 KL 散度
实现:
然后只要得到了状态价值函数、reward 就可以计算 A G A E A^{GAE} AGAE 了,接下来我们来详细看看 GAE 计算优势A 是如何计算的
公式:
具体代码:
我们得到了 A G A E A^{GAE} AGAE
距离 Loss 已经非常的近了
至此采样阶段已经结束了,到了利用阶段了,我们再返回来看我们之前是讲到了内循环,现在我们就知道内循环都干了啥,如下图
当然其中 ppo_train ()
这个我们还没有解析,接下来就看看得到了优势函数后,该如何计算 actor 的Loss、如何计算 critic 的 Loss
这里 train_step()
是进行一次 actor 训练和一次 critic 训练,如下图
Actor 训练过程(采样后的利用阶段)
看一下计算 loss 的过程
这里的 N=2(每次梯度更新用 2 个样本), Tn=766(生成的 token 数最长的是 766)
代码实现:
最后 loss 为
然后反向传播
至此 actor 的一轮训练就结束了
然后就是 critic 的训练
Critic 训练过程(采样后的利用阶段)
我们来看下 loss 是如何计算的
正常的每一步的 label 应该是这样的:
也就是上一步得到预期到最后 reward 的和作为真实值的估计,然后计算 loss
然后回报值 return 在代码中是这样的算的,看一下之前的 GAR 的计算公式,是不是优势函数+value就是每一步到最后的回报值,
然后 loss 这样算
然后反向传播
至此 critic 的一次训练就结束了
以上就是 PPO LLM 训练的过程,想必你对 PPO LLM的训练过程也有了大概的认识,有哪里不清楚可以给我发消息或者评论区讨论,欢迎大家交流