Distilling the Knowledge in a Neural Network(2015.5)(d补)


文章目录

  • Abstract
  • 1 Introduction
  • 2 Distillation
    • 2.1 Matching logits is a special case of distillation
  • Results

论文链接

Abstract

提高几乎所有机器学习算法性能的一种非常简单的方法是在相同的数据上训练许多不同的模型,然后对它们的预测进行平均[3]。不幸的是,使用整个模型集合进行预测是很麻烦的,而且可能计算成本太高,无法部署到大量用户,特别是如果单个模型是大型神经网络。Caruana和他的合作者[1]已经证明,可以将集成中的知识压缩到一个更容易部署的单一模型中,并且我们使用不同的压缩技术进一步开发了这种方法。我们在MNIST上取得了一些令人惊讶的结果,并且我们表明,通过将模型集合中的知识提取到单个模型中,我们可以显著改善大量使用的商业系统的声学模型。我们还介绍了一种由一个或多个完整模型和许多专业模型组成的新型集成,这些模型学习区分完整模型所混淆的细粒度类。与混合专家不同,这些专家模型可以快速并行地训练

1 Introduction

许多昆虫都有一个幼虫形态,它最擅长从环境中获取能量和营养,而一个完全不同的成虫形态,它最擅长于旅行和繁殖的不同需求。在大规模机器学习中,我们通常在训练阶段和部署阶段使用非常相似的模型,尽管它们的要求非常不同:对于语音和对象识别等任务,训练必须从非常大的、高度冗余的数据集中提取结构,但它不需要实时操作,它可以使用大量的计算量。然而,部署到大量用户时,对延迟和计算资源的要求要严格得多。与昆虫的类比表明,如果能更容易地从数据中提取结构,我们应该愿意训练非常繁琐的模型。繁琐的模型可以是单独训练的模型的集合,也可以是使用dropout等非常强的正则化器训练的单个非常大的模型[9]。一旦繁琐的模型得到训练,我们就可以使用另一种训练,我们称之为**“蒸馏”,将知识从繁琐的模型转移到更适合部署的小模型中**。这种策略的一个版本已经由Rich Caruana和他的合作者开创[1]。在他们的重要论文中,他们令人信服地证明了由大量模型集合获得的知识可以转移到单个小模型中

一个概念上的障碍可能阻碍了对这种非常有前途的方法进行更多的研究,那就是我们倾向于用学习到的参数值来识别训练模型中的知识,这使得我们很难看到如何在保持相同知识的情况下改变模型的形式。将知识从任何特定实例中解放出来的更抽象的知识视图是,它是从输入向量到输出向量的学习映射。对于学习区分大量类别的繁琐模型,正常的训练目标是最大化正确答案的平均对数概率,但学习的副作用是训练模型为所有不正确答案分配概率,即使这些概率非常小,其中一些概率也比其他概率大得多。错误答案的相对概率告诉我们很多关于这个繁琐的模型是如何泛化的。例如,宝马的图像可能只有很小的机会被误认为是垃圾车,但这种错误仍然比将其误认为胡萝卜的可能性高很多倍。

将繁琐模型的泛化能力转移到小模型上的一个显而易见的方法是将繁琐模型产生的类概率作为训练小模型的“软目标”。对于这个迁移阶段,我们可以使用相同的训练集或单独的“迁移”集。当繁琐的模型是由许多简单模型组成的大集合时,我们可以使用单个预测分布的算术或几何平均值作为软目标。当软目标具有高熵时,每个训练案例提供的信息量比困难目标大得多,训练案例之间的梯度方差也小得多,因此小模型通常可以在比原始繁琐模型少得多的数据上进行训练,并使用更高的学习率。

