深度学习论文阅读之【Distilling the Knowledge in a Neural Network】提炼神经网络中的知识

论文:link
代码:link

摘要

  提高几乎所有机器学习算法性能的一个非常简单的方法是在相同的数据上训练许多不同的模型,然后对它们的预测进行平均[3]。不幸的是,使用整个模型集合进行预测非常麻烦,并且计算成本可能太高,无法部署到大量用户,尤其是在单个模型是大型神经网络的情况下。 Caruana 和他的合作者 [1] 已经证明,可以将集成中的知识压缩到单个模型中,该模型更容易部署,并且我们使用不同的压缩技术进一步开发了这种方法。我们在 MNIST 上取得了一些令人惊讶的结果,并且表明我们可以通过将模型集合中的知识提炼为单个模型来显着改进频繁使用的商业系统的声学模型。我们还引入了一种由一个或多个完整模型和许多专业模型组成的新型集成,这些模型学习区分完整模型混淆的细粒度类别。与专家的混合不同,这些专业模型可以快速并行地进行训练。

1.Introduction

  许多昆虫都有幼虫形态和完全不同的成虫形态,幼虫形态可以经过优化,可以从环境中获取能量和营养,成虫形态可以满足不同的旅行和繁殖要求,在大规模机器学习中,我们通常在训练阶段和部署阶段使用非常相似的模型,尽管它们的要求非常不同,对于语音和对象识别等任务,训练必须从非常大、高度冗余的数据集中提取结构,但它并不需要这样做,需要实时操作,并且会使用大量的计算量。
  然而,部署到大量用户对延迟和计算资源有更严格的要求。与昆虫的类比表明,如果可以更轻松地从数据中提取结构,我们应该愿意训练非常繁琐的模型。繁琐的模型可能是单独训练的模型的集合,也可能是使用非常强大的正则化器(例如 dropout)训练的单个非常大的模型[9]。一旦繁琐的模型经过训练,我们就可以使用不同类型的训练,我们称之为“蒸馏”,将知识从繁琐的模型转移到更适合部署的小模型。 Rich Caruana 及其合作者已经率先提出了该策略的一个版本 [1]。在他们的重要论文中,他们令人信服地证明,通过大型模型集合获得的知识可以转移到单个小型模型中。通常认为模型学习到的参数代表了知识,无法直接迁移,但教师网络预测结果中各类别概率的相对大小也隐式包含知识。

2.Distillation

  神经网络通常通过使用“softmax”输出层来生成类别概率,该输出层通过将 z i z_i zi与其他logits进行比较。将为每一个类别计算的logits的 z i z_i zi转换为概率 p i p_i pi.
q i = exp ⁡ ( z i / T ) ∑ j exp ⁡ ( z j / T ) {q_i} = \frac{{\exp \left( {{z_i}/T} \right)}}{{{\sum _j}\exp \left( {{z_j}/T} \right)}} qi=jexp(zj/T)exp(zi/T)
其中 T 是温度,通常设置为 1。使用较高的 T 值会在类别上产生较软的概率分布。在最简单的蒸馏形式中,知识被转移到蒸馏模型中,方法是在转移集上进行训练,并使用转移集中每种情况的软目标分布,该软目标分布是通过使用其 softmax 中温度较高的繁琐模型生成的。训练蒸馏模型时使用相同的高温,但训练后它使用温度 1。
一种方法是使用正确的标签来修改软目标,但我们发现更好的方法是简单地使用两个不同目标函数的加权平均值。第一个目标函数是与软目标的交叉熵,并且该交叉熵是使用蒸馏模型的 softmax 中与用于从繁琐模型生成软目标相同的高温来计算的。第二个目标函数是具有正确标签的交叉熵。这是在蒸馏模型的 softmax 中使用完全相同的对数计算的,但温度为 1。我们发现,通常通过在第二个目标函数上使用相当低的权重来获得最佳结果。由于软目标产生的梯度大小为 1/T 2 ,因此在使用硬目标和软目标时将其乘以 T 2 非常重要。这确保了如果在元参数实验时用于蒸馏的温度发生变化,硬目标和软目标的相对贡献保持大致不变。

2.1 匹配logits是蒸馏的一个特例

  传输集中每个案例都贡献一个交叉熵梯度 d C / d z i dC/d{z_i} dC/dzi,相当于蒸馏模型的每个logit z i z_i zi,并且繁琐的模型具有产生软目标概率 p i p_i pi的logits v i v_i vi,并且转移训练是在温度T下完成的,则该梯度由下式给出:
∂ C ∂ z i = 1 T ( q i − p i ) = 1 T ( e z i / T ∑ j e z j / T − e v i / T ∑ j e v j / T ) \frac{{\partial C}}{{\partial {z_i}}} = \frac{1}{T}\left( {{q_i} - {p_i}} \right) = \frac{1}{T}\left( {\frac{{{e^{{z_i}/T}}}}{{{\sum _j}{e^{{z_j}/T}}}} - \frac{{{e^{{v_i}/T}}}}{{{\sum _j}{e^{{v_j}/T}}}}} \right) ziC=T1(qipi)=T1(jezj/Tezi/Tjevj/Tevi/T)
如果温度与logits的大小相比比较高,我们可以近似:
∂ C ∂ z i ≈ 1 T ( 1 + e z i / T N + ∑ j z j / T − 1 + e v i / T N + ∑ j v j / T ) \frac{{\partial C}}{{\partial {z_i}}} \approx \frac{1}{T}\left( {\frac{{1 + {e^{{z_i}/T}}}}{{N + {\sum _j}{z_j}/T}} - \frac{{1 + {e^{{v_i}/T}}}}{{N + {\sum _j}{v_j}/T}}} \right) ziCT1(N+jzj/T1+ezi/TN+jvj/T1+evi/T)
如果我们现在假设每个转移情况的logits都是零均值的,则 ∑ j z j = ∑ j v j = 0 {\sum _j}{z_j} = \sum {}_j{v_j} = 0 jzj=jvj=0,原式可简化为:
∂ C ∂ z i ≈ 1 N T 2 ( z i − v i ) \frac{{\partial C}}{{\partial {z_i}}} \approx \frac{1}{{N{T^2}}}\left( {{z_i} - {v_i}} \right) ziCNT21(zivi)
  因此,在高温极限下,蒸馏相当于最小化 1 / 2 ( z i − v i ) 2 1/2(z_i-v_i)^2 1/2(zivi)2 ,前提是每个分动箱的 logits 分别为零均值。在较低的温度下,蒸馏很少关注比平均值负得多的匹配 logits。这是潜在的优势,因为这些逻辑几乎完全不受用于训练繁琐模型的成本函数的约束,因此它们可能非常嘈杂。另一方面,非常负的逻辑可能会传达有关通过繁琐模型获得的知识的有用信息。这些影响中哪一个占主导地位是一个经验问题。我们表明,当蒸馏模型太小而无法捕获繁琐模型中的所有知识时,中间温度效果最好,这强烈表明忽略大的负对数可能会有所帮助。

3.MNIST初步实验

  为了了解蒸馏的效果如何,我们在所有 60,000 个训练案例上训练了一个大型神经网络,该神经网络具有两个隐藏层,每个隐藏层包含 1200 个校正线性隐藏单元。该网络使用 dropout 和权重约束进行了强烈正则化,如 [5] 中所述。 Dropout 可以被视为训练共享权重的指数级大模型集合的一种方式。此外,输入图像在任何方向上抖动最多两个像素。该网络出现了 67 个测试错误,而具有两个隐藏层(由 800 个校正线性隐藏单元且无正则化)的较小网络出现了 146 个错误。但是,如果仅通过添加在 20 ℃ 的温度下匹配大网络产生的软目标的附加任务来对较小的网络进行正则化,则它会出现 74 个测试错误。这表明软目标可以将大量知识转移到蒸馏模型中,包括如何概括从翻译的训练数据中学到的知识,即使转移集不包含任何翻译。当蒸馏网络的两个隐藏层中每个都有 300 个或更多单位时,所有高于 8 的温度都会给出相当相似的结果。但当这从根本上减少到每层 30 个单位时,2.5 至 4 范围内的温度明显优于更高或更低的温度。然后,我们尝试从传输集中省略数字 3 的所有示例。所以从蒸馏模型的角度来看,3是一个它从未见过的神话数字。尽管如此,蒸馏模型仅出现 206 个测试错误,其中 133 个位于测试集中的 1010 个三元组上。大多数错误是由于第 3 类的学习偏差太低而引起的。如果此偏差增加 3.5(这会优化测试集的整体性能),则蒸馏模型会出现 109 个错误,其中 14 个错误位于 3 上。因此,在正确的偏差下,尽管在训练期间从未见过 3,但蒸馏模型在测试 3 中的正确率达到 98.6%。如果传输集仅包含训练集中的 7 和 8,则蒸馏模型的测试误差为 47.3%,但当 7 和 8 的偏差减少 7.6 以优化测试性能时,测试误差将降至 13.2%。

