pytorch --反向传播和优化器

1. 反向传播

计算当前张量的梯度

Tensor.backward(gradient=None, retain_graph=None, create_graph=False, inputs=None)

计算当前张量相对于图中叶子节点的梯度。

使用反向传播,每个节点的梯度,根据梯度进行参数优化,最后使得损失最小化

代码:

import torch
import torchvision
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

dataset = torchvision.datasets.CIFAR10('data',train=False,transform=torchvision.transforms.ToTensor(),download=True)
dataloader = DataLoader(dataset,batch_size=1)



class Tudui(nn.Module):
    def __init__(self):
        super().__init__()
        # 另一种写法
        self.model1 = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=32, kernel_size=5, stride=1, padding=2),
            nn.MaxPool2d(2),
            nn.Conv2d(in_channels=32, out_channels=32, kernel_size=5, stride=1, padding=2),
            nn.MaxPool2d(2),
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5, stride=1, padding=2),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(in_features=1024, out_features=64),
            nn.Linear(in_features=64, out_features=10)
        )

    def forward(self,x):
        # sequential方式
        x = self.model1(x)
        return x
loss = nn.CrossEntropyLoss()
tudui = Tudui()
for data in dataloader:
    imgs,target = data
    outputs= tudui(imgs)
    result_loss = loss(outputs,target)
    result_loss.backward() # 梯度
    print(result_loss)

2.优化器 (以随机梯度下降算法为例)
将上一步的梯度清零
params ,lr(学习率)
随机梯度下降SGD
torch.optim.SGD(params,
lr=,
momentum=0, dampening=0, weight_decay=0, nesterov=False, *, maximize=False, foreach=None, differentiable=False)

代码:

import torch
import torchvision
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

dataset = torchvision.datasets.CIFAR10('data',train=False,transform=torchvision.transforms.ToTensor(),download=True)
dataloader = DataLoader(dataset,batch_size=1)



class Tudui(nn.Module):
    def __init__(self):
        super().__init__()
        # 另一种写法
        self.model1 = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=32, kernel_size=5, stride=1, padding=2),
            nn.MaxPool2d(2),
            nn.Conv2d(in_channels=32, out_channels=32, kernel_size=5, stride=1, padding=2),
            nn.MaxPool2d(2),
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5, stride=1, padding=2),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(in_features=1024, out_features=64),
            nn.Linear(in_features=64, out_features=10)
        )

    def forward(self,x):
        # sequential方式
        x = self.model1(x)
        return x
loss = nn.CrossEntropyLoss()
tudui = Tudui()
optim = torch.optim.SGD(tudui.parameters(),lr=0.01) # params,lr
for epoch in range(5):# 在整个数据集上训练5次
    running_loss = 0
    #对数据进行一轮学习
    for data in dataloader:
        imgs,target = data
        outputs= tudui(imgs)
        result_loss = loss(outputs,target)
        optim.zero_grad() # 将上一步的梯度清零
        result_loss.backward() # 梯度
        optim.step() # 根据梯度修改参数
        # print(result_loss)
        running_loss = running_loss + result_loss
    print(running_loss)

输出
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/413295.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

IDEA的LeetCode插件的设置

一、下载插件 选择点击File->Setting->Plugins:搜索LeetCode 二、打开这个插件 选择View —>Tool Windows—>leetcode 三、登陆自己的账号 关于下面几个参数的定义,官方给的是: Custom code template: 开启使用自定义模板&…

搜维尔科技:OptiTrack 提供了性能最佳的动作捕捉平台

OptiTrack 动画 我们的 Prime 系列相机和 Motive 软件相结合,产生了世界上最大的捕获量、最精确的 3D 数据和有史以来最高的相机数量。OptiTrack 提供了性能最佳的动作捕捉平台,具有易于使用的制作工作流程以及运行世界上最大舞台所需的深度。 无与伦比…

【k8s配置与存储--持久化存储(PV、PVC、存储类)】

1、PV与PVC 介绍 持久卷(PersistentVolume,PV) 是集群中的一块存储,可以由管理员事先制备, 或者使用存储类(Storage Class)来动态制备。 持久卷是集群资源,就像节点也是集群资源一样…

SpringCloud微服务-Eureka注册中心

Eureka注册中心 文章目录 Eureka注册中心前言1、Eureka的作用2、搭建EurekaServer3、服务注册4、启动多个实例5、服务拉取 -实现负载均衡 前言 在服务调用时产生的问题: //2. 利用RestTemplate发起HTTP请求,查询user String url "http://localho…

【Unity自制手册】Unity—Camera相机跟随的方法大全

👨‍💻个人主页:元宇宙-秩沅 👨‍💻 hallo 欢迎 点赞👍 收藏⭐ 留言📝 加关注✅! 👨‍💻 本文由 秩沅 原创 👨‍💻 收录于专栏:Uni…

小甲鱼Python06 序列字典集合