对于像MNIST这样的任务,繁琐的模型几乎总是产生非常高置信度的正确答案,关于学习函数的大部分信息存在于软目标中非常小的概率比率中。例如,一个2可能有10 −6的概率是3,10 −9的概率是7,而另一个版本可能是相反的。这是有价值的信息,它定义了数据的丰富相似性结构(即,它表示哪些2看起来像3,哪些看起来像7),但它在传递阶段对交叉熵成本函数的影响很小,因为概率非常接近于零。Caruana和他的合作者通过使用logits(最终softmax的输入)而不是softmax产生的概率作为学习小模型的目标来规避这个问题,他们最小化了繁琐模型产生的logits和小模型产生的logits之间的平方差。我们更通用的解决方案,称为“蒸馏”,是提高最终软最大值的温度,直到繁琐的模型产生合适的软目标集。然后,我们在训练小模型时使用相同的高温来匹配这些软目标。稍后我们将说明,匹配繁琐模型的对数实际上是蒸馏的一种特殊情况

用于训练小模型的转移集可以完全由未标记的数据组成[1],或者我们可以使用原始训练集。我们发现,使用原始的训练集效果很好,特别是如果我们在目标函数中添加一个小项,可以鼓励小模型预测真实目标,并匹配繁琐模型提供的软目标。通常,小模型不能完全匹配软目标,在正确答案的方向上犯错误是有帮助的。

2 Distillation

神经网络通常通过使用“softmax”输出层来产生类概率,通过将 zi 与其他 logits 进行比较,将每个类别的 logit z i 计算为概率 q i

T是温度,通常设为1。使用更高的T值会产生更柔和的类概率分布

在最简单的蒸馏形式中,通过在转移集上训练知识,并对转移集中的每个情况使用软目标分布,将知识转移到蒸馏模型中,该转移集中使用具有高温软最大值的繁琐模型产生的软目标分布。在训练蒸馏模型时使用相同的高温,但在训练完成后,它使用的温度为1。

当所有或部分转移集的正确标签已知时,可以通过训练蒸馏模型来产生正确的标签来显著改进该方法。一种方法是使用正确的标签来修改软目标,但我们发现更好的方法是简单地使用两个不同目标函数的加权平均值。第一个目标函数是与软目标的交叉熵,该交叉熵是在蒸馏模型的软最大值中使用与从繁琐模型生成软目标相同的高温来计算的。第二个目标函数是带有正确标签的交叉熵。这是在蒸馏模型的softmax中使用完全相同的logits计算的,但温度为1。我们发现,在第二个目标函数上使用相当低的权重通常可以获得最佳结果。由于软目标产生的梯度大小为1/ t2,因此在使用硬目标和软目标时,将它们乘以t2是很重要的。这确保了在使用元参数进行实验时,如果用于蒸馏的温度发生变化,则硬目标和软目标的相对贡献大致保持不变。

2.1 Matching logits is a special case of distillation

传递集中的每种情况相对于蒸馏模型的每个logit z i贡献了一个交叉熵梯度dC/dz i。如果繁琐模型的logits v i产生软目标概率p i,并且迁移训练在温度T下进行,则该梯度为:
如果温度比对数的大小高,我们可以近似:
如果我们现在假设对数在每个转移情况下都是Σj zj = Σj vj = 0,式3化简为:

所以在高温极限下,蒸馏等于最小化1/2(zi - vi) 2,只要对数分别为零。在较低的温度下,蒸馏很少注意匹配比平均值负得多的对数。这是潜在的优势,因为这些逻辑几乎完全不受用于训练繁琐模型的成本函数的约束,因此它们可能非常嘈杂。另一方面,非常负的对数可以传达关于繁琐模型所获得的知识的有用信息。这些影响中哪一个占主导地位是一个经验问题。我们表明,当蒸馏模型太小而无法捕获繁琐模型中的所有知识时,中间温度效果最好,这强烈表明忽略大的负对数可能是有帮助的。

Results