discussion

  我们已经证明,蒸馏对于将知识从集成或从大型高度正则化模型转移到较小的蒸馏模型非常有效。在 MNIST 上,即使用于训练蒸馏模型的传输集缺少一个或多个类的任何示例,蒸馏也能表现得非常好。对于 Android 语音搜索所使用的深度声学模型版本,我们已经证明,通过训练深度神经网络集合所实现的几乎所有改进都可以被提炼为相同大小的单个神经网络,部署起来要容易得多。对于非常大的神经网络,甚至训练一个完整的集合也是不可行的,但是我们已经证明,经过很长时间训练的单个非常大的网络的性能可以通过学习大量的专家来显着提高网络,每个网络都学会区分高度混乱的集群中的类别。我们还没有证明我们可以将专家的知识提炼回单一的大网络中。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/497401.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

阿里云服务器租用价格表-2024最新(附报价单)

2024年阿里云服务器优惠价格表,一张表整理阿里云服务器最新报价,阿里云服务器网aliyunfuwuqi.com整理云服务器ECS和轻量应用服务器详细CPU内存、公网带宽和系统盘详细配置报价单,大家也可以直接移步到阿里云CLUB中心查看 aliyun.club 当前最新…

数据结构——链表(练习题)

大家好,我是小锋我们继续来学习链表。 我们在上一节已经把链表中的单链表讲解完了,大家感觉怎么样我们今天来带大家做一些练习帮助大家巩固所学内容。 1. 删除链表中等于给定值 val 的所有结点 . - 力扣(LeetCode) 我们大家来分…

登录拦截器

目录 🎈1.登陆拦截器的使用 🎊2.ThreadLocal的简单使用 🎃3.登录拦截器拦截和放行配置 1.登陆拦截器的使用 创建一个拦截器类,必须让其实现HandlerInterceptor接口 1.获取前端的token 2.判断token是否为空 3.若为空&#xff…

蓝桥杯基础练习详细解析(四)——Fibonacci费伯纳西数列(题目分析、代码实现、Python)

试题 基础练习 Fibonacci数列 提交此题 评测记录 资源限制 内存限制:256.0MB C/C时间限制:1.0s Java时间限制:3.0s Python时间限制:5.0s 问题描述 Fibonacci数列的递推公式为:FnFn-1Fn-2,其…

CI/CD实战-jenkins结合ansible

配置主机环境 在jenkins上断开并删除docker1节点 重新给master添加构建任务 将server3,server4作为测试主机,停掉其上后面的docker 在server2(jenkins)主机上安装ansible 设置jenkins用户到目标主机的免密 给测试主机创建用户并…

swagger/knife4j 接口文档增加图标 springboot

