8. 损失函数与反向传播

8.1 损失函数

① Loss损失函数一方面计算实际输出和目标之间的差距。

② Loss损失函数另一方面为我们更新输出提供一定的依据。

8.2 L1loss损失函数 

 ① L1loss数学公式如下图所示,例子如下下图所示。

import torch
from torch.nn import L1Loss
inputs = torch.tensor([1,2,3],dtype=torch.float32)
targets = torch.tensor([1,2,5],dtype=torch.float32)
inputs = torch.reshape(inputs,(1,1,1,3))
targets = torch.reshape(targets,(1,1,1,3))
loss = L1Loss()  # 默认为 maen
result = loss(inputs,targets)
print(result)

结果:

tensor(0.6667)
import torch
from torch.nn import L1Loss
inputs = torch.tensor([1,2,3],dtype=torch.float32)
targets = torch.tensor([1,2,5],dtype=torch.float32)
inputs = torch.reshape(inputs,(1,1,1,3))
targets = torch.reshape(targets,(1,1,1,3))
loss = L1Loss(reduction='sum') # 修改为sum,三个值的差值,然后取和
result = loss(inputs,targets)
print(result)

结果:

tensor(2.)

8.3  MSE损失函数

 ① MSE损失函数数学公式如下图所示。

 

import torch
from torch.nn import L1Loss
from torch import nn
inputs = torch.tensor([1,2,3],dtype=torch.float32)
targets = torch.tensor([1,2,5],dtype=torch.float32)
inputs = torch.reshape(inputs,(1,1,1,3))
targets = torch.reshape(targets,(1,1,1,3))
loss_mse = nn.MSELoss()
result_mse = loss_mse(inputs,targets)
print(result_mse)

结果:

tensor(1.3333)

 8.4 交叉熵损失函数

① 交叉熵损失函数数学公式如下图所示。

 

 

import torch
from torch.nn import L1Loss
from torch import nn

x = torch.tensor([0.1,0.2,0.3])
y = torch.tensor([1])
x = torch.reshape(x,(1,3)) # 1的 batch_size,有三类
loss_cross = nn.CrossEntropyLoss()
result_cross = loss_cross(x,y)
print(result_cross)

结果:

tensor(1.1019)

 8.5 搭建神经网络

import torch
import torchvision
from torch import nn 
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

dataset = torchvision.datasets.CIFAR10("./dataset",train=False,transform=torchvision.transforms.ToTensor(),download=True)       
dataloader = DataLoader(dataset, batch_size=1,drop_last=True)

class Tudui(nn.Module):
    def __init__(self):
        super(Tudui, self).__init__()        
        self.model1 = Sequential(
            Conv2d(3,32,5,padding=2),
            MaxPool2d(2),
            Conv2d(32,32,5,padding=2),
            MaxPool2d(2),
            Conv2d(32,64,5,padding=2),
            MaxPool2d(2),
            Flatten(),
            Linear(1024,64),
            Linear(64,10)
        )
        
    def forward(self, x):
        x = self.model1(x)
        return x
    
tudui = Tudui()
for data in dataloader:
    imgs, targets = data
    outputs = tudui(imgs)
    print(outputs)
    print(targets)

结果:

 8.6 数据集计算损失函数

 

import torch
import torchvision
from torch import nn 
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

dataset = torchvision.datasets.CIFAR10("./dataset",train=False,transform=torchvision.transforms.ToTensor(),download=True)       
dataloader = DataLoader(dataset, batch_size=64,drop_last=True)

class Tudui(nn.Module):
    def __init__(self):
        super(Tudui, self).__init__()        
        self.model1 = Sequential(
            Conv2d(3,32,5,padding=2),
            MaxPool2d(2),
            Conv2d(32,32,5,padding=2),
            MaxPool2d(2),
            Conv2d(32,64,5,padding=2),
            MaxPool2d(2),
            Flatten(),
            Linear(1024,64),
            Linear(64,10)
        )
        
    def forward(self, x):
        x = self.model1(x)
        return x
    
loss = nn.CrossEntropyLoss() # 交叉熵    
tudui = Tudui()
for data in dataloader:
    imgs, targets = data
    outputs = tudui(imgs)
    result_loss = loss(outputs, targets) # 计算实际输出与目标输出的差距
    print(result_loss)

结果:

 8.7 损失函数反向传播

① 反向传播通过梯度来更新参数,使得loss损失最小,如下图所示。

 

import torch
import torchvision
from torch import nn 
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

dataset = torchvision.datasets.CIFAR10("./dataset",train=False,transform=torchvision.transforms.ToTensor(),download=True)       
dataloader = DataLoader(dataset, batch_size=64,drop_last=True)