我们训练了10个独立的模型来预测P(h t |s t;θ),使用完全相同的架构和训练过程作为基线。用不同的初始参数值随机初始化模型,我们发现这在训练模型中产生了足够的多样性,使得集合的平均预测明显优于单个模型。我们已经探索了通过改变每个模型看到的数据集来增加模型的多样性,但是我们发现这不会显著改变我们的结果,所以我们选择了更简单的方法。对于蒸馏,我们尝试了[1,2,5,10]的温度,并对硬目标的交叉熵使用了0.5的相对权重,其中粗体表示表1中使用的最佳值
表1显示,实际上,我们的蒸馏方法能够从训练集中提取更多有用的信息,而不是简单地使用硬标签来训练单个模型。使用10个模型的集合所获得的帧分类精度提高的80%以上被转移到蒸馏模型上,这与我们在MNIST上的初步实验中观察到的改进相似。由于目标函数不匹配,集成对WER的最终目标(在23k个单词的测试集上)给出了较小的改进,但同样,集成实现的WER改进被转移到蒸馏模型上。

我们最近意识到通过匹配已经训练好的大型模型的类概率来学习小型声学模型的相关工作[8]。然而,他们使用大型未标记数据集在1的温度下进行蒸馏,他们的最佳蒸馏模型仅将小模型的错误率降低了28%,这是大模型和小模型在使用硬标签训练时错误率之间的差距

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/223368.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Node.js安装和下载(保姆级教程,别再再说你不会了)

1.浏览器搜索node.js 2.打开官网(选择Other Download) ​ 3.根据你的计算机版本选择 4.找到你下载的程序(双击打开) 5.双击后的效果如下: 6.继续下一步 7.选择安装路径然后下一步 8.然后继续下一步 9. 直接下一步&am…

P6 Linux 系统中的文件类型

目录 前言 ​编辑 01 linux系统查看文件类型 02 普通文件 - 03 目录文件 d 04 字符设备文件 c 和块设备文件 b 05 符号链接文件 l 06 管道文件 p 07 套接字文件 s 总结 前言 🎬 个人…

数据增强改进,实现检测目标copypaste,增加目标数据量,提升精度

🗝️YOLOv8实战宝典--星级指南:从入门到精通,您不可错过的技巧   -- 聚焦于YOLO的 最新版本, 对颈部网络改进、添加局部注意力、增加检测头部,实测涨点 💡 深入浅出YOLOv8:我的专业笔记与技术总结   -- YOLOv8轻松上手, 适用技术小白,文章代码齐全,仅需 …

postgresql自带指令命令系列二

简介 在安装postgresql数据库的时候会需要设置一个关于postgresql数据库的PATH变量 export PATH/home/postgres/pg/bin:$PATH,该变量会指向postgresql安装路径下的bin目录。这个安装目录和我们在进行编译的时候./configure --prefix [指定安装目录] 中的prefix参…

consistency model

Consistency is All You Need - wrong.wang什么都不用做生成却快了十倍其实也并非完全不可能https://wrong.wang/blog/20231111-consistency-is-all-you-need/[学科基础] 从布朗运动到扩散模型采样算法 - 知乎引言 扩散模型是近年来新出现的一种生成模型,很多工作将…

现货白银简单介绍

在贵金属投资领域,现货白银是当前国际上最为流行、交投最为活跃的白银投资方式,其交易市场遍布全球,包括伦敦、苏黎世、纽约、芝加哥及香港等主要市场,是一种以杠杆交易和做市商的形式进行的现货交易。 现货白银可以说是当下交易模…

Python (二) 读写excel文件

程序员的公众号:源1024,获取更多资料,无加密无套路! 最近整理了一波电子书籍资料,包含《Effective Java中文版 第2版》《深入JAVA虚拟机》,《重构改善既有代码设计》,《MySQL高性能-第3版》&…

1996-2021年世界各国WGI全球治理指标:政治稳定、制度控制、国家治理、控制腐败、自由指数数据

