人工智能-优化算法之学习率调度器

学习率调度器

到目前为止,我们主要关注如何更新权重向量的优化算法,而不是它们的更新速率。 然而,调整学习率通常与实际算法同样重要,有如下几方面需要考虑:

  • 首先,学习率的大小很重要。如果它太大,优化就会发散;如果它太小,训练就会需要过长时间,或者我们最终只能得到次优的结果。我们之前看到问题的条件数很重要。直观地说,这是最不敏感与最敏感方向的变化量的比率。

  • 其次,衰减速率同样很重要。如果学习率持续过高,我们可能最终会在最小值附近弹跳,从而无法达到最优解。

  • 另一个同样重要的方面是初始化。这既涉及参数最初的设置方式,又关系到它们最初的演变方式。这被戏称为预热(warmup),即我们最初开始向着解决方案迈进的速度有多快。一开始的大步可能没有好处,特别是因为最初的参数集是随机的。最初的更新方向可能也是毫无意义的。

  • 最后,还有许多优化变体可以执行周期性学习率调整。这超出了本章的范围,我们建议读者阅读 (Izmailov et al, 2018)来了解个中细节。例如,如何通过对整个路径参数求平均值来获得更好的解。

鉴于管理学习率需要很多细节,因此大多数深度学习框架都有自动应对这个问题的工具。 在本章中,我们将梳理不同的调度策略对准确性的影响,并展示如何通过学习率调度器(learning rate scheduler)来有效管理。

们从一个简单的问题开始,这个问题可以轻松计算,但足以说明要义。 为此,我们选择了一个稍微现代化的LeNet版本(激活函数使用relu而不是sigmoid,汇聚层使用最大汇聚层而不是平均汇聚层),并应用于Fashion-MNIST数据集。 此外,我们混合网络以提高性能。 由于大多数代码都是标准的,我们只介绍基础知识,而不做进一步的详细讨论

%matplotlib inline
import math
import torch
from torch import nn
from torch.optim import lr_scheduler
from d2l import torch as d2l


def net_fn():
    model = nn.Sequential(
        nn.Conv2d(1, 6, kernel_size=5, padding=2), nn.ReLU(),
        nn.MaxPool2d(kernel_size=2, stride=2),
        nn.Conv2d(6, 16, kernel_size=5), nn.ReLU(),
        nn.MaxPool2d(kernel_size=2, stride=2),
        nn.Flatten(),
        nn.Linear(16 * 5 * 5, 120), nn.ReLU(),
        nn.Linear(120, 84), nn.ReLU(),
        nn.Linear(84, 10))

    return model

loss = nn.CrossEntropyLoss()
device = d2l.try_gpu()

batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size=batch_size)

# 代码几乎与d2l.train_ch6定义在卷积神经网络一章LeNet一节中的相同
def train(net, train_iter, test_iter, num_epochs, loss, trainer, device,
          scheduler=None):
    net.to(device)
    animator = d2l.Animator(xlabel='epoch', xlim=[0, num_epochs],
                            legend=['train loss', 'train acc', 'test acc'])

    for epoch in range(num_epochs):
        metric = d2l.Accumulator(3)  # train_loss,train_acc,num_examples
        for i, (X, y) in enumerate(train_iter):
            net.train()
            trainer.zero_grad()
            X, y = X.to(device), y.to(device)
            y_hat = net(X)
            l = loss(y_hat, y)
            l.backward()
            trainer.step()
            with torch.no_grad():
                metric.add(l * X.shape[0], d2l.accuracy(y_hat, y), X.shape[0])
            train_loss = metric[0] / metric[2]
            train_acc = metric[1] / metric[2]
            if (i + 1) % 50 == 0:
                animator.add(epoch + i / len(train_iter),
                             (train_loss, train_acc, None))

        test_acc = d2l.evaluate_accuracy_gpu(net, test_iter)
        animator.add(epoch+1, (None, None, test_acc))

        if scheduler:
            if scheduler.__module__ == lr_scheduler.__name__:
                # UsingPyTorchIn-Builtscheduler
                scheduler.step()
            else:
                # Usingcustomdefinedscheduler
                for param_group in trainer.param_groups:
                    param_group['lr'] = scheduler(epoch)

    print(f'train loss {train_loss:.3f}, train acc {train_acc:.3f}, '
          f'test acc {test_acc:.3f}')

