【深度学习笔记】优化算法——Adam算法

Adam算法

🏷sec_adam

本章我们已经学习了许多有效优化的技术。
在本节讨论之前,我们先详细回顾一下这些技术:

  • 在 :numref:sec_sgd中,我们学习了:随机梯度下降在解决优化问题时比梯度下降更有效。
  • 在 :numref:sec_minibatch_sgd中,我们学习了:在一个小批量中使用更大的观测值集,可以通过向量化提供额外效率。这是高效的多机、多GPU和整体并行处理的关键。
  • 在 :numref:sec_momentum中我们添加了一种机制,用于汇总过去梯度的历史以加速收敛。
  • 在 :numref:sec_adagrad中,我们通过对每个坐标缩放来实现高效计算的预处理器。
  • 在 :numref:sec_rmsprop中,我们通过学习率的调整来分离每个坐标的缩放。

Adam算法 :cite:Kingma.Ba.2014将所有这些技术汇总到一个高效的学习算法中。
不出预料,作为深度学习中使用的更强大和有效的优化算法之一,它非常受欢迎。
但是它并非没有问题,尤其是 :cite:Reddi.Kale.Kumar.2019表明,有时Adam算法可能由于方差控制不良而发散。
在完善工作中, :cite:Zaheer.Reddi.Sachan.ea.2018给Adam算法提供了一个称为Yogi的热补丁来解决这些问题。
下面我们了解一下Adam算法。

算法

Adam算法的关键组成部分之一是:它使用指数加权移动平均值来估算梯度的动量和二次矩,即它使用状态变量

v t ← β 1 v t − 1 + ( 1 − β 1 ) g t , s t ← β 2 s t − 1 + ( 1 − β 2 ) g t 2 . \begin{aligned} \mathbf{v}_t & \leftarrow \beta_1 \mathbf{v}_{t-1} + (1 - \beta_1) \mathbf{g}_t, \\ \mathbf{s}_t & \leftarrow \beta_2 \mathbf{s}_{t-1} + (1 - \beta_2) \mathbf{g}_t^2. \end{aligned} vtstβ1vt1+(1β1)gt,β2st1+(1β2)gt2.

这里 β 1 \beta_1 β1 β 2 \beta_2 β2是非负加权参数。
常将它们设置为 β 1 = 0.9 \beta_1 = 0.9 β1=0.9 β 2 = 0.999 \beta_2 = 0.999 β2=0.999
也就是说,方差估计的移动远远慢于动量估计的移动。
注意,如果我们初始化 v 0 = s 0 = 0 \mathbf{v}_0 = \mathbf{s}_0 = 0 v0=s0=0,就会获得一个相当大的初始偏差。
我们可以通过使用 ∑ i = 0 t β i = 1 − β t 1 − β \sum_{i=0}^t \beta^i = \frac{1 - \beta^t}{1 - \beta} i=0tβi=1β1βt来解决这个问题。
相应地,标准化状态变量由下式获得

v ^ t = v t 1 − β 1 t  and  s ^ t = s t 1 − β 2 t . \hat{\mathbf{v}}_t = \frac{\mathbf{v}_t}{1 - \beta_1^t} \text{ and } \hat{\mathbf{s}}_t = \frac{\mathbf{s}_t}{1 - \beta_2^t}. v^t=1β1tvt and s^t=1β2tst.

有了正确的估计,我们现在可以写出更新方程。
首先,我们以非常类似于RMSProp算法的方式重新缩放梯度以获得

g t ′ = η v ^ t s ^ t + ϵ . \mathbf{g}_t' = \frac{\eta \hat{\mathbf{v}}_t}{\sqrt{\hat{\mathbf{s}}_t} + \epsilon}. gt=s^t +ϵηv^t.

与RMSProp不同,我们的更新使用动量 v ^ t \hat{\mathbf{v}}_t v^t而不是梯度本身。
此外,由于使用 1 s ^ t + ϵ \frac{1}{\sqrt{\hat{\mathbf{s}}_t} + \epsilon} s^t +ϵ1而不是 1 s ^ t + ϵ \frac{1}{\sqrt{\hat{\mathbf{s}}_t + \epsilon}} s^t+ϵ 1进行缩放,两者会略有差异。
前者在实践中效果略好一些,因此与RMSProp算法有所区分。
通常,我们选择 ϵ = 1 0 − 6 \epsilon = 10^{-6} ϵ=106,这是为了在数值稳定性和逼真度之间取得良好的平衡。

最后,我们简单更新:

x t ← x t − 1 − g t ′ . \mathbf{x}_t \leftarrow \mathbf{x}_{t-1} - \mathbf{g}_t'. xtxt1gt.

回顾Adam算法,它的设计灵感很清楚:
首先,动量和规模在状态变量中清晰可见,
它们相当独特的定义使我们移除偏项(这可以通过稍微不同的初始化和更新条件来修正)。
其次,RMSProp算法中两项的组合都非常简单。
最后,明确的学习率 η \eta η使我们能够控制步长来解决收敛问题。

实现

从头开始实现Adam算法并不难。
为方便起见,我们将时间步 t t t存储在hyperparams字典中。
除此之外,一切都很简单。

%matplotlib inline
import torch
from d2l import torch as d2l


def init_adam_states(feature_dim):
    v_w, v_b = torch.zeros((feature_dim, 1)), torch.zeros(1)
    s_w, s_b = torch.zeros((feature_dim, 1)), torch.zeros(1)
    return ((v_w, s_w), (v_b, s_b))

def adam(params, states, hyperparams):
    beta1, beta2, eps = 0.9, 0.999, 1e-6
    for p, (v, s) in zip(params, states):
        with torch.no_grad():
            v[:] = beta1 * v + (1 - beta1) * p.grad
            s[:] = beta2 * s + (1 - beta2) * torch.square(p.grad)
            v_bias_corr = v / (1 - beta1 ** hyperparams['t'])
            s_bias_corr = s / (1 - beta2 ** hyperparams['t'])
            p[:] -= hyperparams['lr'] * v_bias_corr / (torch.sqrt(s_bias_corr)
                                                       + eps)
        p.grad.data.zero_()
    hyperparams['t'] += 1

现在,我们用以上Adam算法来训练模型,这里我们使用 η = 0.01 \eta = 0.01 η=0.01的学习率。

data_iter, feature_dim = d2l.get_data_ch11(batch_size=10)
d2l.train_ch11(adam, init_adam_states(feature_dim),
               {'lr': 0.01, 't': 1}, data_iter, feature_dim);
loss: 0.244, 0.015 sec/epoch

在这里插入图片描述

此外,我们可以用深度学习框架自带算法应用Adam算法,这里我们只需要传递配置参数。

trainer = torch.optim.Adam
d2l.train_concise_ch11(trainer, {'lr': 0.01}, data_iter)
loss: 0.254, 0.015 sec/epoch

在这里插入图片描述

Yogi

Adam算法也存在一些问题:
即使在凸环境下,当 s t \mathbf{s}_t st的二次矩估计值爆炸时,它可能无法收敛。
:cite:Zaheer.Reddi.Sachan.ea.2018 s t \mathbf{s}_t st提出了的改进更新和参数初始化。
论文中建议我们重写Adam算法更新如下:

s t ← s t − 1 + ( 1 − β 2 ) ( g t 2 − s t − 1 ) . \mathbf{s}_t \leftarrow \mathbf{s}_{t-1} + (1 - \beta_2) \left(\mathbf{g}_t^2 - \mathbf{s}_{t-1}\right). stst1+(1β2)(gt2st1).

每当 g t 2 \mathbf{g}_t^2 gt2具有值很大的变量或更新很稀疏时, s t \mathbf{s}_t st可能会太快地“忘记”过去的值。
一个有效的解决方法是将 g t 2 − s t − 1 \mathbf{g}_t^2 - \mathbf{s}_{t-1} gt2st1替换为 g t 2 ⊙ s g n ( g t 2 − s t − 1 ) \mathbf{g}_t^2 \odot \mathop{\mathrm{sgn}}(\mathbf{g}_t^2 - \mathbf{s}_{t-1}) gt2sgn(gt2st1)
这就是Yogi更新,现在更新的规模不再取决于偏差的量。

s t ← s t − 1 + ( 1 − β 2 ) g t 2 ⊙ s g n ( g t 2 − s t − 1 ) . \mathbf{s}_t \leftarrow \mathbf{s}_{t-1} + (1 - \beta_2) \mathbf{g}_t^2 \odot \mathop{\mathrm{sgn}}(\mathbf{g}_t^2 - \mathbf{s}_{t-1}). stst1+(1β2)gt2sgn(gt2st1).

论文中,作者还进一步建议用更大的初始批量来初始化动量,而不仅仅是初始的逐点估计。

