【神经网络】09 - 优化器 torch.optim

09 - 优化器 torch.optim

概念

优化器是用于更新和管理模型参数以改进模型性能的一种工具。在机器学习和深度学习中,优化器的主要任务是通过调整模型的参数来最小化或最大化特定的目标函数。在许多情况下,这个目标函数是一个损失函数,我们的目标是将其最小化。

在 PyTorch 中,torch.optim 是一个包含了各种优化算法的模块,用于帮助开发者在训练神经网络时更方便地更新和管理模型参数。

当我们训练神经网络时,我们的目标是找到一组模型参数,可以最小化或最大化某个损失函数或者目标函数。优化算法或者优化器(Optimizer)是用于实现这个过程的工具。

torch.optim 中包括了很多常用的优化算法,如随机梯度下降(SGD)、Adam、Adagrad、RMSprop等。这些优化器通常都会实现一个 step() 方法来更新模型的参数。

一些常用的优化器:

  • 随机梯度下降(SGD):SGD是最基础和最经常使用的优化器之一。它通过计算每一个训练样本或训练批次的损失函数的梯度,并沿着梯度的反方向更新模型参数以最小化损失。
  • Momentum:融入动量的随机梯度下降,意味着在更新参数时,不仅考虑当前的梯度,还会考虑前一次的梯度方向,因此具有“保持动量”并更快速地收敛到最小值的效果。
  • Adagrad, RMSprop, Adam:这些是更高级的优化器,它们可以自动调节学习率——每个参数都有自己的学习率。有的优化器(如Adam)甚至综合了动量和自适应学习率的优点。

关键的通用参数:

  1. params(必须):是一个可迭代的对象,通常你会传入模型的参数,通常可以通过 model.parameters() 获取。这个参数告诉优化器应该更新哪些张量。
为什么 model.parameters() 可以获取 params?

GPT4的回答:

在 PyTorch 中,model.parameters() 方法被用来获取模型中所有可训练的参数。这些参数,例如权重和偏置,是在训练过程中通过优化算法进行更新的值。

首先,我们需要明白,在 PyTorch 中,模型(通常是 nn.Module 的子类)是由一系列的层构成的,而每一层又有它自己的参数。例如,线性层 nn.Linear 有权重和偏置作为参数,卷积层 nn.Conv2d 也有权重和偏置。

在定义模型时,这些层通常被定义为模型的属性。因此,当你调用 model.parameters() 时,PyTorch 会去遍历每一层,并收集所有的参数。

在优化器中,我们需要告诉优化器,应该对哪些张量进行优化(即应用梯度更新)。model.parameters() 提供了一种简便的方式来收集所有这些张量,而无需手动指定。

具体来说,model.parameters() 返回一个生成器,生成包含模型所有参数(权重和偏置)的张量,这些参数就是我们希望优化器能够更新的值。

  1. lr(必须):学习率,它决定了参数更新的步长。学习率太小可能导致训练过程慢而无法收敛,太大可能导致收敛不稳定。

示例

import torch
import torchvision
from torch import nn
from torch.nn import Sequential, Conv2d, MaxPool2d, Flatten, Linear
from torch.utils.data import DataLoader

dataset = torchvision.datasets.CIFAR10(root="./dataset", train=True, download=True,
                                       transform=torchvision.transforms.ToTensor())

dataloader = DataLoader(dataset, batch_size=32)


class MyNet(nn.Module):
    def __init__(self):
        # nn.Module.__init__(self)
        super(MyNet, self).__init__()
        self.model1 = Sequential(
            Conv2d(3, 32, 5, padding=2),
            MaxPool2d(2),
            Conv2d(32, 32, 5, padding=2),
            MaxPool2d(2),
            Conv2d(32, 64, 5, padding=2),
            MaxPool2d(2),
            Flatten(),
            Linear(1024, 64),
            Linear(64, 10)
        )

    def forward(self, x):
        x = self.model1(x)
        return x

