生成式大模型的RLHF技术(一):基础

一、概述

大语言模型(LLMs)在预训练的过程中通常会捕捉数据的特征,而这些训练数据通常既包含高质量的也包含低质量的,因此模型有时会产生不被期望的行为,如编造事实,生成有偏见或有毒的文本,甚至对人类有害的内容。因此,将LLMs与人类价值观(如helpful, honest, 和harmless, 即3H)对齐是非常重要的,目前采用的主流的技术即是基于人类反馈的强化学习技术(RLHF)。

通常来说,RLHF包括三个步骤: ①supervised fine-tuning (SFT):对LLMs进行微调,LLMs通过模仿人类标注的对话示例来学习通用的的类似人类的对话。 ②reward model (RM) training:对于模型对同一个prompt的多个回复,利用人类标注来进行排序以获取人类偏好,然后单独使用另一个语言模型作为reward model,在这个reward model上使用标注的数据进行训练(类似排序任务)。 ③proximal policy optimization (PPO):以训练得到的reward model作为reward function来继续训练优化LLMs,促进其与人类偏好的对齐。

cc7f6bee893fa175ae4ffd1f33e01a30.png
RLHF

RLHF的整个过程如上图所示,也可参考OpenAI的InstructGPT的做法:InstructGPT:语言模型的人类反馈指令对齐。InstructGPT使用的数据规模如下表所示:

f837e6b24e62c95af62f1625dbd73261.jpeg
数据规模

本文会介绍基于PPO(proximal policy optimization)的RLHF的技术细节,包括Reward Modeling、Policy Gradient Methods、Advantage Actor-Critic (A2C)等。本文主要介绍A2C框架,其中包括四个模型的调度: ①Policy Model/Actor Model:由SFT之后的模型初始化而来。作为策略(policy)模型,用于接收上文,做出动作,预测下一个字符。学习完毕之后,我们最终使用的就是这个模型。 ②Reference Model:和Actor Model同样初始化自SFT Model,训练过程中冻结参数,用于和Actor Model做对比,保证模型不要偏离原始SFT Model太多。 ③Reward Model:作为环境(env),训练过程中冻结参数,针对每一个状态,给出奖励分数。 ④Critic Model:由Reward Model初始化而来,用于近似价值函数,输入为当前的状态,估计当前状态的价值。

二、Reward Modeling

Reward model可以使用移除了最后一个unembedding层的预训练语言模型来作为基础架构,通常就是将最后一个token最终的embedding输入给一个线性层,然后得到一个标量值,即是reward的值。在InstructGPT中,作者尝试用了1.3B、6B和175B的GPT-3来做实验,最终综合考虑只用6B的模型来训练reward model。在训练reward model时,对于同一个输入prompt,有一个更偏好的输出和一个相对不偏好的输出。每一对偏好和不偏好的回复的损失为:

❝ ❞

这里的是sigmoid函数,代表reward model,其参数为,是一个reward model为和预测的标量得分。另外,也可以额外加上一个模仿学习(imitation learning)的损失,也就是一个语言模型的预训练损失,来让模型模仿句子对中更偏好的那一个:

❝ ❞

这里的和都是超参数,是训练集的经验分布,与除了顶部线性层不同以外是同一个模型(线性层的维度为词典的大小),是给定prompt和偏好回复后的似然。

在PPO阶段,使用训练得到的reward model来作为reward function训练policy模型时,还可以为reward function添加一个基于当前policy模型和SFT模型之间KL散度的惩罚项,此时reward function为:

❝ ❞

这里的是一个超参数。这个KL散度的惩罚项,通常来说有两个作用: ①作为一个entropy bonus,促进模型在policy空间中的探索,防止policy model过早地收敛到一个单一的模式。 ②可以确保强化学习policy model的输出不会与reward model在训练阶段遇到的样本严重偏离。

三、Reinforcement Learning

将强化学习应用于对话生成是一种艰难的挑战,这是因为其状态-动作空间(state-action space)是非常巨大和广阔的。在这样的场景中,我们将人类的互动作为环境(environment)。在每个时间步,agent(也就是LLMs)将从环境(即对话历史)中接收一个状态,其中包括所有的对话历史文本(也就是prompt和已经生成的回复)。接着,基于policy,agent的动作为生成下一个token。环境相应的也会反馈一个reward,这个reward来自于一个reward function,是从人类偏好数据中训练得到的,也就是前文的reward model。此后,agent将转换到下一个状态。整个过程将得到一个轨迹(trajectory)。对于LLMs的一个输入和输出来说,,,可采取的动作的所有选择即是模型词典中的所有token。强化学习的目的即是最大化一个轨迹的累积reward(也就是回报,return)。一种有限期无折扣回报( finite-horizon undiscounted return)为,即有限步数的累积reward的加和。另一种无限期折扣回报( infinite-horizon discounted return)为,这种计算方式考虑了整个轨迹上所获得的的所有回报,其中为折扣率。

  1. 「Policy Gradient Methods」