1.在资源目录下增加图标文件 2.配置/favicon.ico 资源 Configuration public class WebConfig implements WebMvcConfigurer {Overridepublic void addResourceHandlers(ResourceHandlerRegistry registry) {registry.addResourceHandler("/favicon.ico").addResour…

小程序利用WebService跟asp.net交互过程发现的问题并处理

最近在研究一个项目,用到asp.net跟小程序交互,简单的说就是小程序端利用wx.request发起请求。获取asp.net 响应回来的数据。但经常会报错。点击下图的测试按钮 出现如下错误: 百思不得其解,试了若干方法,都不行。 因为…

axios发送get请求但参数中有数组导致请求路径多出了“[]“的处理办法

一、情况 使用axios发送get请求携带了数组参数时,请求路径中就会多出[]字符,而在后端也会报错 二、解决办法 1、安装qs 当前项目的命令行中安装 npm install qs2、引入qs库(使用qs库来将参数对象转换为字符串) // 全局 import qs from qs Vue.proto…

Vite 为什么比 Webpack 快?

目录 1. Webpack 的构建原理 2. Script 的模块化(主流浏览器对 ES Modules 的支持) 3. Webpack vs Vite 开发模式的差异 对 ES Modules 的支持 底层语言的差异 热更新的处理 1. Webpack 的构建原理 前端之所以需要类似于 Webpack 这样的构建工具&…

智慧工地安全生产与风险预警大平台的构建,需要哪些技术?

随着科技的不断发展,智慧工地已成为现代建筑行业的重要发展趋势。智慧工地方案是一种基于先进信息技术的工程管理模式,旨在提高施工效率、降低施工成本、保障施工安全、提升施工质量。一般来说,智慧工地方案的构建,需要通过集成物…

kubernetes(K8S)学习(一):K8S集群搭建(1 master 2 worker)

K8S集群搭建(1 master 2 worker) 一、环境资源准备1.1、版本统一1.2、k8s环境系统要求1.3、准备三台Centos7虚拟机 二、集群搭建2.1、更新yum,并安装依赖包2.2、安装Docker2.3、设置hostname,修改hosts文件2.4、设置k8s的系统要求…

【IP 组播】PIM-SM

目录 原理概述 实验目的 实验内容 实验拓扑 1.基本配置 2.配置IGP 3.配置PIM-SM 4.用户端DR与组播源端DR 5.从RPT切换到SPT 6.配置PIM-Silent接口 原理概述 PIM-SM 是一种基于Group-Shared Tree 的组播路由协议,与 PIM-DM 不同,它适合于组播组成…

SpringMVC第一个helloword项目

文章目录 前言一、SpringMVC是什么?二、使用步骤1.引入库2.创建控制层3.创建springmvc.xml4.配置web.xml文件5.编写视图页面 总结 前言 提示:这里可以添加本文要记录的大概内容: SpringMVC 提示:以下是本篇文章正文内容&#xf…

如何创建纯净版Django项目并启动?——让Django更加简洁

目录 1. Django的基本目录结构 2. 创建APP 2.1 创建app 2.2 配置文件介绍 3. 迁移数据库文件 3.2 连接数据库 3.1 创建迁移文件 3.2 同步数据库 4. 纯净版Django创建 4.1 剔除APP 4.2 剔除中间件 4.3 剔除模板引擎 5. 最终 1. Django的基本目录结构 在我们创建Django项…

吴恩达机器学习笔记 三十 什么是聚类 K-means

聚类(clustering)是一种无监督学习算法,关注多个数据点并自动找到相似的数据点,在数据中找到一种特定的结构。无监督学习算法的数据集中没有标签 y ,所以不能说哪个是“正确的 y ”。 K-means算法 K-means算法就是在重复做两件事&#xff1a…

k8s 如何获取加入节点命名

当k8s集群初始化成功的时候&#xff0c;就会出现 加入节点 的命令如下&#xff1a; 但是如果忘记了就需要找回这条命令了。 kubeadm join 的命令格式如下&#xff1a;kubeadm join --token <token> --discovery-token-ca-cert-hash sha256:<hash>--token 令牌--…

【NLP笔记】大模型prompt推理(提问)技巧

文章目录 prompt概述推理&#xff08;提问&#xff09;技巧基础prompt构造技巧进阶优化技巧prompt自动优化 参考链接&#xff1a; Pre-train, Prompt, and Predict: A Systematic Survey of Prompting Methods in Natural Language Processing预训练、提示和预测&#xff1a;NL…

Qt打印系统库的日志 - QLoggingCategory

Qt的动态库通过源码可以可以看到含有大量的qCInfo 和 qCDebug 等大量的日志&#xff0c; 但是我们正常运行Qt程序&#xff0c;这些动态库或插件里面的日志是不会输出到我们的控制台里面的。 所以本章主要记录怎么输出这些日志出来。 一&#xff1a; 步骤 主要使用的是Qt的 函…

磐启/PAN7030/2.4GHz 无线收发SOC芯片/ESSOP10/SOP16

1 概述 PAN7030 是一款集成 8 位 OTP MCU 和 2.4GHz 无线收发电路芯片&#xff0c;适合应用于玩具小车、 遥控器等领域。 PAN7030 内置 8 位 OTP MCU&#xff0c;包括 1.25KW 的程序存储器、80 字节数据存储器、16 位定 时器和 8 位/11 位 PWM 定时器、看门狗、电压比较器等…

OpenCV 如何使用 XML 和 YAML 文件的文件输入和输出

返回&#xff1a;OpenCV系列文章目录&#xff08;持续更新中......&#xff09; 上一篇&#xff1a;如何利用OpenCV4.9离散傅里叶变换 下一篇: 目标 本文内容主要介绍&#xff1a; 如何使用 YAML 或 XML 文件打印和读取文件和 OpenCV 的文本条目&#xff1f;如何对 OpenCV …
最新文章