def yogi(params, states, hyperparams):
    beta1, beta2, eps = 0.9, 0.999, 1e-3
    for p, (v, s) in zip(params, states):
        with torch.no_grad():
            v[:] = beta1 * v + (1 - beta1) * p.grad
            s[:] = s + (1 - beta2) * torch.sign(
                torch.square(p.grad) - s) * torch.square(p.grad)
            v_bias_corr = v / (1 - beta1 ** hyperparams['t'])
            s_bias_corr = s / (1 - beta2 ** hyperparams['t'])
            p[:] -= hyperparams['lr'] * v_bias_corr / (torch.sqrt(s_bias_corr)
                                                       + eps)
        p.grad.data.zero_()
    hyperparams['t'] += 1

data_iter, feature_dim = d2l.get_data_ch11(batch_size=10)
d2l.train_ch11(yogi, init_adam_states(feature_dim),
               {'lr': 0.01, 't': 1}, data_iter, feature_dim);
loss: 0.245, 0.015 sec/epoch

在这里插入图片描述

小结

  • Adam算法将许多优化算法的功能结合到了相当强大的更新规则中。
  • Adam算法在RMSProp算法基础上创建的,还在小批量的随机梯度上使用EWMA。
  • 在估计动量和二次矩时,Adam算法使用偏差校正来调整缓慢的启动速度。
  • 对于具有显著差异的梯度,我们可能会遇到收敛性问题。我们可以通过使用更大的小批量或者切换到改进的估计值 s t \mathbf{s}_t st来修正它们。Yogi提供了这样的替代方案。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/447451.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

从element-plus 引入ILoadingInstance 出现类型错误

具体报错如下图所示: 1、引入ILoadingInstance 出现错误: 解决问题如下所示: 可能是因为element-plus 多次升级原因,将原来的内部代码多次改下了,原来是loading.type文件,现在变成loading.js,包…

卷积神经网络必备基础

卷积神经网络(Convolutional Neural Network, CNN) 传统的全连接神经网络并不适用于图像处理,这是因为:每个像素点都是一个输入特征,随着层数的增加,参数以指数级增长,而图片的像素点往往非常…

【STM32】HAL库 CubeMX 教程 --- 高级定时器 TIM1 定时

实验目标: 通过CUbeMXHAL,配置TIM1,1s中断一次,闪烁LED。 一、常用型号的TIM时钟频率 1. STM32F103系列: 所有 TIM 的时钟频率都是72MHz;F103C8不带基本定时器,F103RC及以上才带基本定时器。…

2024.3.10 win11系统设置环境变量的方法

2024.3.10 win11系统设置环境变量的方法 win11和其他版本略有区别,以安装maven为例进行操作。 一、鼠标右键点击下拉菜单中选择“个性化” 二、点击个性化中选项后在左侧菜单选择“系统” 三、在右侧系统项目中选择“系统信息” 四、在系统信息中选择“高级系统…

Android将自己写的maven库上传至jitpack(2024靠谱版)

浏览了一堆陈年旧贴,终于实验成功了 第一步 将自建空项目同步至github并保证能正常运行第二步新增一个library类型的modul第三步 在新建的library里面写一些测试用的代码第四步在library的gradle文件增加插件和发布脚本第五步新建一个配置文件第六步 把所有更改push…

面试题:分布式锁用了 Redis 的什么数据结构

在使用 Redis 实现分布式锁时,通常使用 Redis 的字符串(String)。Redis 的字符串是最基本的数据类型,一个键对应一个值,它能够存储任何形式的字符串,包括二进制数据。字符串类型的值最多可以是 512MB。 Re…

基于java+springboot+vue实现的火车票订票系统(文末源码+Lw)294

摘要 火车票订票系统可以对火车票订票系统信息进行集中管理,可以真正避免传统管理的缺陷。火车票订票系统是一款运用软件开发技术设计实现的应用系统,在信息处理上可以达到快速的目的,不管是针对数据添加,数据维护和统计&#xf…

SaulLM-7B: A pioneering Large Language Model for Law

SaulLM-7B: A pioneering Large Language Model for Law 相关链接:arxiv 关键字:Large Language Model、Legal Domain、SaulLM-7B、Instructional Fine-tuning、Legal Corpora 摘要 本文中,我们介绍了SaulLM-7B,这是为法律领域量…

Keepalived+LVS构建高可用集群

目录 一、Keepalive基础介绍 1. Keepalive与VRRP 2. VRRP相关技术 3. 工作原理 4. 模块 5. 架构 6. 安装 7. Keepalived 相关文件 7.1 配置组成 7.2 全局配置 7.3 VRRP实例配置(lvs调度器) 7.4 虚拟服务器与真实服务器配置 二、Keepalived…

用*把棱形画出来

