InstructGPT 精简总结:Training language models to follow instructions with human feedback

  • 2022年发表。
  • 出自论文:《Training language models to follow instructions with human feedback》,OpenAI。
  • 与 chatgpt 最相近的工作。
  • 在OpenAI官网中,又称为Aligning language models to follow instructions

1、提出动机

  • GPT-3虽牛,但仍会生成一些带偏见、不真实、有害的负面信息,有时候一本正经胡说八道。这从做研究的角度来看,确实没啥,因为你只要在某个数据集上碾压对手,那就是牛的。但对于工业实践来说,你带有这些问题的话,特别是对于大公司来说,肯定会被用户骂死,骂到产品下线为止。
  • 且GPT-1, GPT-2, GPT-3的主要任务还是续写即文字接龙,不太擅长与听你指令干活。比如,你输入“给我写一份方案”,GPT很可能输出的是“主题是关于如何入门深度学习”,而不是给你生成出一份方案。
  • InstructGPT因此而提出,它提出了一个叫 “align” 即对齐的概念,指的是使模型输出与人类真是意图更为接近,即对齐,更符合人类偏好。

2、模型结构

3、训练机制

InstructGPT的训练机制主要分为了3步:

  • 1、SFT:Supervised Fine-Tuning,有监督微调。
  • 2、训练奖励模型
  • 3、强化学习(PPO方法)

注:其中第2,3步合起来也就是常听到的RLHF了,基于人类反馈的强化学习。

在这里插入图片描述

3.1、数据的收集

在这里插入图片描述

3.2、SFT

Supervised Fine-Tuning

SFT做的事情其实就是语言模型做的预训练。和GPT-3区别在于,InstructGPT的数据为人工标的高质量数据。

  • gpt3中对于某个下游任务,采用的是few-shot的形式(任务描述+例子+prompt),通常采用固定的任务描述形式,且需要人为去测试哪种任务表述方式好。很明显了,这种方式和实际场景中用户的表述存在很大的gap。
  • InstructGPT在SFT中标注的数据正是为了消除这种gap的。
  • 具体数据来源从gpt3的真实用户请求中采样了大量下游任务的描述,然后标注人员对任务描述续写,从而得到高质量回答。真实用户请求又被称为某个任务的指令,所以也就是概念 “基于人类反馈的指令微调” 的由来。

3.3、RLHF

Reinforcement Learning from Human Feedback

3.3.1、Train reward model

这一步单独生成了一个reward model(打分模型)用于PPO里打分。
训练步骤如下:

  • 1、基于SFT训练好的模型,对输入(提问)生成多个输出(答案)(根据输出概率分布进行采样即可得到多个输出或是beam-search类的方法等),然后人工去给这些输出打分。【其实reward model = sft model + 二分类头。数据集的样子,每条数据包含三部分:instruction,chosen answer,reject answer】
  • 2、使用第一步得到的样本集,训练模型,模型的输入为一个提问+一个答案,模型的输出为对该答案的打分。【gpt模型下游的softmax改为MLP】
  • 3、既然是排序问题,那loss用的是排序问题里常用的pairwise ranking loss。其中,K就是模型生成K个答案,论文采用的是K=9;这个loss就是计算了两两的答案的损失。也就是希望真实排序靠前标签对应的打分比排序靠后的标签对应的打分要高;最后再除与所有可能的组合数(从K个里选2个)。长得也就是LR的loss形式:sigmoid+log。
  • 4、有了loss那就能反向传播更新模型参数了。
    -

3.3.2、Train policy with PPO

Proximal Policy Optimization:强化学习PPO模型是OpenAI 2017年的工作。
⭐⭐⭐这一步用到的模型是:
1、用于强化学习的SFT模型(用于生成答案)。
2、原SFT模型(用于计算loss中的KL散度,为了保证PPO学习出的模型的预测不至于偏离原预训练模型的预测太多,因为RL就是用SFT模型初始化的)
3、reward model(用于对生成的答案打分;相当于作为老师,来对模型的回答进行‘批改’,让模型愈发把知识‘对齐’到我们想要的样子。)。

⭐ 为啥需要PPO这一步呢?使用强化学习(而非监督学习)的方式更新语言模型,最大的优势是在于能够使得模型更加自由的探索更新方向,从而突破监督学习的性能天花板。
⭐ 在reward model训练完之后有奖励模型咯,用它来作为PPO训练的value function。

