动态调整学习率Lr

动态调整学习率Lr

  • 0 引入
  • 1 代码例程
    • 1.1 工作方式解释
  • 2 动态调整学习率的几种方法
    • 2.1 lr_scheduler.LambdaLR
    • 2.2 lr_scheduler.StepLR
    • 2.3 lr_scheduler.MultiStepLR
    • 2.4 lr_scheduler.ExponentialLR
    • 2.2.5 lr_scheduler.CosineAnnealingLR
    • 2.6 lr_scheduler.ReduceLROnPlateau
    • 2.7 lr_scheduler.CyclicLR
  • 参考

0 引入

在训练深度学习模型时,不可避免的要调整超参,而学习率首当其冲是大家最先想要调整的一个超参。而且学习率对于模型训练效果来说也相当重要。
然鹅,学习率过低会导致学习速度太慢,学习率过高又容易导致难以收敛。
因此,很多炼丹师都会采用动态调整学习率的方法。刚开始训练时,学习率大一点,以加快学习速度;之后逐渐减小来寻找最优解。
那么在Pytorch中,如何**在训练过程里动态调整学习率呢?**本文将带你深入理解优化器和学习率调整策略。

1 代码例程

1.1 工作方式解释

自定义学习率调度器:torch.optim.lr_scheduler.LambdaLR

torch.optim.lr_scheduler.LambdaLR(optimizer,lr_lambda = lr_lambda)
  • optimizer:优化器对象,表示需要调整学习率的优化器。
  • Ir_lambda:一个函数,接受一个整数参数epoch,并返回一个浮点数值,表示当前epoch下的学习率变化系数。

每个epoch 更新一次lr:
image.png

每个batch 更新一次学习率:
image.png

举例子: 新的学习率 = 原始学习率 * 学习率因子
image.png

学习率调度器:

image.png

2 动态调整学习率的几种方法

一般地,模型训练和人刻意练习一项技能一样,需要反复地学习才能加固认知。模型也是一样,需要反复学习多次给定的数据集,即要训练多个epoch。在Pytorch中,给我们提供了很多种动态调整学习率的策略,这些调整策略多是基于epoch进行的。下面一一进行讲解。

2.1 lr_scheduler.LambdaLR

目的:将每个参数组的学习率设置为初始 lr 乘以给定的Lambda函数根据epoch计算出来的值。

CLASS torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch=- 1, verbose=False)

参数详情
optimizer (Optimizer):关联的优化器
lr_lambda (function or list):给定当前epoch计算乘法因子的函数,如果是函数列表,则每一个函数对应一个参数组,这样就可以为不同的参数组设置不同的学习率。
last_epoch (int):最后一个epoch的索引,默认值为-1.
verbose(bool):如果为True,为每次更新打印输出,默认值False

举例子: 线性递减的lr

# 假设优化器有两个参数组.
lambda1 = lambda epoch: epoch // 30
lambda2 = lambda epoch: 0.95 ** epoch
scheduler = LambdaLR(optimizer, lr_lambda=[lambda1, lambda2])
for epoch in range(100):
    train(...)
    validate(...)
    scheduler.step()

2.2 lr_scheduler.StepLR

目的:每step_size个epochs通过γ降低每个参数组的学习率。

CLASS torch.optim.lr_scheduler.StepLR(optimizer, step_size, gamma=0.1, last_epoch=- 1, verbose=False)

参数详情

optimizer (Optimizer):关联的优化器
step_size (int):学习率更新的周期,每step_size个epochs更新一次
gamma (float):降低学习率的乘法因子,默认值0.1
last_epoch(int):最后一个epoch的索引,默认值-1
verbose(bool):如果为True,为每次更新打印输出,默认值False

举例子

# 假设所有参数组的初始学习率为0.05,γ为0.1
# lr = 0.05     if epoch < 30
# lr = 0.005    if 30 <= epoch < 60
# lr = 0.0005   if 60 <= epoch < 90
# ...
scheduler = StepLR(optimizer, step_size=30, gamma=0.1)
for epoch in range(100):
    train(...)
    validate(...)
    scheduler.step()

2.3 lr_scheduler.MultiStepLR

目的:当epoch到达设定的标志位时,通过γ降低每个参数组的学习率。请注意,这种衰减可能与此调度程序外部对学习率的其他更改同时发生。