让我们来看看如果使用默认设置,调用此算法会发生什么。 例如设学习率为0.3并训练30次迭代。 留意在超过了某点、测试准确度方面的进展停滞时,训练准确度将如何继续提高。 两条曲线之间的间隙表示过拟合。

lr, num_epochs = 0.3, 30
net = net_fn()
trainer = torch.optim.SGD(net.parameters(), lr=lr)
train(net, train_iter, test_iter, num_epochs, loss, trainer, device)

train loss 0.128, train acc 0.951, test acc 0.885

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/208929.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Linux系统之centos7编译安装Python 3.8

前言 CentOS (Community Enterprise Operating System) 是一种基于 Red Hat Enterprise Linux (RHEL) 进行源代码再编译并免费提供给用户的 Linux 操作系统。 CentOS 7 采用了最新的技术和软件包,并提供了强大的功能和稳定性。它适用于各种服务器和工作站应用场景&a…

postgresql-shared_buffers参数详解

shared_buffers 是 PostgreSQL 中一个非常关键的参数,用于配置服务器使用的共享内存缓冲区的大小。这些缓冲区用于存储数据页,以便数据库可以更快地访问磁盘上的数据。 这个参数在 PostgreSQL 的性能方面有着重要的影响。增加 shared_buffers 可以提高数…

spring boot mybatis TypeHandler 看源码如何初始化及调用

目录 概述使用TypeHandler使用方式在 select | update | insert 中加入 配置文件中指定 源码分析配置文件指定Mapper 执行query如何转换 结束 概述 阅读此文 可以达到 spring boot mybatis TypeHandler 源码如何初始化及如何调用的。 spring boot 版本为 2.7.17,my…

【论文阅读】CAN网络中基于时序信道的隐蔽认证算法

文章目录 摘要一、引言和动机A 相关工作 二、背景及实验设置A 以前工作中的时钟偏差和局限性B.最坏到达时间C.安装组件 三、优化流量分配A.问题陈述B.优化帧调度 四、协议和结果A.主协议B.对手模型C. 优化流量和单一发送者的结果D.多发送方情况和噪声信道 摘要 以前的研究工作…

TCP三次握手过程

什么是TCP tcp是一个面向连接的、可靠的、基于字节流的传输层通信协议 面向连接:TCP连接是一对一的,不能实现一对多或多对一,TCP在通信前要首先建立连接,连接成功后才能开始进行通信可靠的:TCP连接要保证通信过程的可靠…

WebSocket 前端使用vue3+ts+elementplus 实现连接