一、序列 1.id函数 is运算符 我们首先思考下字符串、元组、列表的共同点: 都有很多共同的运算符。都可以通过索引来获取元素,第一个元素索引都是0,都可以通过切片的方法获取某个范围内元素的集合。 以上三种统称为序列。序列分为可变序列…

安装极狐GitLab Runner并测试使用

本文继【新版极狐安装配置详细版】之后继续 1. 添加官方极狐GitLab 仓库: 对于 RHEL/CentOS/Fedora: curl -L "https://packages.gitlab.com/install/repositories/runner/gitlab-runner/script.rpm.sh" | sudo bash2. 安装最新版本的极狐G…

多进程服务端进程框架

对父进程来说,不需要客户端连接的socket,对子进程来说,不需要监听的socket。代码中标出了两行关闭对应socket的函数 上图可见,对应的关闭了以后(3是服务端监听的socket,4是客户端连上来的socket) 通过上图…

【GPTs分享】每日GPTs分享之Image Generator Tool

今日GPTs分享:Image Generator Tool。Image Generator Tool是一种基于人工智能的创意辅助工具,专门设计用于根据文字描述生成图像。这款工具结合了专业性与友好性,鼓励用户发挥创造力,同时提供高效且富有成效的交互体验。 主要功能…

10:00面试,10:05就出来了,问的问题过于变态了。。。

我从一家小公司转投到另一家公司,期待着新的工作环境和机会。然而,新公司的加班文化让我有些始料未及。虽然薪资相对较高,但长时间的工作和缺乏休息使我身心俱疲。 就在我逐渐适应这种高强度的工作节奏时,公司突然宣布了一则令人…

C++ Webserver从零开始:代码书写(十四)——http连接处理

前言 HTTP类是Webserver到目前为止最为庞大的类。其实最开始我是只想分析它的部分代码,但是最后我还咬咬牙将http连接处理的全代码分析写完了。因此,本文会特别的长,我相信没人可以把它一口气全部读完。不过我在本文中进行了细致的目录划分&a…

部署roop实现视频人脸替换

roop只需要一张人脸的图像,就可以替换视频中的脸。不需要数据集和模型训练。 下载对应版本的cudnn https://developer.nvidia.com/rdp/cudnn-archivehttps://developer.nvidia.com/rdp/cudnn-archive解压后的三个文件夹拷贝到cuda的目录 C:\Program Files\NVIDIA…

【非递归版】归并排序算法(2)

目录 MergeSortNonR归并排序 非递归&归并排序VS快速排序 整体思想 图解分析​ 代码实现 时间复杂度 归并排序在硬盘上的应用(外排序) MergeSortNonR归并排序 前面的快速排序的非递归实现,我们借助栈实现。这里我们能否也借助栈去…

uniapp实现单选框

采用uniapp-vue3实现的一款单选框组件,提供丝滑的动画选中效果,支持不同主题配置,适配web、H5、微信小程序(其他平台小程序未测试过,可自行尝试) 可到插件市场下载尝试: https://ext.dcloud.net…

SpringCloudAlibaba全家桶介绍

Spring Cloud Alibaba Spring Cloud Alibaba 是什么?微服务全景图核心特色 大家好,我叫阿明。下面我会为大家准备Spring Cloud Alibaba系列知识体系,结合实战输出案列,让大家一眼就能明白得技术原理,应用于各公司得各…

二次供水物联网:HiWoo Cloud助力城市水务管理升级

随着城市化的快速推进,二次供水系统作为城市基础设施的重要组成部分,其稳定运行和高效管理显得至关重要。然而,传统的二次供水管理方式在应对复杂多变的城市供水需求时,显得力不从心。为了破解这一难题,HiWoo Cloud平台…

VsCode的leetcode插件无法登录

前提 想使用VsCode的leetcode插件进行刷题,然后按照网上的教程进行安装下载,但是到了登录这一步,死活也登录不了,然后查看log一直报的错误是invalid password。 解决方法 首先确定在插件中设置的站点是Leetcode中国&#xff0c…

【Java EE初阶二十五】简单的表白墙(一)

1. 前端部分 1.1 前端代码 <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta http-equiv"X-UA-Compatible" content"IEedge"><meta name"viewport" content"wid…

继电器测试中需要注意的安全事项有哪些?

继电器广泛应用于电气控制系统中的开关元件&#xff0c;其主要功能是在输入信号的控制下实现输出电路的断开或闭合。在继电器测试过程中&#xff0c;为了确保测试的准确性和安全性&#xff0c;需要遵循一定的安全事项。以下是在进行继电器测试时需要注意的安全事项&#xff1a;…

【代码随想录python笔记整理】第十三课 · 链表的基础操作 1

前言:本笔记仅仅只是对内容的整理和自行消化,并不是完整内容,如有侵权,联系立删。 一、链表 在之前的学习中,我们接触到了字符串和数组(列表)这两种结构,它们具有着以下的共同点:1、元素按照一定的顺序来排列。2、可以通过索引来访问数组中的元素和字符串中的字符。由此,…
最新文章