学习率范围测试(LR Finder)脚本

简介

深度学习中的学习率是模型训练中至关重要的超参数之一。合适的学习率可以加速模型的收敛,提高训练效率,而不恰当的学习率可能导致训练过慢或者无法收敛。为了找到合适的学习率,LR Finder成为了一种强大的工具。

学习率范围测试(LR Finder)是一种通过逐渐增加学习率来观察模型在不同学习率下的性能变化的方法。这个过程可以帮助我们找到一个合适的初始学习率,有助于训练过程的稳定和加速。在本文中,我们将深入探讨 LR Finder 的原理、实现和应用,以及如何在实际的分类项目中充分利用这个强大的工具。

学习率在深度学习中的关键

深度学习当中,学习率是一个至关重要的超参数,直接影响了模型的训练性能和收敛速度。因为它决定了每一次参数更新的步长,过大的学习率可能会导致模型在训练过程中发散,而过小的学习率则可能导致模型训练缓慢甚至停滞。

学习率调整方法

下面是一些常见的学习率调整方法,它们在不同的场景和问题上表现出色。

1.常数学习率

常数学习率是最简单的学习率调整方法,即在整个训练过程中,学习率始终保持不变,但是它只是比较适用于数据集较小、模型已经比较稳定的情况。

用法很简单,就是在训练开始时选择一个合适的学习率,比如0.001或0.01,后续的训练过程中就不再变化了。这个值的设定就比较依赖于的经验了。

2.学习率衰减

学习率衰减是随着训练的进行逐渐减小学习率。它可以分为定期调整和根据模型性能调整两种方式,它适用于需要在训练初期更加收敛,而在后期更加稳定的情况。

  • 定期调整:每隔一段训练轮次或步骤,学习率按照一定的衰减率进行调整。
  • 性能调整:根据模型在验证集上的性能进行调整,性能停滞时降低学习率,性能提高时保持学习率不变或轻微提高。

3.自适应学习率

自适应学习率方法根据模型参数的梯度信息和历史信息来自动调整学习率。主要应用于不同参数具有不同特性或数据分布不均匀的情况。

  • Adam: 结合了动量法和自适应学习率的方法,对不同参数应用不同的学习率。
  • Adagrad: 根据参数的历史梯度信息自适应地调整学习率。
  • RMSProp: 在 Adagrad 的基础上加入了一个衰减系数,以防止学习率过早地降低。

学习率范围测试(LR Finder)

学习率范围测试是一种通过逐渐增加学习率来观察模型性能的方法,从而找到一个合适的初始学习率,适合在训练初期选择合适的学习率范围。

实现的方法:

逐渐增加学习率,从一个较小的学习率开始,逐渐增加学习率直到性能开始下降,在学习率范围内观察模型的性能曲线,找到性能峰值对应的学习率。

class FindLR(_LRScheduler):
    """
    exponentially increasing learning rate

    Args:
        optimizer: optimzier(e.g. SGD)
        num_iter: totoal_iters
        max_lr: maximum  learning rate
    """
    def __init__(self, optimizer, max_lr=10, num_iter=100, last_epoch=-1):

        self.total_iters = num_iter
        self.max_lr = max_lr
        super().__init__(optimizer, last_epoch)

    def get_lr(self):
        return [base_lr * (self.max_lr / base_lr) ** (self.last_epoch / (self.total_iters + 1e-32)) for base_lr in self.base_lrs]

在 get_lr 方法中,通过指数增长的方式计算当前迭代次数下每个参数的学习率。这个方法返回一个学习率列表,其中每个元素对应网络中的一个参数。这个学习率列表将被用于更新优化器中的学习率。

import torch
import torch.optim as optim
from pyzjr.core.lr_scheduler import _LRScheduler
import matplotlib.pyplot as plt

class FindLR(_LRScheduler):
    def __init__(self, optimizer, max_lr=10, num_iter=100, last_epoch=-1):
        self.total_iters = num_iter
        self.max_lr = max_lr
        super().__init__(optimizer, last_epoch)

    def get_lr(self):
        return [base_lr * (self.max_lr / base_lr) ** (self.last_epoch / (self.total_iters + 1e-32)) for base_lr in self.base_lrs]

x = torch.arange(0, 100, 1)
y = torch.sin(0.1 * x) + 0.1 * torch.randn(100)

model = torch.nn.Linear(1, 1)
optimizer = optim.SGD(model.parameters(), lr=0.01)

lr_finder = FindLR(optimizer, max_lr=10, num_iter=100)
lr_values = []

for epoch in range(100):
    outputs = model(x.unsqueeze(1).float())
    loss = torch.nn.functional.mse_loss(outputs.squeeze(), y)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    lr_finder.step()

    current_lr = optimizer.param_groups[0]['lr']
    lr_values.append(current_lr)

    print(f"Iteration: {epoch}, Loss: {loss.item():.4f}, LR: {current_lr:.8f}")