CLASS torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones, gamma=0.1, last_epoch=- 1, verbose=False)

参数详情
optimizer (Optimizer):关联的优化器
milestones (list):epoch索引的列表,必须是升序的,因为epoch是越来越大的
gamma (float):降低学习率的乘法因子,默认值0.1
last_epoch(int):最后一个epoch的索引,默认值-1
verbose(bool):如果为True,为每次更新打印输出,默认值False

例子

# 假设所有参数组的初始学习率为0.05,γ为0.1
# lr = 0.05     if epoch < 30
# lr = 0.005    if 30 <= epoch < 80
# lr = 0.0005   if epoch >= 80
scheduler = MultiStepLR(optimizer, milestones=[30,80], gamma=0.1)
for epoch in range(100):
    train(...)
    validate(...)
    scheduler.step()

2.4 lr_scheduler.ExponentialLR

目的:对于每个epoch,使用γ降低每个参数组的学习率。以 gama 为底,epoch为指数,
在这里插入图片描述

参数详情
optimizer (Optimizer):关联的优化器
gamma (float):降低学习率的乘法因子,默认值0.1
last_epoch(int):最后一个epoch的索引,默认值-1
verbose(bool):如果为True,为每次更新打印输出,默认值False

2.2.5 lr_scheduler.CosineAnnealingLR

目的:使用余弦退火策略动态调整每个epoch的学习率。

CLASS torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max, eta_min=0, last_epoch=- 1, verbose=False)

参数详情:
optimizer (Optimizer):关联的优化器
T_max (int):最大的迭代次数
eta_min(float):最小的学习率,默认值0
last_epoch(int):最后一个epoch的索引,默认值-1
verbose(bool):如果为True,为每次更新打印输出,默认值False

计算公式如下:
image.png
其中,\eta_{t}表示更新后的学习率,ηmin 表示最小的学习率,ηmax表示最大的学习率,Tcur表示当前epoch的索引,Tmax表示最大epoch的索引。

2.6 lr_scheduler.ReduceLROnPlateau

目的:当指标停止改进时便开始降低学习率。一旦学习停滞,将学习率降低 2-10 倍之后,模型通常会继续学习。 该调度程序读取一个指标数量,如果一定数量的 epoch 后仍没有改善,则学习率会降低。

CLASS torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10, threshold=0.0001, threshold_mode='rel', cooldown=0, min_lr=0, eps=1e-08, verbose=False)

参数详情
optimizer (Optimizer):关联的优化器
mode(str):有min和max两种可选。在min模式下,当模型指标停止下降时开始调整学习率;在max模式下,当模型指标停止上升时开始调整学习率
factor (float):降低学习率的乘法因子.lr_new = lr ∗ factor
patience(int):耐心值,不是指标停止改变后立马就调整学习率,而是patience个epoch之后指标仍没有改变,便开始调整学习率,即该策略对指标的变化有一定的容忍度。默认值10
threshold(float):衡量新的最佳阈值,只关注重大变化。默认值:1e-4
cooldown(int):在 lr 减少后恢复正常操作之前要等待的 epoch 数
min_lr (float or list):每个参数组设定的最小学习率
eps (float):如果调整之后的学习率和调整之前的学习率差距小于eps,则忽略此次调整
verbose(bool):如果为True,为每次更新打印输出,默认值False

例子

optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
scheduler = ReduceLROnPlateau(optimizer, 'min')
for epoch in range(10):
    train(...)
    val_loss = validate(...)
    # Note that step should be called after validate()
    scheduler.step(val_loss)

2.7 lr_scheduler.CyclicLR

目的:根据周期性学习率策略调整每个参数组的学习率。该策略以恒定频率在两个边界之间循环调整学习率,详见论文 Cyclical Learning Rates for Training Neural Networks。 两个边界之间的距离可以在每次迭代或每个周期的基础上进行缩放。循环学习率策略在每个Batch之后改变学习率。 step()方法应该在一个Batch用于训练后调用。

该方法有三种循环策略:
triangular“:没有幅度缩放的基本三角循环
”triangular2“:一个基本的三角循环,每个循环将初始幅度缩放一半
”exp_range“:每个循环之后,将初始幅度按照 \gamma^{cycle-iterations}进行降低