1.配置连接 websocket.ts文件如下 import { ElMessage } from "element-plus";interface WebSocketProps {url: string; // websocket地址heartTime?: number; // 心跳时间间隔,默认为 50000 msheartMsg?: string; // 心跳信息,默认为pingr…

利用 NRF24L01 无线收发模块实现传感器数据的无线传输

NRF24L01 是一款常用的无线收发模块,适用于远程控制和数据传输应用。本文将介绍如何利用 NRF24L01 模块实现传感器数据的无线传输,包括硬件的连接和配置,以及相应的代码示例。 一、引言 NRF24L01 是一款基于 2.4GHz 射频通信的低功耗无线收发…

PG14归档失败解决办法archiver failed on wal_lsn

案例1:pg_wal下有wal_lsn文件 案例1适用于以下场景: pg_wal下有该wal_lsn文件而归档目录下无该wal_lsn文件pg_wal和归档目录下同时都有该wal_lsn文件 问题描述 昨晚RepmgrPG14主备主库因wal日志撑爆磁盘,删除主库过期wal文件重做备库后上午进行主备状…

Find My文件袋|苹果Find My技术与文件袋结合,智能防丢,全球定位

文件袋是指用于对自己的私人物品或身份文件保管等、资料袋、白卷宗。款式分为扣式和文件套式。文件袋通常具有足够的容量,可以容纳大量文件。当需要长期存档文件时,文件袋是一个方便的选择,可以将文件整理好放入文件袋中,便于存放…

[github全教程]github版本控制最全教学------- 大厂找工作面试必备!

作者:20岁爱吃必胜客(坤制作人),近十年开发经验, 跨域学习者,目前于新西兰奥克兰大学攻读IT硕士学位。荣誉:阿里云博客专家认证、腾讯开发者社区优质创作者,在CTF省赛校赛多次取得好成绩。跨领域…

【笔记】2023最新Python安装教程(Windows 11)

🎈欢迎加群交流(备注:csdn)🎈 ✨✨✨https://ling71.cn/hmf.jpg✨✨✨ 🤓前言 作为一名经验丰富的CV工程师,今天我将带大家在全新的Windows 11系统上安装Python。无论你是编程新手还是老手&…

C语言错误处理之 “<errno.h>与<error.h>”

目录 前言 错误号处理方式 errno.h头文件 error.h头文件 参数解释: 关于的”__attribute__“解释: 关于“属性”的解释: 实例一: 实例二: error.h与errno.h的区别 补充内容: 前言 在开始学习…

Redis——某马点评day01——短信登录

项目介绍 导入黑马点评项目 项目架构 基于Session实现登录 基本流程 实现发送短信验证码功能 controller层中 /*** 发送手机验证码*/PostMapping("code")public Result sendCode(RequestParam("phone") String phone, HttpSession session) {// 发送短信…

支持Upsert、Kafka Connector、集成Airbyte,Milvus助力高效数据流处理

Milvus 已支持 Upsert、 Kafka Connector、Airbyte! 在上周的文章中《登陆 Azure、发布新版本……Zilliz 昨夜今晨发生了什么?》,我们已经透露过 Milvus(Zilliz Cloud)为提高数据流处理效率, 先后支持了 Up…

JOSEF约瑟 DY-34 型电压继电器,15-30V 柜内安装,板前接线

DY-30系列电压继电器 DY-32电压继电器; DY-36电压继电器; DY-33电压继电器; DY-37电压继电器; DY-34电压继电器; DY-38电压继电器; DY-31电压继电器; DY-35电压继电器; DY-32/60C电压…

HarmonyOS——解决本地模拟器无法选择设备的问题

在使用deveco studio进行鸿蒙开发的时候,可能会遇到本地模拟器已经启动了,但是仍然无法选择本地模拟器中的设备,尤其在MAC环境中尤为常见。 解决办法: 先打开IDE启动本地模拟器,等模拟器启动后,退出IDE重新…

蓝桥杯每日一题2023.12.1

题目描述 蓝桥杯大赛历届真题 - C 语言 B 组 - 蓝桥云课 (lanqiao.cn) 题目分析 对于此题目而言思路较为重要&#xff0c;实际可以转化为求两个数字对应的操作&#xff0c;输出最前面的数字即可 #include<bits/stdc.h> using namespace std; int main() {for(int i 1…

ARM64版本的chrome浏览器安装

这一快比较玄学&#xff0c;花个半个小时左右才能安装好&#xff0c;也不知道是个什么情况。 sudo snap install chromium只需要以上这个命令&#xff0c;当然&#xff0c;也可以自己去找安装包进行安装&#xff0c;但是测试后发现并没有那么好装&#xff0c;主要是两个部分 一…

第九节HarmonyOS 常用基础组件-Text

一、组件介绍 组件&#xff08;Component&#xff09;是界面搭建与显示的最小单位&#xff0c;HarmonyOS ArkUI声名式为开发者提供了丰富多样的UI组件&#xff0c;我们可以使用这些组件轻松的编写出更加丰富、漂亮的界面。 组件根据功能可以分为以下五大类&#xff1a;基础组件…

Unity中Shader指令优化

文章目录 前言一、解析一下不同运算所需的指令数1、常数基本运算2、变量基本运算3、条件语句、循环 和 函数 前言 上一篇文章中&#xff0c;我们解析了Shader解析后的代码。我们在这篇文章中来看怎么实现Shader指令优化 Unity中Shader指令优化&#xff08;编译后指令解析&…
最新文章