plt.switch_backend('TkAgg')
plt.plot(lr_values)
plt.xlabel('Iteration')
plt.ylabel('Learning Rate')
plt.title('Learning Rate Range Test')
plt.show()


迭代与学习率的曲线图:

下面我们将采用pyzjr中定义的 lr_finder 对CIFAR-100进行学习率范围测试:

import torch.nn as nn

import matplotlib
matplotlib.use('Agg')
from utils import get_network
from dataset import get_train_loader
from pyzjr.dlearn.learnrate import lr_finder

CIFAR100_TRAIN_MEAN = (0.5070751592371323, 0.48654887331495095, 0.4409178433670343)
CIFAR100_TRAIN_STD = (0.2673342858792401, 0.2564384629170883, 0.27615047132568404)

if __name__ == '__main__':
    class parser_args():
        def __init__(self):
            self.net = "vgg16"
            self.batch_size = 64
            self.base_lr = 1e-7
            self.max_lr = 10
            self.num_iter = 100
            self.Cuda = True

    args = parser_args()

    mean = CIFAR100_TRAIN_MEAN
    std = CIFAR100_TRAIN_STD
    train_loader = get_train_loader(mean, std, batch_size=4)

    net = get_network(args)  # 网络模型定义

    loss_function = nn.CrossEntropyLoss()
    lrfinder = lr_finder(net, train_loader, loss_function)
    lrfinder.update()   # 不断迭代
    lrfinder.plotshow()   # 绘制图像并显示

    lrfinder.save(path="result.jpg")   # 保存图像到指定路径

输出图像result.jpg:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/137801.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

springboot项目基本配置

接口入口日志 参数校验 业务逻辑执行 异常捕获-统一异常处理 统一数据返回体 接口返回日志 使用的是springboot2.x版本。 Mybatisplus 官网地址&#xff1a;https://baomidou.com/ 导入依赖 <dependency><groupId>com.baomidou</groupId><artifactId&g…

前端工具nvm实现node自由

node的自由之路 前言 大家使用vue框架开发的朋友可能会遇到首次运行公司项目环境的时候&#xff0c;会出现使用npm install命令安装依赖包的时候出现各种各样的问题&#xff0c;其中很重要的一个错误原因就是因为你的nodejs版本和当时搭建环境的版本不一致造成的。今天就来给…

c语言-数据结构-栈和队列的实现和解析

目录 一、栈 1、栈的概念 1.2 栈的结构 2、栈的创建及初始化 3、压栈操作 4、出栈操作 5、显示栈顶元素 6、显示栈空间内元素的总个数 7、释放栈空间 8、测试栈 二、队列 1、队列的概念 1.2 队列的结构 2、队列的创建及初始化 3、入队 4、出队 5、显示队头、队…

creo之混合和扫描混合

案例一&#xff1a;杯子 步骤&#xff1a; 在top平面画一个草图圆角矩形&#xff1a; 然后形状–》混合 然后绘制新增的截面2&#xff1a; 用中心线将圆分割成八分&#xff0c;因为底部的圆角矩形是八份线段组成&#xff0c;所以我们要和他一样分成八份&#xff1a;先画中心线…

详解Java:抽象类和接口

前言&#xff1a;在前文中我们学习认知到了多态的使用和相关知识&#xff0c;算是打开了Java世界的大门&#xff0c;而本次要分享的抽象类和接口则是我们在面向对象编程中最常用的编程结构之一 目录 一.抽象类 abstract 抽象类特性 二.接口 语法规则 接口使用 接口特…

第十七章jQuery中的事件与动画

一。常用事件&#xff1a; 1.鼠标事件&#xff1a; mouseover()&#xff1a;在鼠标进入内容后一直显示事件 mouseout()&#xff1a;在鼠标离开内容后一直显示事件 mouseenter()&#xff1a;在进入刹那间显示事件 mouseleave()&#xff1a;在退出刹那间显示事件 案例&#xf…

windows系统升级powershell7

利用winget在线升级脚本 REM --accept-package-agreements 接受包的所有许可协议 REM --accept-source-agreements 在源操作期间接受所有源协议 winget install Microsoft.PowerShell --accept-source-agreements 应该在路径C:\Program Files\PowerShel…

功能案例 -- 拖拽上传文件,生成缩略图

直接看效果 实现代码 <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><title>拖拽上传文件</title>&l…

两个序列(数论)

两个序列 Problem:B Time Limit:1000ms Memory Limit:65535K Description Gugu 有两个长度无限长的序列A,BA0a^0/0!,A1a^1/1!,A2a^2/2!,A3a^3/3!…. B00, B1b^1/1!,B20,B3b^3/3!,B40, B5b^5/5! … Douge 看到这道这两个序列很觉得很麻烦&#xff0c;所以他想到一个好点子&…

