【实践】LLM RLHF——PPO 实践、代码细致解读

LLM RLHF——PPO 代码笔记


本文主要是针对已经有了 PPO for LLM理论基础的,但是缺乏 PPO 实践的同学。本人在看了一些 PPO 视频和文章,对于 PPO 的理论有了一定的了解,但是不知道是如何用代码实践的。

学计算机就不能不学代码,学习写代码的第一步要先运行代码和读代码,而 PPO 更是网上对于理论解释的内容远多于代码解读,所以本文主要内容就是对 PPO 的代码进行逐行 debug 搞清楚流程。

读完本文你将知道:

  • PPO 算法的具体流程(你可以尝试自己实现)
  • PPO 算法的实现代码(仅用 pytorch 实现,方便掌握原理,懂了以后再使用一些写好的 RLHF 库)

代码来源:

OpenRLHF

训练Reward Model

数据集示例:

可以看到结构是一个 rejected 和一个 chosen 表示偏好,并且也给出了 rejected_scorechosen_score 作为 label

在 main 函数进行了 args 的解析后,就进入了 train 函数,进行训练前的准备

从本地加载模型

这部分是由这个函数给出的

Transformers 库加载预训练模型的 4 种方式

这里使用的是第 3 种方式,继承 PreTrainedModel 并自定义前向逻辑,适配特殊需求

具体代码

然后我们一个一个来看

  1. 加载模型的 config

    模型的 config 是包含了模型的各种参数

比如模型的词表大小,模型的 top-k、top_p 等参数

  1. 获取预训练模型的类

动态生成一个类(因为不同模型需要的类不同,所以动态生成)然后返回这个类,不得不感叹 python 的灵活性 hhh

具体看类里面

初始化就是设置了模型的结构:基础模型+一个 value head

等训练的时候,在详细介绍前向传播 forward() 的过程

  1. 得到类以后,加载模型参数

AutoModelForCausalLM. from_pretrained() 类似,只不是将 AutoModelForCausalLM 换成了自己定义的类

准备数据集

在加载完模型后,还需加载 tokenizer、optimizer

然后是加载数据集,我们重要点看看处理后的数据长什么样子

首先混合数据集,然后选择 args. max_samples 个样本,然后再用 Dataset 类包装

我们具体来看如何用 Dataset 包装的

  1. 初始化

  2. 处理数据

process_data:

就干了一件事:从原始数据中,提出来想要的数据,并用字典表示

具体怎么提取出来的呢?

根据 arg 中给的数据集中 rejected chosen 对应的键值来提取,然后要将原本的字典格式的数据转换为序列(字符串),也就是 apply_chat_template, 如下图所示

然后再对 dataset 的每一项进行如上操作后,就结束了。

进行训练

在处理完了数据后,就要进行训练,我们要重点看看训练的过程

首先从数据集中获取数据 data

data 的结构是四个 tensor(chosen、chosen_mask、rejected、rejected_mask)

然后前向传播

输入模型的数据是这样的

将 chosen(shape: 1,66)和 rejected (shape: 1,115) 填充后拼接起来得到 inputs_ids (shape: 2,115),然后 mask 也是同样的维度,用于 mask 掉 user 的部分,不计算 loss

然后开始 forward() 函数

这里一次把 2 个序列传给模型(一个 chosen 一个 rejected),得到 hidden_states(shape: 2 115 4096)

然后传入 value head 也就是一个线性层,得到 values (shape 2 115),因为线性层的输入是 2 115 4096 输出是 2 115 1,用 squeeze 移除大小为 1 的维度

最后计算 reward,是取序列中最后一个 token(不是填充 token)的 value 值

#通过mask 倒数,最后一个不为零的索引就是最后一个 token 的位置
eos_indices = attention_mask.size(1) - 1 - attention_mask.long().fliplr().argmax(dim=1, keepdim=True)
# 提取出第二维的index 位置的值
reward = values.gather(dim=1, index=eos_indices).squeeze(1)

最后返回 reward(shape 2) 和 outputs

然后传回到 fit() 函数,计算 loss,相当于是一个二分类任务的 loss

用的是这个损失函数