class Tudui(nn.Module):
    def __init__(self):
        super(Tudui, self).__init__()        
        self.model1 = Sequential(
            Conv2d(3,32,5,padding=2),
            MaxPool2d(2),
            Conv2d(32,32,5,padding=2),
            MaxPool2d(2),
            Conv2d(32,64,5,padding=2),
            MaxPool2d(2),
            Flatten(),
            Linear(1024,64),
            Linear(64,10)
        )
        
    def forward(self, x):
        x = self.model1(x)
        return x
    
loss = nn.CrossEntropyLoss() # 交叉熵    
tudui = Tudui()
for data in dataloader:
    imgs, targets = data
    outputs = tudui(imgs)
    result_loss = loss(outputs, targets) # 计算实际输出与目标输出的差距
    result_loss.backward()  # 计算出来的 loss 值有 backward 方法属性,反向传播来计算每个节点的更新的参数。这里查看网络的属性 grad 梯度属性刚开始没有,反向传播计算出来后才有,后面优化器会利用梯度优化网络参数。      
    print("ok")

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/99069.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

visual studio编写DLL,python调用

选择第一个c DLL&#xff0c; 然后项目源文件下右击新建项&#xff0c;这里名字随便取&#xff0c;在代码中输入一下内容&#xff1a; #include <iostream>#define EXPORT extern "C" __declspec(dllexport)EXPORT int sub(int a, int b) {return a - b; } 在…

【GO】LGTM_Grafana_Tempo(1)_架构

最近在尝试用 LGTM 来实现 Go 微服务的可观测性&#xff0c;就顺便整理一下文档。 Tempo 会分为 4 篇文章&#xff1a; Tempo 的架构官网测试实操跑通gin 框架发送 trace 数据到 tempogo-zero 微服务框架使用发送数据到 tempo 第一篇是关于&#xff0c;tempo 的架构&#xff…

R语言对综合社会调查GSS数据进行自举法bootstrap统计推断、假设检验、探索性数据分析可视化|数据分享...

全文链接&#xff1a;https://tecdat.cn/?p33514 综合社会调查&#xff08;GSS&#xff09;是由国家舆论研究中心开展的一项观察性研究。自 1972 年以来&#xff0c;GSS 一直通过收集当代社会的数据来监测社会学和态度趋势。其目的是解释态度、行为和属性的趋势和常量。从 197…

Docsify + Gitalk详细配置过程讲解

&#x1f496; 作者简介&#xff1a;大家好&#xff0c;我是Zeeland&#xff0c;开源建设者与全栈领域优质创作者。&#x1f4dd; CSDN主页&#xff1a;Zeeland&#x1f525;&#x1f4e3; 我的博客&#xff1a;Zeeland&#x1f4da; Github主页: Undertone0809 (Zeeland)&…

前端:html实现页面切换、顶部标签栏(可删、可切换,点击左侧超链接出现标签栏)

一、在一个页面&#xff08;不跨页面&#xff09; 效果&#xff1a; 代码 <!DOCTYPE html> <html><head><style>/* 设置标签页外层容器样式 */.tab-container {width: 100%;background-color: #f1f1f1;overflow: hidden;}/* 设置标签页选项卡的样式 …

Linux socket网络编程实战(tcp)实现双方聊天

在上节已经系统介绍了大致的流程和相关的API&#xff0c;这节就开始写代码&#xff01; 回顾上节的流程&#xff1a; 创建一个NET文件夹 来存放网络编程相关的代码&#xff1a; tcp服务端代码初步实现--上 这部分先实现服务器的连接部分的代码并进行验证 server1.c&#xff…

如何在小红书进行学习直播

诸神缄默不语-个人CSDN博文目录 因为我是从B站开始的&#xff0c;所以一些直播常识型的东西请见我之前写的如何在B站进行学习直播这一篇。 本篇主要介绍一些小红书之与B站不同之处。 小红书在手机端是可以直接点击“”选择直播的。 文章目录 1. 电脑直播-小红书直播软件2. 电…

二叉树题目:二叉树的右视图

文章目录 题目标题和出处难度题目描述要求示例数据范围 解法一思路和算法代码复杂度分析 解法二思路和算法代码复杂度分析 题目 标题和出处 标题&#xff1a;二叉树的右视图 出处&#xff1a;199. 二叉树的右视图 难度 4 级 题目描述 要求 给定二叉树的根结点 root \t…

DHCP中继实验