输入一个整数n表示棱形的对角半长度&#xff0c;请你用*把这个棱形画出来。 输入&#xff1a;1输出&#xff1a;*输入&#xff1a;3输出&#xff1a;**** *********输入输出格式 输入描述: 输入一个整数n&#xff08;n < 10&#xff09;。 输出描述: 按题目要求输出字符棱…

Jobs Portal求职招聘系统源码v3.5版本

源码介绍: Jobs Portal 求职招聘系统 是为求职者和公司发布职位而开发的交互式求职招聘源码。它使求职者能够发布简历、搜索工作、查看个人工作列表。它将提供各种公司在网站上放置他们的职位空缺资料&#xff0c;并且还可以选择搜索候选人简历。除此之外&#xff0c;还有一个…

Jupyter Notebook使用教程——从Anaconda环境构建到Markdown、LaTex语法介绍

0. 前言 按照国际惯例&#xff0c;首先声明&#xff1a;本文只是我自己学习的理解&#xff0c;虽然参考了他人的宝贵见解及成果&#xff0c;但是内容可能存在不准确的地方。如果发现文中错误&#xff0c;希望批评指正&#xff0c;共同进步。 你是否在视频教程或说明文档或Githu…

TC7.0简单编程十六进制跟十进制转化函数

写脚本的时候&#xff0c;没用内存的功能什么的&#xff0c;基本跟十六进制用得都比较少。最近因为易语言的一个代码要转化过来&#xff0c;看到易语言里面有现成的函数 16到10 跟10 到16&#xff0c;就想着TC是否也有这样的函数。找来找去没找到。其实TC也有这样的函数来的。藏…

misc49

下载附件是个txt文件&#xff0c;打开发现是个压缩包的头 后缀改成zip后打开 base解码无果&#xff0c;我们尝试字母解码 然后音符解码得到 ❀✿✼❇❃❆❇✿❁❇✻✿❀✾✿✻❀❊❆❃❀❊✻❅❀❄✼❂❊❊✾❇❁✽✽✼❁❂❀❀❀❉❃❂❀❉❃❂❊❊✾✼✻✻❀❆✻✻❀❀✻✻✿…

C#MQTT编程10--MQTT项目应用--工业数据上云

1、文章回顾 这个系列文章已经完成了9个内容&#xff0c;由浅入深地分析了MQTT协议的报文结构&#xff0c;并且通过一个有效的案例让伙伴们完全理解理论并应用到实际项目中&#xff0c;这节继续上马一个项目应用&#xff0c;作为本系列的结束&#xff0c;奉献给伙伴们&#x…

Visual Studio单步调试中监视窗口变灰的问题

在vs调试中&#xff0c;写了这样一条语句 while((nfread(buf, sizeof(float), N, pf))>0) 然而&#xff0c;在调试中&#xff0c;只要一执行while这条语句&#xff0c;监视窗口中的变量全部变为灰色&#xff0c;不能查看&#xff0c;是程序本身并没有报错&#xff0c;能够继…

简单了解一个数据包在网络的一生

在主题之前&#xff0c;我想先谈谈目前计算机的网络模型&#xff0c;主要谈谈 TCP/IP 模型&#xff1a; 应用层&#xff1a;产生最原始的数据&#xff0c;常见协议如 http、ftp、websocket、DNS、QUIC 传输层&#xff1a;传递应用层的数据给网络层&#xff0c;必要时进行切割&…

[Angular 基础] - 表单:响应式表单

[Angular 基础] - 表单&#xff1a;响应式表单 之前的笔记&#xff1a; [Angular 基础] - routing 路由(下) [Angular 基础] - Observable [Angular 基础] - 表单&#xff1a;模板驱动表单 开始 其实这里的表单和之前 Template-Driven Forms 没差很多&#xff0c;不过 Tem…

从16-bit 到 1.58-bit :大模型内存效率和准确性之间的最佳权衡

通过量化可以减少大型语言模型的大小&#xff0c;但是量化是不准确的&#xff0c;因为它在过程中丢失了信息。通常较大的llm可以在精度损失很小的情况下量化到较低的精度&#xff0c;而较小的llm则很难精确量化。 什么时候使用一个小的LLM比量化一个大的LLM更好? 在本文中&a…

C 嵌套循环

C 语言允许在一个循环内使用另一个循环&#xff0c;下面演示几个实例来说明这个概念。 语法 C 语言中 嵌套 for 循环 语句的语法&#xff1a; for (initialization; condition; increment/decrement) {statement(s);for (initialization; condition; increment/decrement){s…
最新文章