LLM面面观之RLHF平替算法DPO

1. 背景

最近本qiang~老看到一些关于大语言模型的DPO、RLHF算法,但都有些云里雾里,因此静下心来收集资料、研读论文,并执行了下开源代码,以便加深印象。

此文是本qiang~针对大语言模型的DPO算法的整理,包括原理、流程及部分源码

2. DPO vs RLHF

上图左边是RLHF算法,右边为DPO算法,两图的差异对比即可体现出DPO的改进之处。

1. RLHF算法包含奖励模型(reward model)和策略模型(policy model,也称为演员模型,actor model),基于偏好数据以及强化学习不断迭代优化策略模型的过程。

2. DPO算法不包含奖励模型和强化学习过程,直接通过偏好数据进行微调,将强化学习过程直接转换为SFT过程,因此整个训练过程简单、高效,主要的改进之处体现在于损失函数。

PS:

1. 偏好数据,可以表示为三元组(提示语prompt, 良好回答chosen, 一般回答rejected)。论文中的chosen表示为下标w(即win),rejected表示为下标l(即lose)

2. RLHF常使用PPO作为基础算法,整体流程包含了4个模型,且通常训练过程中需要针对训练的actor model进行采样,因此训练起来,稳定性、效率、效果不易控制。

1) actor model/policy model: 待训练的模型,通常是SFT训练后的模型作为初始化

2) reference model: 参考模型,也是经SFT训练后的模型进行初始化,且通常与actor model是同一个模型,且模型冻结,不参与训练,其作用是在强化学习过程中,保障actor model与reference model的分布差异不宜过大。

3) reward model: 奖励模型,用于提供每个状态或状态动作对的即时奖励信号。

4) Critic model: 作用是估计状态或状态动作对的长期价值,也称为状态值函数或动作值函数。

3. DPO算法仅包含RLHF中的两个模型,即演员模型(actor model)以及参考(reference model),且训练过程中不需要进行数据采样。

4. RLHF可以参考附件中的引文

3. DPO的损失函数

如何将RLHF的Reward model过程简化为上式,作者花了大量篇幅进行了推导,感兴趣的读者可以参考附件DPO的论文。

DPO算法的目的是最大化奖励模型(此处的奖励模型即为训练的策略),使得奖励模型对chosen和rejected数据的差值最大,进而学到人类偏好。

上式的后半部分通过对数函数运算规则,可以进行如下转化。

转化后的公式和源代码中的计算函数中的公式是一致的。

其中左半部分是训练的policy模型选择chosen优先于rejected,右半部分是冻结的reference模型选择chosen优先于rejected,二者的差值可类似于KL散度,保障actor模型的分布与reference模型的分布不会有较大的差异。

4. 微调流程

上图展示了DPO微调的大致流程,其中Trained LM即为策略模型,Frozen LM即为参考模型,二者均是先进行SFT微调得到的模型进行初始化,其中Trained LM需要进行训练,Frozen LM不参与训练。

两个模型分别针对chosen和rejected进行预测获取对应的得分,再通过DPO的损失函数进行损失计算,进而不断的迭代优化

5. 源码

源码参考代码:https://github.com/eric-mitchell/direct-preference-optimization

5.1 DPO损失函数