# 定义损失函数为交叉熵损失,它通常用于分类任务
loss = nn.CrossEntropyLoss()
net = MyNet()
# 定义优化器为随机梯度下降(SGD),学习率设置为0.01。
# 优化器会更新传入的 net.parameters() 中的参数
optimizer = torch.optim.SGD(net.parameters(), lr=0.01)
for data in dataloader:
    imgs, targets = data
    outputs = net(imgs)
    # 计算此次预测的损失值,预测结果outputs与真实targets的交叉熵损失。
    res_loss = loss(outputs, targets)
    # 清零之前反向传播过程中累积的梯度
    optimizer.zero_grad()
    # 对损失进行反向传播,计算每个参数的梯度
    res_loss.backward()
    # 根据计算得到的梯度更新网络参数,更新步骤由前面定义的优化器optimizer完成
    optimizer.step()

该训练循环在每一个大小为32的批次(batch)上进行了完整的前向传播(获取预测和损失)、反向传播(计算梯度)以及更新权重的操作。这个过程会持续进行,直到遍历完 dataloader 中所有的数据。

这个过程完成一次被称为一个训练周期(epoch)。如果要进行多个训练周期(epoch),可以将这个训练循环放在一个外部循环中,每个训练周期都遍历一次所有的数据。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/605337.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

javax.net.ssl.SSLException: Received fatal alert: protocol_version已经解决

起因: 在帮别人讲解项目时,将项目的tomcat配置完,点击运行后,报错,信息如标题。 解决办法: 在csdn百度问题,得到的方法主要有几个: 1.jdk要配置在1.8以上; 2.数据库地…

【MySQL】ON WHERE 和 ON AND 的区别

1. 查询语句语法规则 “[ ]” 包含的内容可以省略; “{ }” 包含的内容必须存在; DISTINCT: 设定 **distinct** 可以去掉重复记录; AS: 表明或字段名过长时,可以用 **AS** 关键字起别名,也可…

06.配置邮件报警

配置邮件报警 我的授权码:HCHNVOAENURLOACG 1.定义发件人 密码是163邮箱的授权码 2.配置收件人 我就配置收件人是qq邮箱了 3.启动动作 验证邮件发送成功

Redis如何避免数据丢失?——AOF

目录 AOF日志 1. 持久化——命令写入到AOF文件 写到用户缓冲区 AOF的触发入口函数——propagate 具体的实现逻辑——feedAppendOnlyFile 从用户缓冲区写入到AOF文件(磁盘) 函数write、fsync、fdatasync Redis的线程池 AOF文件的同步策略 触发的入口函数——…

特斯拉擎天柱机器人:工厂自动化的未来

随着技术的进步,工业自动化已经逐步进入了一个新的纪元。特斯拉最近公布的擎天柱机器人Optimus的演示,不仅仅展示了一个高科技机器人的能力,更是向我们揭示了未来工厂的可能性。 特斯拉擎天柱机器人的功能展示 马斯克在最新的演示中向我们展…

使用Nuxt.js实现服务端渲染(SSR)

Nuxt.js 是一个基于 Vue.js 的框架,它提供了服务器端渲染(SSR)和静态站点生成(SSG)的能力,使开发者能够轻松地构建高效、优雅的前端应用。Nuxt.js 集成了许多开箱即用的功能和工具,帮助开发者快…

C语言—深入理解指针(2)

1.数组名的理解 不难发现,数组名就是数组首元素的地址。 但是有两个例外: 1.sizeof(数组名) 这里的数组名表示整个数组,计算的是整个数组的大小,单位是字节。 2.&数组名 这里的数组名也表示整个数…

MacOS miniconda安装方法

打开macos “终端” 应用 执行命令 mkdir -p ~/miniconda3curl https://mirrors.tuna.tsinghua.edu.cn/anaconda/miniconda/Miniconda3-latest-MacOSX-arm64.sh -o ~/miniconda3/miniconda.shbash ~/miniconda3/miniconda.sh -b -u -p ~/miniconda3rm -rf ~/miniconda3/mini…

CPU基本知识点

目录 1.概念 2.分类 3.运作原理 4.指令系统 1.概念 CPU:英文Central Processing Unit,即中央处理器。 解释和执行指令的功能单元,它是计算机的中枢神经系统(即核心)。 是计算机最核心的部件,主要是运算…