在介绍最常用的A2C方法之前,先介绍下更基础一些的policy gradient方法。Policy gradient方法是一种强化学习技术,其直接优化agent的policy,也就是从state到action的映射。Policy gradient方法的核心思想是使用梯度上升方法直接优化policy。从本质上讲,这类方法调整policy model的参数,使其朝着最大限度地提高预期return的方向优化。Policy通常由参数化,我们将其表示为,即在状态下采取动作的概率。Policy gradient的参数更新方式为:

❝ ❞

这里的是学习率,表示当采用policy时的期望return,其梯度被称为policy gradient。一个policy gradient的通用形式为:

❝ ❞

这里的可以是, ,(是一个baseline)中的任意一个。这些不同的形式会得到policy gradient相同的期望值,但会得到不同的方差。

Return通过蒙特卡洛采样的方式来计算。如果得到的return是有利的,则会增大生成这些动作的概率。这种方法的优势在于其是无偏的,因为其只依赖于实际的return,而非需要估计它。然而这种方法具有非常高的方差,因为同一个prompt产生的不同的轨迹会计算得到不同的return值,这是由于环境(一个episode中的随机事件)和policy本身的随机性决定的。

为了降低这里的高方差,一种常用的策略是使用优势函数(advantage function)估计来替代原始return,即。优势函数表示在状态时采取动作相比于同一状态下所有动作的平均质量来说是否会更好。这里需要引入两个概念: ①动作的价值:在时刻,给定当前状态,采取动作,可以获得的return的期望,具体的,。 ②状态的价值:在时刻,给定当前状态,可以获得的return的期望,具体的,,其中是所有动作的集合。这里就是在上文状态下,求和所有下一个token出现概率与token对应价值的乘积。

有了这两个概念即可得到优势函数。优势函数的意义在于只有当前动作的return比平均水平更高时才能获得正的优势值,从而被增强,如果低于平均水平就会被抑制。

使用优势函数的policy gradient方法是强化学习邻域的重要支柱。优势函数的估计方法是多种多样的,其中有一种广泛采用的估计方法为广义优势估计(Generalized Advantage Estimation, GAE),下一节将着重介绍。

  1. 「Generalized Advantage Estimation」

优势函数定义为函数与价值函数的差值。考虑一个具体的动作,而价值函数是所有可能的动作的平均。而在实践中,我们使实际episode的return来估计函数,也就是蒙特卡洛采样的return,这会引入非常高的方差,因为未来的reward包含大量的噪声。一种减少这种噪声的方法是使用价值函数来估计未来(时间步之后)的return,也就是Temporal Difference (TD)的方法,这种方法通常有较大的偏差。本节要介绍的GAE算法是一种介于使用one-step TD return(高偏差)和完全蒙特卡洛return(高方差)之间的算法,可以平衡偏差和方差。接下来的内容即是对GAE算法的推导。

首先,我们用来表示TD- return,这是一种实际的reward和估计的return的集合:

❝ ❞

这里的是折扣率。折扣率的意义可以这样理解:每一步虽然都有一个即时的reward,但是每一步对后面的可能状态都是有影响的,即后面的动作获取的reward都能累计到前面的动作的贡献。不过直接加上去可能不好,毕竟不是前面的动作直接获取的reward,但是可以打个折扣再加上去,即乘个小于1的。

使用TD- return的优势称为-step优势,定义为:

❝ ❞

这里的,叫做TD error。在-step优势中,如果比较小,偏差会很高,因为优势估计只基于很少的步数,所以非常依赖价值函数的准确性。而在比较大时,方差会非常高,因为优势估计会把很多噪声reward加进来。

为了平衡偏差和方差,GAE定义优势函数为-step优势的指数移动平均,权重为:

❝ ❞

GAE可以平滑地平衡高偏差()和高方差():

❝ ❞

通过GAE,我们可以准确地得到优势函数的估计值。这一估计值在进行policy gradient估计时将扮演重要的角色:

❝ ❞

这里的是有限批量的样本。后面我们将用来表示。

  1. 「Proximal Policy Optimization」

PPO和TRPO是RL中的两种关键技术,旨在有效地训练policy而不损害其稳定性。这些方法的遵循“小而稳定的步骤”的思想,即轻微地进行policy的优化,而不是强制进行可能破坏整个学习过程的激进更新。