def preference_loss(policy_chosen_logps: torch.FloatTensor,
                    policy_rejected_logps: torch.FloatTensor,
                    reference_chosen_logps: torch.FloatTensor,
                    reference_rejected_logps: torch.FloatTensor,
                    beta: float,
                    label_smoothing: float = 0.0,
                    ipo: bool = False,
                    reference_free: bool = False) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
    # policy_chosen_logps: 训练模型对于chosen经过log后logits
	# policy_rejected_logps: 训练模型对于rejected经过log后logits
	# reference_chosen_logps: 训练模型对于chosen经过log后logits
	# reference_rejected_logps: 训练模型对于rejected经过log后logits
	# beta: policy和reference的差异性控制参数
	
	# actor模型选择chosen优先于rejected
    pi_logratios = policy_chosen_logps - policy_rejected_logps
	# reference模型选择chosen优先于rejected
    ref_logratios = reference_chosen_logps - reference_rejected_logps

    if reference_free:
        ref_logratios = 0
	
	# 差值可类似于KL散度,保障actor模型的分布与reference模型的分布不会有较大的差异
    logits = pi_logratios - ref_logratios  # also known as h_{\pi_\theta}^{y_w,y_l}

    if ipo:
        losses = (logits - 1/(2 * beta)) ** 2  # Eq. 17 of https://arxiv.org/pdf/2310.12036v2.pdf
    else:
        # Eq. 3 https://ericmitchell.ai/cdpo.pdf; label_smoothing=0 gives original DPO (Eq. 7 of https://arxiv.org/pdf/2305.18290.pdf)
		# label_smoothing为0,对应的DPO论文的算法
        losses = -F.logsigmoid(beta * logits) * (1 - label_smoothing) - F.logsigmoid(-beta * logits) * label_smoothing
	
	# chosen和rejected的奖励
    chosen_rewards = beta * (policy_chosen_logps - reference_chosen_logps).detach()
    rejected_rewards = beta * (policy_rejected_logps - reference_rejected_logps).detach()

    return losses, chosen_rewards, rejected_rewards

5.2 批次训练过程

def get_batch_metrics(self, batch: Dict[str, Union[List, torch.LongTensor]], loss_config: DictConfig, train=True):
	"""Compute the SFT or DPO loss and other metrics for the given batch of inputs."""

	if loss_config.name in {'dpo', 'ipo'}:
		# policy模型针对chosen和rejected进行预测
		policy_chosen_logps, policy_rejected_logps = self.concatenated_forward(self.policy, batch)
		with torch.no_grad():
			# reference模型针对chosen和rejected进行预测
			reference_chosen_logps, reference_rejected_logps = self.concatenated_forward(self.reference_model, batch)

		if loss_config.name == 'dpo':
			loss_kwargs = {'beta': loss_config.beta, 'reference_free': loss_config.reference_free, 'label_smoothing': loss_config.label_smoothing, 'ipo': False}
		elif loss_config.name == 'ipo':
			loss_kwargs = {'beta': loss_config.beta, 'ipo': True}
		else:
			raise ValueError(f'unknown loss {loss_config.name}')
		# 损失计算
		losses, chosen_rewards, rejected_rewards = preference_loss(
			policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps, **loss_kwargs)

		reward_accuracies = (chosen_rewards > rejected_rewards).float()

	elif loss_config.name == 'sft':
		policy_chosen_logits = self.policy(batch['chosen_input_ids'], attention_mask=batch['chosen_attention_mask']).logits.to(torch.float32)
		policy_chosen_logps = _get_batch_logps(policy_chosen_logits, batch['chosen_labels'], average_log_prob=False)

		losses = -policy_chosen_logps

	return losses.mean()

5.3 LM的交叉熵计算

def _get_batch_logps(logits: torch.FloatTensor, labels: torch.LongTensor, average_log_prob: bool = False) -> torch.FloatTensor:
    # 经模型后的logits进行批量计算logps
	
    assert logits.shape[:-1] == labels.shape
	
	# 基于先前的token预测下一个token
    labels = labels[:, 1:].clone()
    logits = logits[:, :-1, :]
    loss_mask = (labels != -100)

    # dummy token; we'll ignore the losses on these tokens later
    labels[labels == -100] = 0
	
	# 交叉熵函数
    per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)

    if average_log_prob:
        return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
    else:
        return (per_token_logps * loss_mask).sum(-1)

5.4 其他注意

1. hugging face设置代理

源码会从hugging face中下载英文语料和模型,由于网络限制,因此设置代理映射,将HF_ENDPOINT设置为https://hf-mirror.com,即设置: os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'