嵌入式数据库SQLite 3配置使用详细笔记教程

0、惨痛教训 随着管理开发的项目体积越来越庞大,产品系统涉及的数据量也越来越多,并且伴随着项目不久就要交付给甲方了。如果项目的数据信息没有被妥善管理,后期设备的运行状态、操作状况等数据流信息不能被溯源,当出现了一些特殊…

【35分钟掌握金融风控策略16】贷前风控策略详解-1

目录 贷前风控策略详解 贷前风控目标 精准审核申请贷款客户资质 对申请贷款客户进行合理定额 对申请贷款客户进行合理定价 推动实现利润最大化 贷前风控数据源 客户贷款时提供的数据 贷前风控策略详解 俗话说,良好的开端是成功的一半,而贷前是风…

C++新手村指南:入门基础

目录 C概念 C发展史 C关键字(C98) 命名空间 命名空间的定义 命名空间的使用 C中的输入&&输出 缺省参数 缺省参数的概念 缺省参数的分类 函数重载 函数重载概念 函数重载实现 引用 引用的概念 引用的特性 常引用 引用的使用场景…

基于单片机的小型自动浇灌系统设计

摘 要:以单片机为主控芯片,结合传感器和计算机,搭建了一套智能化的浇灌系统;利用LabVIEW 设计并编写了基于状态机程序架构的上位机软件,实现了友好的用户交互界面,实时测量、显示与记录等功能,并由主控芯片进行浇灌。经测试,本系统具有结构简单,研制成本低,运…

详细介绍一下PointPillars算法的网络结构

PointPillars是一种用于3D目标检测的算法,它主要使用了点云数据和深度学习模型。 PointPillars算法的网络结构主要可以分为三个主要阶段: Pillar Feature Net(点云特征处理网络):此阶段的主要任务是将输入的点云数据转…

回答篇:测试开发高频面试题目

引用之前文章:《测试开发高频面试题目》 https://blog.csdn.net/qq_41214208/article/details/138193469?spm1001.2014.3001.5502 本篇文章是回答篇(持续更新中) 1. 什么是测试开发以及其在软件开发流程中的作用。 a. 测试开发是指测试人员或…

Java:Servlet详解

目录 一、什么是Servlet 二、Servlet原理 Servlet的生命周期 三、 Servlet注释 WebServlet 一、什么是Servlet Servlet是JavaWeb开发的一种技术,Servlet程序需要部署在Servlet容器(服务端)中才能运行,常见的Servlet容器有Tom…

【C++】环境搭建CentOS Clion报错Unsupported git Version 1.8.3.1

【C】环境搭建Clion-Unsupported git Version 1.8.3.1 Git升级步骤1.卸载旧版本2.安装依赖3.下载git最新版本包4.解压git文件包5.编译文件5.将git加入环境变量6.验证git版本 如上图所示,报错Unsupported git Version 1.8.3.1 At least 2.17.0 is required 报错意思…

windows驱动开发-inf文件(一)

驱动总是和inf文件相关,在WinDDK的时候,许多inf文件都需要开发工程师手动编写,不过,现在已经可以使用inx文件来生成inf文件了,它经常用于驱动的安装和卸载;不过,并不是所有的驱动都需要使用inf文…

小白修复msvcp140.dll丢失的解决方法,一键修复丢失的dll文件

在我们使用电脑时,常常会碰到各种烦人的状况。比方说,当我们期待畅玩游戏时,可能会突然遭遇一则令人沮丧的提示:“打开游戏缺少msvcp140.dll文件”。这个问题会给我们带来困扰和不愉快,但庆幸的是,有多种解…

UE4_Water插件_Buoyancy组件使用

water插件提供了一个浮力Actor蓝图类。 需要注意的几个问题: 1、StaticMesh需要替换根组件。 2、需要模拟物理设置质量。 3、需要添加浮力组件,设置浮力点,应用水中牵引力。 4、最重要的是需要激活——自动启用。 5、调水波长的地方 双击图片…
最新文章