例子

optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=0.01, max_lr=0.1)
data_loader = torch.utils.data.DataLoader(...)
for epoch in range(10):
    for batch in data_loader:
        train_batch(...)
        scheduler.step()

参考

深度学习学习率调整方案如何选择? - Summer Clover的回答 - 知乎
https://www.zhihu.com/question/315772308/answer/1636730368

[]深度学习学习率调整方案如何选择?
https://www.zhihu.com/question/315772308

[pytorch 动态调整学习率,学习率自动下降,根据loss下降]
https://blog.csdn.net/qq_41554005/article/details/119879911

[深度学习之动态调整学习率LR]
https://blog.csdn.net/Just_do_myself/article/details/123751148

[pytorch优化器与学习率设置详解]
https://zhuanlan.zhihu.com/p/435669796

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/139664.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

后门程序2

System\CurrentControlSet\Services\Disk\Enum Windows 操作系统注册表中的一个路径。这个路径通常包含有关磁盘设备的信息。在这个特定的路径下&#xff0c;可能存储了有关磁盘枚举的配置和参数 Enum&#xff08;枚举&#xff09;子键通常包含了系统对磁盘的枚举信息&#xf…

网络安全/黑客技术(0基础入门到进阶提升)

前言 前几天发布了一篇 网络安全&#xff08;黑客&#xff09;自学 没想到收到了许多人的私信想要学习网安黑客技术&#xff01;却不知道从哪里开始学起&#xff01;怎么学 今天给大家分享一下&#xff0c;很多人上来就说想学习黑客&#xff0c;但是连方向都没搞清楚就开始学习…

高通SDX12:ASoC 音频框架浅析