2. 如果仅想要熟悉DPO整体流程,可以下载较小的生成式模型,如BLOOM 560M,GPT2等

6. 总结

一句话足矣~

本文主要针对大语言模型的DPO算法的整理,包括原理、流程及部分源码。

此外,建议大家可以针对源码进行运行,源码的欢迎大家一块交流。

7. 参考

(1) RLHF:ChatGPT技术原理解析:从RL之PPO算法、RLHF到GPT4、instructGPT-CSDN博客

(2) DPO论文: https://arxiv.org/pdf/2305.18290v2.pdf

(3) DPO代码: https://github.com/eric-mitchell/direct-preference-optimization

(4) DPO理解1:https://medium.com/@joaolages/direct-preference-optimization-dpo-622fc1f18707

(5) DPO理解2: https://zhuanlan.zhihu.com/p/669825918

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/360564.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

python使用Schedule

目录 一:使用场景: 二:参数 三:实例 "Schedule"在Python中通常指的是时间调度或任务计划。Python中有多个库可以用来处理时间调度和任务计划,其中最流行的是schedule库。 一&#x…

Linux安装nexus maven 仓库

下载安装包 详情请参考https://blog.csdn.net/weixin_42585386/article/details/122108563 上传到服务器,解压缩 tar -xzvf nexus-3.31.1-01-unix.tar.gz启动服务 cd nexus-3.31.1-01/bin/ ./nexus -start登录私服 浏览器输入:http://192.168.80.10…

Docker打包镜像把jar包打成tar包

1.虚拟机中安装docker Cd /usr cd … 以后就在主目录操作 2.然后启动docker 启动命令:systemctl start docker 然后使用touch Dockerfile 创建Dockerfile 文件 使用vim Dockerfile 在此文件编辑配置等 编辑Dockerfile FROM openjdk:8-jre WORKDIR /home ADD emss-individua…

单元测试 | Junit4“单元测试“ ( Java中可用 )

目录: 使用JUnit4进行“单元测试” 作者简介 :一只大皮卡丘,计算机专业学生,正在努力学习、努力敲代码中! 让我们一起继续努力学习! 文章用于本人学习使用 , 同时希望能帮助大家。 欢迎大家点赞👍 收藏⭐ …

内部类 --java学习笔记

内部类 是类中的五大成分之一(成员变量、方法、构造器、内部类、代码块),如果一个类定义在另一个类的内部,那么这个类就是内部类当一个类的内部包含了一个整体的事务,且这个事务没必要单独设计时,就可以把…

一种手机短信验证码登录平台的解决方案

前提 爬取数据时,请求需要带上Cookie,这是很常见的一种防爬手段。更新Cookie,常用的方法就是Selenium模拟输入用户名和密码;偶尔会遇到图片验证码,现在打码平台很多且技术也很成熟,这个已经不成问题。所谓…

Qt QScrollArea 不显示滚动条 不滚动

使用QScrollArea时,发现添加的控件超出QScrollArea 并没有显示,且没有滚动条效果 原因是 scrollArea指的是scrollArea控件本身的大小,肉眼能看到的外形尺寸。 scrollAreaWidgetContents指的是scrollArea控件内部的显示区域,里面可…

【笔记】React-Native跟Android交互--简单示例

/** * 使用命令 npx react-nativelatest init DemoRN创建项目 * * "react": "18.2.0", * "react-native": "0.73.2" * * 官网有详细教程:https://reactnative.dev/docs/native-modules-android */ 一、RN invoke androi…

大创项目推荐 题目:基于深度学习的中文汉字识别 - 深度学习 卷积神经网络 机器视觉 OCR

文章目录 0 简介1 数据集合2 网络构建3 模型训练4 模型性能评估5 文字预测6 最后 0 简介 🔥 优质竞赛项目系列,今天要分享的是 基于深度学习的中文汉字识别 该项目较为新颖,适合作为竞赛课题方向,学长非常推荐! &a…

C 变量