1996-2021年世界各国WGI全球治理指标:政治稳定、制度控制、国家治理、控制腐败、自由指数数据 1、时间:1996-2021年 2、指标:Voiceand Accountability、Political Stability No Violence、Government Effectiveness、Regulatory Quality、R…

tomcat控制台中文信息显示乱码

问题现象 我的tomcat版本是10.1版本。 在cmd下启动tomcat,会新打开控制台输出窗口: 控制台窗口输出的中文信息是乱码: 问题原因 产生这个问题的原因是:控制台窗口的编码和输出到控制台窗口的日志信息编码不一致。 查看tomc…

《opencv实用探索·十一》opencv之Prewitt算子边缘检测,Roberts算子边缘检测和Sobel算子边缘检测

1、前言 边缘检测: 图像边缘检测是指在图像中寻找灰度、颜色、纹理等变化比较剧烈的区域,它们可能代表着物体之间的边界或物体内部的特征。边缘检测是图像处理中的一项基本操作,可以用于人脸识别、物体识别、图像分割等多个领域。 边缘检测…

如何在服务器上运行python文件

目录 前置准备 详细步骤 一,在服务器安装Anaconda 下载安装包 上传文件到服务器 安装环境 二,创建虚拟环境 创建环境 三,测试执行python文件 执行python文件 查看进程状态 总结 前置准备 如何在个人服务器上运行python文件&#x…

elk+kafka+filebeat

elk1 cd /opt 把filebeat投进去 tar -xf filebeat-6.7.2-linux-x86_64.tar.gz mv filebeat-6.7.2-linux-x86_64 filebeat cd filebeat/ yum -y install nginx systemctl restart nginx vim /usr/share/nginx/html/index.html this is nginx cp filebeat.yml filebeat.yml.…

Matlab之统计数据分布并绘制直方图函数histogram

一、功能 直方图是一种将数据分组到条柱中的条形图。该函数可以统计数据在划分区间内的数量分布,同时以直方图的形式展示统计结果。 二、语法 1、histogram(X) 创建直方图X的图。该函数使用 一种自动分箱算法,返回具有统一宽度…

数组解构、对象解构与forEach方法遍历数组

解构赋值 1. 数组解构 1.1 基本语法 1.2 变量多 单元值少的情况 1.3 变量少 单元值多的情况 1.4 防止undefined传值情况 使用默认值 1.5 按需导入 忽略某些值 1.6 支持多维数组的解构 2. 对象解构 2.1 基本语法 2.2 给新的变量名赋值 2.3 数组对象解构 2.4 多级对象解构 cons…

网络安全威胁——跨站脚本攻击

跨站脚本攻击 1. 定义2. 跨站脚本攻击如何工作3. 跨站脚本攻击类型4. 如何防止跨站脚本攻击 1. 定义 跨站脚本攻击(Cross-site Scripting,通常称为XSS),是一种典型的Web程序漏洞利用攻击,在线论坛、博客、留言板等共享…

vscode插件离线下载

离线下载插件地址:https://marketplace.visualstudio.com/VSCode

win11 关闭快速启动,解决重启后部分应用没有关闭的问题

鼠标右击win11开始菜单选择windows终端(管理员)打开输入:powercfg /h off按下回车即可

AOC computer monitor

【窗口增亮】关闭就没掉了

近期Google paly再次卡审?需要开发者提供更多关于应用的信息以通过谷歌审查?

谷歌政策更新得越来越频繁,也越来越严格,加大了对应用的审核力度。 最近,不少开发者表示,谷歌卡审又出新花样了。与之前收到暂停审核电话验证邮件(需要在48-72小时内,拨打你开发者账号的号码,应…

8、Broker进一步了解

1、Broker消息分发服务以及构建ConsumeQueue和IndexFile与消息清除 前面分析如何进行刷盘,本章分析Broker的消息分发以及构建ConsumerQueue和IndexFile,两者构建是为了能够提高效率,减少消息查找时间以及减少网络带宽与存储空间。 ConsumeQ…
最新文章