EMNLP 2023 | 基于知识图谱嵌入的关系感知集成学习算法

©PaperWeekly 原创 作者 | 黄鉦皓 单位 | 清华大学 研究方向 | 图神经网络 本文介绍《基于知识图谱嵌入的关系感知集成学习算法》&#xff08;Relation-aware Ensemble Learning for Knowledge Graph Embedding&#xff09;&#xff0c;该论文提出的 RelEns-DSC 方法针对…

【Springboot】基于注解式开发Springboot-Vue3整合Mybatis-plus实现分页查询(一)——后端实现思路

系列文章目录 基于注解式开发Springboot-Vue3整合Mybatis-plus实现分页查询(二&#xff09;——前端el-pagination实现 文章目录 系列文章目录系统版本实现功能操作步骤1. 新建Mybatis的全局分页配置文件2. 编写OrderMapper :继承Mybatis-plus提供的BaseMapper3. 编写OrderSer…

Python---字典---dict

1、为什么需要字典 如果想要存储一个人的信息&#xff0c;姓名&#xff1a;Tom&#xff0c;年龄&#xff1a;20周岁&#xff0c;性别&#xff1a;男&#xff0c;如何快速存储。 person [Tom, 20, 男] 在日常生活中&#xff0c;姓名、年龄以及性别同属于一个人的基本特征。 但…

【哈夫曼树的构造】

文章目录 如何构造哈夫曼树哈夫曼树构造算法的实现 如何构造哈夫曼树 哈夫曼算法口诀&#xff1a; 1.构造森林全是根&#xff1b;2.选用两小造新树&#xff1b; 3.删除两小添新人&#xff1b;4.重复2,3剩单根&#xff1b; 例&#xff1a;有4个新结点a,b,c,d&#xff0c;权值为…

upload-labs关卡7(基于黑名单的空格绕过)通关思路

文章目录 前言一、回顾上一关知识点二、靶场第七关通关思路1、看源代码2、空格绕过3、检查文件是否成功上传 总结 前言 此文章只用于学习和反思巩固文件上传漏洞知识&#xff0c;禁止用于做非法攻击。注意靶场是可以练习的平台&#xff0c;不能随意去尚未授权的网站做渗透测试…

【Python】二维码和条形码的识别

我主要的问题就在于无法识别图片 注意事项&#xff1a; 1、从文件中加载图像的时候注意图片尽量用英文来命名&#xff0c;因为中文无法识别到图片 2、使用绝对地址的时候要用两个双斜杠&#xff0c;因为用一个会被识别为Unicode 转义&#xff0c;但是并没有后续的合法 Unico…

matlab simulink PSO算法优化simulink的PID参数

1、内容简介 略 13-可以交流、咨询、答疑 PSO算法优化simulink的PID参数 2、内容说明 标准的PSO算法优化simulink的PID参数 PSO、粒子群算法、simulink参数优化 3、仿真分析 4、参考论文 略 链接&#xff1a;https://pan.baidu.com/s/1yQ1yDfk-_Qnq7tGpa23L7g 提取码&…

docker安装RocketMQ

1、RocketMQ基本概念 1.1 消息模型&#xff08;Message Model&#xff09; RocketMQ主要由Producer、Broker、Consumer三部分组成&#xff0c;其中Producer负责生产消息&#xff0c;Consumer负责消费消息&#xff0c;Broker负责存储消息。Broker在实际部署过程中对应一台服务…

node实战——koa实现文件下载和图片/pdf/视频预览(node后端储备知识)

文章目录 ⭐前言⭐koa-send库实现下载⭐mime-types库实现图片预览&#x1f496; 渲染图片&#x1f496;渲染404&#x1f496;预览pdf&#x1f496;预览视频 ⭐总结⭐结束 ⭐前言 大家好&#xff0c;我是yma16&#xff0c;本文分享关于node实战——koa实现文件下载和图片预览。…

单链表(6)

删除第一个val的值&#xff08;考试重点&#xff09; 思路&#xff1a;例如删除val值为3的数据&#xff0c;直接让数据2的p->next指向数据4就可以了。 所以删除必须依赖前驱。也就是要写删除函数&#xff0c;则先要完成返回key的前驱地址的函数 也就是先知道前驱地址&#…

kubenetes-容器运行时接口CRI

一、CRI 容器运行时&#xff08;Container Runtime&#xff09;&#xff0c;运行于Kubernetes&#xff08;K8s&#xff09; 集群的每个节点中&#xff0c;负责容器的整个生命周期。其中Docker是目前应用最广的。随着容器云的发展&#xff0c;越来越多的容器运行时涌现。 为了解…
最新文章