目录 1. C变量 2. C变量定义 2.1 变量初始化 2.2 C中的变量声明 3. C中的左值(Lvalues)和右值(Rvalues) 1. C变量 在C语言中,变量可以根据其类型分为以下几种基本类型: 整型变量:用…

2.室内设计学习 - CAD 2021 调整经典界面教程及基本设置

设置经典界面 1.在第二行的空白处右击,弹出对话框,并点击【关闭】,关闭掉。 2.菜单栏没有显示的情况下,在最上面的一排,点击向下的箭头展开下拉框,勾选 【显示菜单栏】 3.点击菜单【工具】-【工具栏】-【a…

【Docker】在Windows下使用Docker Desktop创建nginx容器并访问默认网站

欢迎来到《小5讲堂》,大家好,我是全栈小5。 这是《Docker容器》序列文章,每篇文章将以博主理解的角度展开讲解, 特别是针对知识点的概念进行叙说,大部分文章将会对这些概念进行实际例子验证,以此达到加深对…

AI人工智能可以怎么应用?——GPT4v图文识别问答功能

沃卡 AI 已支持 AI识图问答TTS语音对话文档总结对话Dall E3 对话文生图国内大模型集合AI 绘画思维导图,而且功能还在不断更新优化,丰富好用!一个系统满足您多个需求! 大家可以通过收藏网页www.woka.chat 直接进行访问&#xff0c…

10s 内得到一个干净、开箱即用的 Linux 系统

安装 使用官方脚本安装我的服务器不行 官方脚本 mkdir instantbox && cd $_ bash <(curl -sSL https://raw.githubusercontent.com/instantbox/instantbox/master/init.sh) 下面是我的完整安装过程 mkdir /opt/instantbox cd /opt/instantbox 1.脚本文件 (这个没…

Asp.net移除Server, X-Powered-By, 和X-AspNet-Version头

移除X-AspNet-Version很简单,只需要在Web.config中增加这个配置节: <httpRuntime enableVersionHeader"false" />移除Server在Global.asax文件总增加&#xff1a; //隐藏IIS版本 protected void Application_PreSendRequestHeaders() {HttpContext.Current.Res…

数据结构day7

1.思维导图 1.二叉树递归创建 2.二叉树先中后序遍历 3.二叉树计算节点 4.二叉树计算深度。 5.编程实现快速排序降序

C# Onnx yolov8 水表读数检测

目录 效果 模型信息 项目 代码 训练数据 下载 C# Onnx yolov8 水表读数检测 效果 模型信息 Model Properties ------------------------- date&#xff1a;2024-01-31T10:18:10.141465 author&#xff1a;Ultralytics task&#xff1a;detect license&#xff1a;AGPL-…

Linus Torvalds的20个事实

Linus Torvalds 是 Linux 操作系统的创造者&#xff0c;至今还在维护内核。本文是他的自传《Just for fun》的简短摘录&#xff0c;关于他个人的20个事实&#xff0c;比如他的老婆是他的学生。 Brief: Some known, some lesser known – here are 20 facts about the Linus Tor…

【Maven基础】依赖插件管理工具

Maven Maven 作用Maven 安装Maven 目录Maven config settings创建 Maven 项目运行 Java 文件Maven 坐标导入 Maven 项目依赖管理依赖配置 依赖传递排除依赖 依赖范围生命周期test跳过 Test Maven 作用 Maven 安装 Maven 目录 bin 存放可执行文件 config 存放 Maven 的配置文件 …

C++ 数论相关题目,博弈论,SG函数,集合-Nim游戏

给定 n 堆石子以及一个由 k 个不同正整数构成的数字集合 S 。 现在有两位玩家轮流操作&#xff0c;每次操作可以从任意一堆石子中拿取石子&#xff0c;每次拿取的石子数量必须包含于集合 S &#xff0c;最后无法进行操作的人视为失败。 问如果两人都采用最优策略&#xff0c;…
最新文章