损失函数解释:

  1. 核心操作:

    loss = -F.logsigmoid(chosen_reward - reject_reward - margin) # 若 margin 不为空 或 loss = -F.logsigmoid(chosen_reward - reject_reward) # 若 margin 为空

    • torch.nn.functional.logsigmoid(x) 相当于 l o g ( σ ( x ) ) log\bigl(\sigma(x)\bigr) log(σ(x)) 其中 σ ( x ) = 1 1 + e − x \sigma(x) = \frac{1}{1 + e^{-x}} σ(x)=1+ex1是sigmoid函数。
    • 直观理解:希望 (chosen_reward−reject_reward−margin)尽可能大(也就是 σ ( …   ) ≈ 1 \sigma(\dots)\approx 1 σ()1,这样 logsigmoid(...) 越大越好,但这里加了一个负号,所以我们想让 - logsigmoid(...) 越小越好(最小化损失)。
  2. 为什么能表示“被选样本比拒绝样本好”?

    • chosen_reward 大于 reject_reward(并且若有margin也满足足够大)时,(chosen_reward - reject_reward - margin) 为正且数值越大,sigmoid(...) 就越接近 1,logsigmoid(...) 就越接近 0;所以负的 logsigmoid 就接近 0,损失很小,说明模型学习到“chosen 优于 reject”。
    • chosen_reward 不足够大或甚至比 reject_reward 小的时候,(chosen_reward - reject_reward - margin) 变得较小或为负,sigmoid(...) 就明显小于 1,logsigmoid(...) 为一个负值,负的 logsigmoid(…) 则成为正值,损失增加,这会通过反向传播来推动模型提高 chosen_reward 或降低 reject_reward
  3. margin 的作用:
    当需要强调“优势要足够明显”,就可以设置一个 margin > 0,这样只有当 chosen_reward−reject_reward≥marginchosen_reward - reject_reward \ge marginchosen_reward−reject_reward≥margin 时,损失才会足够小,否则仍会对模型进行惩罚,引导模型拉大两者差距。

  4. 最后的 loss.mean()

    • 将批量中的每个 pairwise 差异的损失取平均值,以得到一个标量损失,便于反向传播。
    • 若在一个 batch 中,我们有多组 (chosen_reward, reject_reward) 的数据对,这些对的损失会同时被平均在一起。

计算完损失就会,反向传播,更新参数(因为用的 deepspeed 的 self.strategy.backward(loss, ...) 所以DeepSpeed 会根据其内部配置的梯度累加逻辑(gradient_accumulation_steps 等)来决定何时真正执行参数更新,不是每一步就更新的,而是累积到一定步数才更新)

至此对于 reward model 的训练过程我们就大概了解了,接下来是 PPO 的过程了

PPO 训练过程

需要的数据长这样子

只有 user 的 prompt

加载模型

PPO 需要加载的模型有:actor、critic、reward、initial model 一共四个模型

其中 actor 是 Actor 类、critic model 是CriticModel 类、reward 类是RewardModel 类、initial model 是 Actor 类,每个类有所不同,当我们用到他们的时候再具体介绍

加载数据

先打乱数据,然后根据 arg 截断数据

prompt_data 是这样的
最后传入 PromptDataset 进行处理:提取出每个 context_message 内容,然后应用 chat_template,如下图的代码

然后对每一项应用,就得到了数据集。

训练过程

在处理完数据后就传入 trainer 开始训练了

我们重点看看训练的过程是怎样的

循环前

在循环前,首先计算了一些数,如下图。

我们来解释下:

在这段代码中,核心思路是根据已经消耗的样本(consumed_samples)推算出之前已经完成了多少个 episode,从而在恢复(resume)训练时能够接着上一次训练进度继续往后进行。也就是说,整个循环 for episode in range(start_episode, args.num_episodes): 是为了在断点续训或重新加载时,跳过已经完成的 episode,从 start_episode 开始一直训练到设定的总 episode 数 args.num_episodes

让我们一步步分析关键变量:

  1. steps = consumed_samples // args.rollout_batch_size + 1

    • 这里的 steps 通常表示已经进行了多少个“rollout_batch”。
    • consumed_samples // args.rollout_batch_size 就是“已消耗的样本数量”对应多少次 rollouts。
    • 之所以 +1,是因为程序可能需要从下一步 step 开始继续往后执行(类似于下标从 1 开始,而不是 0)。
  2. start_episode = consumed_samples // args.rollout_batch_size // num_rollouts_per_episodes

    • 我们先用 consumed_samples // args.rollout_batch_size 得到已经“完成”或“使用”了多少个并行 rollouts;
    • 然后再用 // num_rollouts_per_episodes,可以理解为:一个 episode 需要多少个并行 rollouts(即 num_rollouts_per_episodes),那么这部分已经完成的 rollouts 对应了多少个完整的 episode;
    • 换言之,start_episode 就是已经完整执行过的 episode 数
  3. consumed_samples = consumed_samples % (num_rollouts_per_episodes * args.rollout_batch_size)

    • 这里将 consumed_samples 重置为余数的作用,是为了处理“当前 episode 中已经消耗了一部分,但还没全部走完”的情况;
    • 也就是说,如果之前已经训练到第 start_episode 个 episode 了,而且在下一个 episode 的数据可能只用掉了部分,那就需要取余来表示“剩下还没用完的部分”。
  4. for episode in range(start_episode, args.num_episodes):

    • 最后这一行才是真正的“从哪一轮 episode 开始继续训练”。如果是全新训练,consumed_samples 通常是 0,start_episode 也就会是 0;
    • 如果是断点续训,可能已经消耗了一些样本,就会得到一个大于 0 的 start_episode,从而跳过前面已经完成的 episode,直接开始新的 episode。

因此,这个循环次数之所以这样设计,是为了在有断点恢复(或断点续训)需求时,不要重复训练已经做过的 episodes,从而实现训练过程的正确衔接和继续。同时,args.num_episodes 是预先设定的总 episode 数量,最终整个训练循环还是会走到 args.num_episodes 为止。

总结来说:

  • start_episode 通过“已消耗样本 / 每次 rollouts 数量 / 每个 episode 所需的 rollouts 数量”计算出上一次训练已经完成了多少个 episode;
  • 接着在 for episode in range(start_episode, args.num_episodes) 中,从这个 episode 继续往后训练至指定的最大 episode 数;
  • 这样就能保证在断点续训场景下的训练进度衔接以及数据一致性。
循环内

查看一下是如何根据 rand_prompts (len:1024) 这么多 prompt 来产生 experience

首先是在 generate_samples 中的循环

这里的 sequences (shape : 4 1626) ,4 是 4 个 prompt,1626 是 prompt+generate 的 token 数量

attention_mask (shape: 4,1626) 用于标识序列中哪些位置是有效的,哪些位置是填充或结束标记。在生成和训练过程中用于屏蔽无效位置。

action_mask (shape: 4,766) 766是 1626 减去输入 prompt 的 800 后剩下,其中也有部分是填充的 token 而不是有效的 token,用于标识在强化学习过程中哪些位置是有效的动作。在训练过程中用于屏蔽无效位置。

得到采样后的样本 samples 后返回,在逐项处理下得到 experience 具体怎么处理的呢?请往下看

  1. 首先根据产生的 sequence、attention_mask 计算每个 action(这里就是 token) 的概率,是使用 actor(sequence,……) 实现的(注意和之前 sample 的时候用的不同,sample 的时候是使用的 generate() 方法实现的,这里是调用 forward() 实现的)



action_log_probs (shape: 4 766)

  1. 使用参考模型计算和上面一样的每个 action 的概率,这里就不在赘述了

    base_action_log_probs (shape: 4 766)

  2. 计算状态价值 V

这里我们可以详细看看状态价值网络是如何得到 V 的,状态价值网络和 reward model 的区别是什么

区别:
首先看网络结构,看样好像没有区别,都是在原本的网络中最后加上了一个线性层(value head)

所以区别就只是在 forward() 前向传播中了

上图是状态价值模型的前向传播,聪明同学是不是已经看出来区别了

区别就是 reward model 前向传播时计算出 value 后,只取用最后一个 token 对应的 value(如下图),因为只有最后一个 token 可以看得到所有的 token,所以用最后的 token 来当做整个句子的 reward,而状态价值模型是要返回每个 token 预期的可以达到收益。

是不是有一种书读百遍其义自见的感觉,代码读懂一部分,其他的就会触类旁通

得到 value(shape: 4 766)

  1. 计算 Reward 值

这里的过程和上一节的 reward model 前向传播的过程一样,就不再赘述了,得到 r (shape: 4)

  1. 计算初始模型(参考模型)和 Actor 模型的KL 散度

这个过程就不带大家看了,最后得到 kl (shape:4 766)

第一轮都是 0,说明 actor 和参考模型还没有区别

最后总结一下处理过程,如下图

得到 experience 部分值 reward(奖励) kl(KL散度) value(状态价值) 之后继续,还需要根据这些计算 A G A E A^{GAE} AGAE

我们先来看看 KL 散度(shape: 4,766)是如何与 reward (shape :4)加的,我们想要的效果是,将 reward 加在最后一个 action 上,然后再全部加入 KL 散度
实现:

然后只要得到了状态价值函数、reward 就可以计算 A G A E A^{GAE} AGAE 了,接下来我们来详细看看 GAE 计算优势A 是如何计算的

公式:

具体代码:

我们得到了 A G A E A^{GAE} AGAE

距离 Loss 已经非常的近了

至此采样阶段已经结束了,到了利用阶段了,我们再返回来看我们之前是讲到了内循环,现在我们就知道内循环都干了啥,如下图

当然其中 ppo_train () 这个我们还没有解析,接下来就看看得到了优势函数后,该如何计算 actor 的Loss、如何计算 critic 的 Loss

这里 train_step() 是进行一次 actor 训练和一次 critic 训练,如下图

Actor 训练过程(采样后的利用阶段)

看一下计算 loss 的过程

这里的 N=2(每次梯度更新用 2 个样本), Tn=766(生成的 token 数最长的是 766)

代码实现:

最后 loss 为

然后反向传播

至此 actor 的一轮训练就结束了

然后就是 critic 的训练

Critic 训练过程(采样后的利用阶段)

我们来看下 loss 是如何计算的
正常的每一步的 label 应该是这样的:

也就是上一步得到预期到最后 reward 的和作为真实值的估计,然后计算 loss

然后回报值 return 在代码中是这样的算的,看一下之前的 GAR 的计算公式,是不是优势函数+value就是每一步到最后的回报值,

然后 loss 这样算

然后反向传播

至此 critic 的一次训练就结束了

以上就是 PPO LLM 训练的过程,想必你对 PPO LLM的训练过程也有了大概的认识,有哪里不清楚可以给我发消息或者评论区讨论,欢迎大家交流

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值