PPO算法更新参数的大致流程如下图:
图源于:https://mp.weixin.qq.com/s/1v4Uuc1YAZ9MRr1UWMH9xw在这里插入图片描述

  • 输入x是第三个数据集里的prompt,而输出y则是强化学习的SFT模型的输出。y会随着模型参数的更新而不一样,这里有区别与监督学习,监督学习中,训练多个epoch时同一个x对应的y是一样的,但强化学习中同样的x对应的y是会随着模型更新而变化的。

  • 这里我简单讲讲,因为主要思想也挺简单的,即使我没搞过强化学;李沐老师说如果前面标的数据够多,其实这一步可能并不需要。

  • PPO的主要思想。 强化学习的模型称为策略模型,又称为策略。其中rθ是RM,为了确保RM打分不至于被过度优化,增加了个log项,那是KL散度,为了保证PPO学习出的强化学习SFT的预测不至于偏离原SFT模型的预测太多,因为RL就是用SFT模型初始化的。这就是PPO的主要思想。

  • 语言模型预训练项。 另外地,InstructGPT在PPO的loss基础上加上了预训练损失,为了防止强化学习SFT模型只对打分这个任务过度拟合,导致泛化性能损失,所以加上了预训练语言模型的损失来确保模型在公开NLP数据集上的表现。加上了这一项,PPO就变为了PPO-ptx。

  • 有了loss就能反向传播更新模型参数了。

😄 RL 模型的优化目标是使得损失函数越大越好,损失函数可以分为三个部分,打分部分、KL 散度部分以及预训练部分。
在这里插入图片描述

在这里插入图片描述
解释版:
在这里插入图片描述

4、与先前GPT的区别

  • 解决 GPT-3 的输出与人类意图之间的 Align 问题;
  • InstructGPT的输出比GPT3的输出更好,更真实可靠,更丰富。
  • InstructGPT对有害结果的生成控制地更好,但对于偏见问题没有明显改善。
  • InstructGPT泛化性更好,在缺乏人类指令数据上也表现不错;
  • InstructGPT基于SFT后的模型在公开数据集上表现也不错。







参考链接:
[1] 李沐:InstructGPT论文精读

### 如何通过人类反馈训练语言模型以遵循指令 通过人类反馈训练语言模型以使其能够更好地遵循指令是一个复杂的过程,涉及多个阶段和技术方法。以下是对此过程的详细介绍: #### 数据收集与准备 为了使模型具备理解并执行指令的能力,数据的质量至关重要。最初的数据来源于互联网文本,这些数据用于预训练大型语言模型的基础架构[^2]。然而,这种基础模型的目标仅仅是预测下一个单词,并未针对特定的任务优化。 在项目初期,由于缺乏现成的支持指令交互的模型,研究人员采用了人工生成的方式获取高质量的输入-输出对。具体来说,OpenAI 雇佣了承包商来手动创建三类提示及其对应的理想响应[^3]。这种方法虽然劳动密集型较高,但它为后续更高效的自动化流程奠定了坚实的基础。 #### 微调 (Fine-tuning) 一旦积累了足够的高质量样本集之后,则可以利用它们进一步微调预先存在的大体量参数化神经网络结构。此步骤旨在调整原有通用性较强的语言建模能力向更加专注于理解和回应各种形式的具体指示方向转变。即让机器学习如何依据给定条件产生恰当的结果而不是单纯依赖统计规律推测可能延续的内容片段[^1]。 #### 基于偏好排序的学习机制引入 除了传统的监督式教学外,在本案例里还特别强调了一种新颖的技术手段—基于偏好的强化学习(RLHF, Reinforcement Learning from Human Feedback) 。它的工作原理大致如下:先由真人审阅若干候选答案版本;接着根据其主观判断给出相对优劣评价等级作为奖励信号反馈至算法内部;最后借助这些累积起来的经验指导系统逐步改进直至达到预期效果为止。 ```python def train_with_human_feedback(model, dataset): """ A simplified pseudo-code representation of training a model using human feedback. Args: model: Pretrained language model instance. dataset: Dataset containing input-output pairs along with preference rankings. Returns: Trained model that better follows instructions based on human preferences. """ for epoch in range(num_epochs): for batch in dataset: outputs = model(batch['input']) # Obtain human-generated ranking or direct labels as rewards reward_signal = get_reward_from_humans(outputs) # Update model parameters via reinforcement learning techniques update_model_parameters(model, reward_signal) return model ``` 上述伪代码展示了如何结合人类反馈来进行模型训练的一个简化框架。其中`get_reward_from_humans()`函数代表从真实用户处获得关于不同输出选项之间比较的信息源流转换操作序列集合体实例对象列表数组等等概念术语定义说明文档链接地址引用标记.
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

#苦行僧

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值