一、简介 ASoC–ALSA System on Chip ,是建立在标准ALSA驱动层上,为了更好地支持嵌入式处理器和移动设备中的音频Codec的一套软件体系。 本文基于高通SDX12平台,对ASoC框架做一个分析。 二、整体框架 1. 硬件层面 嵌入式Linux设备的Audio subsystem可以划分为Machine(板…

关于企业海外Social平台营销布局,你需要了解这三件事

01 企业Social营销布局模式 Social营销走到现在&#xff0c;早已进入了标准配置期。任何企业和组织&#xff0c;进行营销宣传的时候都会在社媒社交平台上创建账号和运营。目前&#xff0c;海外Social平台营销模式基本分为四类&#xff1a; 官方社媒账号运营&#xff1a;以Hoot…

HTML5学习系列之主结构

HTML5学习系列之主结构 前言HTML5主结构定义页眉定义导航定义主要区域定义文章块定义区块定义附栏定义页脚 具体使用总结 前言 学习记录 HTML5主结构 定义页眉 head表示页眉&#xff0c;用来表示标题栏&#xff0c;引导和导航作用的结构元素。 <header role"banner…

使用Python和requests库的简单爬虫程序

这是一个使用Python和requests库的简单爬虫程序。我们将使用代理来爬取网页内容。以下是代码和解释&#xff1a; import requests from fake_useragent import UserAgent # 每行代理信息 proxy_host "jshk.com.cn" # 创建一个代理器 proxy {http: http:// proxy_…

SQLServer添加Oracle链接服务器

又一次在项目中用到了在SQLServer添加Oracle链接服务器&#xff0c;发现之前文章写的也不太好使&#xff0c;那就再总结一次吧。 1、安装OracleClient 安装64位&#xff0c;多数SQLServer是64位&#xff0c;所以OracleClient也安装64位的&#xff1b; 再一个一般安装的Oracl…

Python二级 每周练习题26

如果你感觉有收获&#xff0c;欢迎给我打赏 ———— 以激励我输出更多优质内容 练习一: 从键盘输入任意字符串&#xff0c;按照下面要求分离字符串中的字符&#xff1a; 1、分别取出该字符串的第偶数位的元素&#xff08;提醒注意&#xff1a;是按照从左往右数的 方式确定字…

新增配置字

新增的配置字 “使用的集成功放通道数量”&#xff0c;我的理解 channelnum 和 有没有AVAS有关&#xff0c;油车的话没有AVAS&#xff1b; a2b :有外置功放肯定有a2b; amp1:内置功放&#xff1b; amp2:Avas 用的内置功放&#xff1b;

如何防止听力下降?

听力受损是不可逆的&#xff0c;一旦听力下降了是无法恢复的&#xff0c;所以当我们出现听力障碍的时候&#xff0c;我们更应该注意我们的耳朵&#xff0c;想想如何能保护我们的残余听力&#xff01; 今天来告诉大家&#xff0c;哪些事是有易于听力的&#xff0c;一起来看看吧…

一款免费好用的制作电子杂志网站,发现新大陆~

你是不是也厌倦了传统纸质杂志的限制&#xff0c;想要尝试一种全新的阅读体验&#xff1f;那么&#xff0c;今天我要向你推荐的这款免费好用的制作电子杂志网站&#xff0c;绝对能让你眼前一亮&#xff01; 这款网站就是FLBOOK在线制作电子杂志平台&#xff0c;并且界面简洁、操…

Kubernetes介绍和环境部署

文章目录 Kubernetes一、Kubernetes介绍1.Kubernetes简介2.Kubernetes概念3.Kubernetes功能4.Kubernetes工作原理5.kubernetes组件6.Kubernetes优缺点 二、Kubernetes环境部署环境基本配置1.所有节点安装docker2.所有节点安装kubeadm、kubelet、kubectl添加yum源containerd配置…

【带头学C++】----- 五、字符串操作函数 ---- 5.1 字符串操作函数

5.1字符串操作函数(以str开头的字符串处理函数默认遇到\0结束操作) 5.1.1 测量字符串的长度strlen() strlen() 函数用于计算一个字符串的长度。 #include <string.h> //注意&#xff1a;该头文件必须包含 size_t strlen(const char *s); // s指的是需要测量字符串的首地…

【T690 之十一】基于方寸EVB2开发板,结合 Eclipse+gdb+gdbserver 调试 CCAT 的流程总结

目录 1. 准备工作1.1 Eclipse1.2 工程编译1.3 烧写固件 2. 创建工程2.1 搭建调试工程2.2 配置Dbug调试信息 3. 调试4. 手动调试过程4. 总结 备注&#xff1a; 1&#xff0c;假设您已对方寸微电子的T690系列芯片的使用方式都有了一定的了解&#xff0c;可以根据此文的配置进行Li…

前端中 JS 发起的请求可以暂停吗?

给大家推荐一个实用面试题库 1、前端面试题库 &#xff08;面试必备&#xff09; 推荐&#xff1a;★★★★★ 地址&#xff1a;web前端面试题库 这个问题非常有意思&#xff0c;我一看到就想了很多可以回复的答案&#xff0c;但是评论区太窄&#xff0c;就直接开…

印刷包装服务预约小程序的作用是什么

印刷包装厂家非常多&#xff0c;其主要服务为名片印刷、礼品纸袋定制、画册宣传单印刷等&#xff0c;这些服务对大多数企业都有很高的需求&#xff0c;同时具备批量、长期合作属性&#xff0c;同时具备跨区域合作性&#xff0c;所以品牌可扩展度高。 但高需求的同时&#xff0…

查询数据表格中的数据

1.创建这个表至少20个 1&#xff09;创建数据库&#xff1a;create database 四川信息职业技术; 2&#xff09;创建数据表 3&#xff09;插入数据&#xff08;第一条代码修改了一下手机号码的字段类型&#xff09; 2.统计表中的人数 如果你想根据某个特定的列来统计人数&…

mysql、mysql+python

一、 window端mysql免费版&#xff1a; &#xff08;未特别描述则不做更改直接点下一步&#xff09; 下载地址&#xff1a;https://downloads.mysql.com/archives/installer mysql安装好后添加path&#xff1a; 将MySQL安装目录的bin文件夹的路径复制&#xff0c;点新建添加…

C++:对象成员方法的使用

首先复习一下const : //const: //Complex* const pthis1 &ca; //约束指针自身 不能指向其他对象 // pthis1 &cb; err //pthis1->real; //const Complex* const pthis1 &ca;//指针指向 指针自身 都不能改 //pthis1->real; 只可读 …

算法细节类错误

1.使用全局变量时&#xff0c;若有多组测试数据 应该注意在循坏中重新初始化全局变量 例如&#xff1a;
最新文章