文章目录 一、实验背景与目的二、实验拓扑三、实验需求四、实验解法1. 配置IP地址2.配置R1为DHCP服务器&#xff0c;能够跨网段为192.168.2.0/24网段自动分配IP地址3. 在PC3上Ping 192.168.1.1&#xff0c;确认可以Ping通 摘要&#xff1a; 本实验旨在通过配置DHCP中继实现跨网…

C++项目:网络版本在线五子棋对战

目录 1.项目介绍 2.开发环境 3.核心技术 4. 环境搭建 5.websocketpp 5.1原理解析 5.2报文格式 5.3websocketpp常用接口介绍 5.4websocket服务器 6.JsonCpp使用 6.1Json数据格式 6.2JsonCpp介绍 7.MySQL API 7.1MySQL API介绍 7.2MySQL API使用 7.3实现增删改查…

降噪音频转录 Krisp: v1.40.7 Crack

主打人工智能降噪服务的初创公司「Krisp」近期宣布推出音频转录功能&#xff0c;能对电话和视频会议进行实时设备转录。该软件还整合的ChatGPT&#xff0c;以便快速总结内容&#xff0c;开放测试版于今天上线。 随着线上会议越来越频繁&#xff0c;会议转录已成为团队工作的重…

微服务系统面经之二: 以秒杀系统为例

16 微服务与集群部署 16.1 一个微服务一般会采用集群部署吗&#xff1f; 对于一个微服务是否采用集群部署&#xff0c;这完全取决于具体的业务需求和系统规模。如果一个微服务的访问压力较大&#xff0c;或者需要提供高可用性&#xff0c;那么采用集群部署是一种常见的策略。…

C语言之函数题

目录 1.乘法口诀表 2.交换两个整数 3.函数判断闰年 4.函数判断素数 5.计算斐波那契数 6.递归实现n的k次方 7.计算一个数的每位之和&#xff08;递归&#xff09; 8.字符串逆序&#xff08;递归实现&#xff09; 9.strlen的模拟&#xff08;递归实现&#xff09; 10.求…

NoSQL MongoDB Redis E-R图 UML类图概述

NoSQL NoSQL(Not only SQL)是对不同于传统的关系数据库的数据库管理系统的统称&#xff0c;即广义地来说可以把所有不是关系型数据库的数据库统称为NoSQL。 NoSQL 数据库专门构建用于特定的数据模型&#xff0c;并且具有灵活的架构来构建现代应用程序。NoSQL 数据库使用各种数…

这是一条求助贴(postman测试的时候一直是404)

看到这个问题是404的时候总感觉不该求助大家&#xff0c;404多常见一看就是简单的路径问题&#xff0c;我的好像不是&#xff0c;我把我的问题奉上。 首先我先给出我的url http://10.3.22.195:8080/escloud/rest/escloud_contentws/permissionStatistics/jc-haojl/sz 这是我…

抖音电商,提前批offer!

南京夫子庙茶颜悦色店 摄于2023.8.27 小伙伴们大家好&#xff0c;我是阿秀。 互联网圈有个梗就是"两大码农工厂&#xff1a;南华科、北北邮"&#xff0c;就是说这两所高校的毕业生从事互联网工作的特别多&#xff0c;北邮虽然是211&#xff0c;但在互联网圈子里比很多…

Qt5升级到Qt6分步迁移教程

Qt框架的一个新的长期支持版本6.5最近发布。它为以前的版本引入了许多修复、改进和新功能。有些可能对您的应用程序有用&#xff08;如果不是现在&#xff0c;可能会在将来&#xff09;&#xff0c;因此最好将应用程序迁移到最新版本的框架。 仍然有许多应用程序仍在使用Qt 5&…

瑞芯微:基于RK3568得人脸朝向检测

驾驶员监控系统是基于驾驶员面部图像处理来研究驾驶员状态的实时系统。首先挖掘出人在疲劳状态下的表情特征&#xff0c;然后将这些定性的表情特征进行量化&#xff0c;提取出面部特征点及特征指标作为判断依据&#xff0c;再结合实验数据总结出基于这些参数的识别方法&#xf…

git clone 报SSL证书问题

git命令下运行 git config --global http.sslVerify false 然后再进行重新clone代码

Web安全——信息收集下篇

Web安全 一、网络空间搜索引擎二、扫描敏感目录/文件1、御剑2、7kbstorm3、bbscan4、dirmap5、dirsearch6、gobuster7、网站文件 三、扫描网页备份四、网站头信息收集五、敏感文件搜索1、GitHub搜索2、Google-hacking3、wooyun漏洞库4、网盘搜索5、社工库6、网站注册信息7、js敏…