在传统的强化学习中,policy gradient的原则要求新旧policy在参数空间中保持接近。然而,参数空间中的这种接近并不一定等同于相似的性能,参数的轻微变化可能会极大地影响策略的有效性。这一部分原因要归结于神经网络是过参数化的,类似微调语言模型的LoRA方法,无需微调所有的模型参数即可将模型适配到特定的下游任务上。此外,如果不加限制地大步更新,就可能导致policy表现崩溃,这种情况通常被描述为“掉下悬崖(falling off the cliff)”。这种固有的风险是原始的policy gradient中样本效率的限制因素。

  • TRPO

TRPO的方法没有受到参数接近性(parameter closeness)的限制,而是对policy更新引入了一种不同的约束。其通过确保新旧policy model的KL散度在一个可接受的限制范围内来正则化policy的更新:

❝ ❞

这里的是更新之前旧的policy参数。

  • PPO-Penalty

PPO的方法有两个主要的变种:PPO-Penalty和PPO-Clip。TRPO是一种KL散度的硬性限制,而PPO-Penalty通过采用基于惩罚的方法而不是约束来以无约束优化问题的方式优化policy:

❝ ❞

这里的是惩罚因子。

  • PPO-Clip

PPO-Clip试图保持新policy接近旧policy,但不像TRPO那样对KL散度施加约束,而是在其目标函数中根据新旧policy的预测概率的比值进行截断。目标函数可以表示为:

❝ ❞

这里的是新policy与旧policy的概率的比例,是一个超参数,用于控制新policy能够偏离旧policy多远。函数将概率的比值限制在之间。函数充当正则化器,限制policy从一个迭代到下一个迭代的急剧变化的程度。防止过大的policy更新确保了学习过程的鲁棒性,同时保持了比普通policy gradient方法更高效的样本学习。

  • 价值函数估计

在PPO中,文章开篇提到的critic model通常用来作为价值函数,来评估每个状态的期望return。其学习目标为最小化其预测值与真实return之间的差异,目标函数通常采用MSE损失,即:

❝ ❞

这里的代表critic model(参数为)在状态的预测值,代表实际状态的return值,通常估计为。

  • 混合预训练梯度

为了缓解PPO训练后的模型在通用语言能力上的退化和灾难性遗忘问题,可以在强化学习训练过程中加入预训练数据,这种方法通常称为PPO-ptx,其损失函数为:

❝ ❞

这里的是一个超参数,是预训练数据的分布。

四、总结

综合前面的描述,整个PPO训练过程的算法可以表示为:

45668f15a62ca77f5b012aa980c6438c.jpeg
算法

另外也可参考整个流程的框架图:

283f5dc04956456e75b34fe191d0c949.jpeg
框架

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/166842.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

商业园区的万能管理法,还怪高级的咧!

随着社会的不断发展和科技的飞速进步,视频监控技术已经成为维护安全、提高效率以及实现智能化管理的关键工具。 在这个信息时代,人们对于安全和管理的需求不断提升,而视频监控系统作为一种强大而灵活的解决方案,正日益受到各行各业…

QQ同步通讯录,详细操作方法来了!

腾讯QQ是一款功能丰富的即时通信软件,能够让用户随时随地与好友保持联系,不受时间和地域限制,受到了广大用户的喜爱和信赖。 为了能够快速添加QQ好友,我们可以通过开启通讯录来实现。那么,qq同步通讯录如何操作呢&…

数字IC前端学习笔记:异步复位,同步释放

相关阅读 数字IC前端https://blog.csdn.net/weixin_45791458/category_12173698.html?spm1001.2014.3001.5482 异步复位 异步复位是一种常见的复位方式,可以使电路进入一个可知的状态。但是不正确地使用异步复位会导致出现意想不到的错误,复位释放便是…

新生儿奶藓:原因、科普和注意事项

引言: 新生儿奶藓是一种常见的婴儿皮肤问题,通常在生后的头几个月内出现。尽管奶藓对婴儿的健康没有太大影响,但了解其原因、科普相关信息以及采取适当的注意事项是帮助父母更好地照顾婴儿皮肤的关键。本文将深入探讨新生儿奶藓的原因、相关…

【pytorch深度学习 应用篇02】训练中loss图的解读,训练中的问题与经验汇总

文章目录 loss图解析train loss ↘ \searrow ↘ ↗ \nearrow ↗ 先降后升 loss图解析 train loss ↘ \searrow ↘ 不断下降,test loss ↗ \nearrow ↗ 不断上升:原因很多,我是把workers1,batchSize8192train loss ↘ \searro…

【Linux】vscode远程连接ubuntu,含失败解决方案

删除vscode远程连接 打开‪C:\Users\GIGA\.ssh\config文件,GIGA是windows下自己的用户名。 删除‪C:\Users\GIGA\.ssh\config文件里的所有内容,点击保存;然后刷新。 可以看出SSH 远程连接已经被删除了。 vscode远程连接ubuntu 在弹出的…

nginx静态网站部署

Nginx是一个HTTP的web服务器,可以将服务器上的静态文件(如HTML、图片等)通过HTTP协议返回给浏览器客户端 案例:将ace-master这个静态网站部署到Nginx服务器上 通过Xftp将ace-master到linux服务器/opt/static目录下,为…

Spring高级bean的实例化方法

bean的实例化方法 构造方法 实例化bean第一种:使用默认无参构造函数(常用) 第二种创建bean实例:静态工厂实例化(了解) 第三种:实例工厂(了解)与FactoryBean(实用)

这些好用的录屏专家,你都知道吗?(干货)

在数字时代,录制屏幕已经成为沟通、教育和创作的重要工具。无论您是一位教育者、企业家还是内容创作者,能够熟练地使用录屏软件将帮助您传达信息和创作内容。在本文中,我们将介绍三款优秀的录屏专家,以帮助您找到最适合自己需求的…

如何通过算法模型进行数据预测

当今数据时代背景下更加重视数据的价值,企业信息化建设会越来越完善,越来越体系化,以数据说话,通过数据为企业提升渠道转化率、改善企业产品、实现精准运营,为企业打造自助模式的数据分析成果,以数据驱动决…

springboot学习笔记

目录 概述 常见的SSM搭建项目弊端 什么是springboot 特点 1.简化部署 2.简化配置,注解代替xml 3.简化依赖配置 4.应用监控 springboot与springmvc,springcloud关系 创建springboot项目 spring4提供的注解 Spring的发展 Java配置 1.核心注解…

构造函数,原型对象,实例对象

1.构造函数、原型对象、实例对象三者分别是什么? 构造函数:用来创建对象的函数,创建实例对象的模板 。构造函数的函数名尽量首字母大写(为了区分普通函数和构造函数)原型对象:每一个函数在创建的时候,系统都会给分配一…

wpf devexpress 绑定数据编辑器

定义视图模型 打开前一个项目 打开RegistrationViewModel.cs文件添加如下属性到RegistrationViewModel类 [POCOViewModel] public class RegistrationViewModel {public static RegistrationViewModel Create() {return ViewModelSource.Create(() > new RegistrationVie…

振弦式渗压计的安装方式及注意要点

振弦式渗压计的安装方式及注意要点 振弦式渗压计是一种高精度、高效率的地下水位测量仪器。它可以测量地下水位的高度,计算地下水的压力,从而推算出地下水的流量。对于地下水资源管理和保护、治理工程等方面具有非常重要的意义。在安装振弦式渗压计时&a…

什么是媒体见证?媒体宣传有哪些好处?

传媒如春雨,润物细无声,大家好,我是51媒体网胡老师。 一,什么是媒体见证? 媒体见证是指企业举办活动,发布会,邀请媒体现场采访的一种宣传方式,媒体到场后,对其进行记录…

金蝶云星空对接打通旺店通·旗舰奇门采购退料单查询接口与创建货品档案接口

金蝶云星空对接打通旺店通旗舰奇门采购退料单查询接口与创建货品档案接口 来源系统:金蝶云星空 金蝶K/3Cloud在总结百万家客户管理最佳实践的基础上,提供了标准的管理模式;通过标准的业务架构:多会计准则、多币别、多地点、多组织、多税制应用…

ModuleNotFoundError: No module named ‘pycocotools‘

cuda 12.1 pytorch 2.0.1 python 3.11 运行代码,报该错误,尝试了以下方法解决: 方法一 # step 1: 安装cython pip install Cython# step 2: 安装pycocotools pip install githttps://github.com/philferriere/cocoapi.git#eggpycocotools…

MacOs 删除第三方软件

AppStore下载的软件 如果删除AppStore下载的软件,直接长按软件,点击删除或拖到废纸篓就可以完成软件的删除 第三方软件 但是第三方下载的软件,无法拖进废纸篓,长按软件也没有右上角的小叉 可以通过以下方法实现对软件的卸载 …

EMQX vs Mosquitto | MQTT Broker 对比

物联网开发者需要为自己的物联网项目选择合适的 MQTT 消息产品或服务,从而构建可靠高效的基础数据层,保障上层物联网业务。目前市面上有很多开源的 MQTT 产品,在性能功能等方面各有优点。本文将选取目前最为流行的两个开源 MQTT Broker&#…

详细介绍:国产操作系统银行麒麟V10的下载和安装

📚📚 🏅我是默,一个在CSDN分享笔记的博主。📚📚 ​​ 🌟在这里,我要推荐给大家我的专栏《Linux》。🎯🎯 🚀无论你是编程小白,还是有一…