【Pytorch】学习记录分享3——PyTorch 自动微分与线性回归

【【Pytorch】学习记录分享3——PyTorch 自动微分与线性回归

      • 1. autograd 包,自动微分
      • 2. 线性模型回归演示
      • 3. GPU进行模型训练

小结:只需要将前向传播设置好,调用反向传播接口,即可实现反向传播的链式求导

1. autograd 包,自动微分

自动微分是机器学习工具包必备的工具,它可以自动计算整个计算图的微分。

PyTorch内建了一个叫做torch.autograd的自动微分引擎,该引擎支持的数据类型为:浮点数Tensor类型 ( half, float, double and bfloat16) 和复数Tensor 类型(cfloat, cdouble)

PyTorch中与自动微分相关的常用的Tensor属性和函数:

属性requires_grad:
默认值为False,表明该Tensor不会被自动微分引擎计算微分。设置为True,表明让自动微分引擎计算该Tensor的微分
属性grad:存储自动微分的计算结果,即调用backward()方法后的计算结果
方法backward(): 计算微分,一般不带参数,等效于:backward(torch.tensor(1.0))。若backward()方法在DAG的root上调用,它会依据链式法则自动计算DAG所有枝叶上的微分。
方法no_grad():禁用自动微分上下文管理, 一般用于模型评估或推理计算这些不需要执行自动微分计算的地方,以减少内存和算力的消耗。另外禁止在模型参数上自动计算微分,即不允许更新该参数,即所谓的冻结参数(frozen parameters)。
zero_grad()方法:PyTorch的微分是自动积累的,需要用zero_grad()方法手动清零

# 模型:z = x@w + b;激活函数:Softmax
x = torch.ones(5)  # 输入张量,shape=(5,)
labels = torch.zeros(3) # 标签值,shape=(3,)
w = torch.randn(5,3,requires_grad=True) # 模型参数,需要计算微分, shape=(5,3)
b = torch.randn(3, requires_grad=True)  # 模型参数,需要计算微分, shape=(3,)
z = x@w + b # 模型前向计算
outputs = torch.nn.functional.softmax(z) # 激活函数
print("z: ",z)
print("outputs: ",outputs)
loss = torch.nn.functional.binary_cross_entropy(outputs, labels)
# 查看loss函数的微分计算函数
print('Gradient function for loss =', loss.grad_fn)
# 调用loss函数的backward()方法计算模型参数的微分
loss.backward()
# 查看模型参数的微分值
print("w: ",w.grad)
print("b.grad: ",b.grad)

在这里插入图片描述

小姐:

方法描述
.requires_grad 设置为True会开始跟踪针对 tensor 的所有操作
.backward()张量的梯度将累积到 .grad 属性
import torch

x=torch.rand(1)
b=torch.rand(1,requires_grad=True)
w=torch.rand(1,requires_grad=True)
y = w * x
z = y + b

x.requires_grad, w.requires_grad,b.requires_grad,y.requires_grad,z.requires_grad

print("x: ",x, end="\n"),print("b: ",b ,end="\n"),print("w: ",w ,end="\n")
print("y: ",y, end="\n"),print("z: ",z, end="\n")

# 反向传播计算
z.backward(retain_graph=True) #注意:如果不清空,b每一次更新,都会自我累加起来,依次为1 2 3 4 。。。

w.grad
b.grad

运行结果:
在这里插入图片描述
反向传播求导原理:
在这里插入图片描述

2. 线性模型回归演示

import torch
import torch.nn as nn

## 线性回归模型: 本质上就是一个不加 激活函数的 全连接层
class LinearRegressionModel(nn.Module):
    def __init__(self, input_size, output_size):
        super(LinearRegressionModel, self).__init__()
        self.linear = nn.Linear(input_size, output_size)
        
    def forward(self, x):
        out = self.linear(x)
        return out
input_size = 1
output_size = 1

model = LinearRegressionModel(input_size, output_size)
model

# 指定号参数和损失函数
epochs = 500
learning_rate = 0.01
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
criterion = nn.MSELoss()

# train model
for epoch in range(epochs):
    epochs+=1
    #注意 将numpy格式的输入数据转换成 tensor
    inputs = torch.from_numpy(x_train)
    labels = torch.from_numpy(y_train)
    
    #每次迭代梯度清零
    optimizer.zero_grad()
    
    #前向传播
    outputs = model(inputs)
    
    #计算损失
    loss = criterion(outputs, labels)
    
    #反向传播
    loss.backward()
    
    #updates weight and parameters
    optimizer.step()
    if epoch % 50 == 0:
        print("Epoch: {}, Loss: {}".format(epoch, loss.item()))

# predict model test,预测结果并且奖结果转换成np格式
predicted =model(torch.from_numpy(x_train).requires_grad_()).data.numpy()
predicted

#model save
torch.save(model.state_dict(),'model.pkl')

#model 读取
model.load_state_dict(torch.load('model.pkl'))

在这里插入图片描述

3. GPU进行模型训练

只需要 将模型和数据传入到“cuda”中运行即可,详细实现见截图

import torch
import torch.nn as nn
import numpy as np

# #构建一个回归方程 y = 2*x+1

#构建输如数据,将输入numpy格式转成tensor格式
x_values = [i for i in range(11)]
x_train = np.array(x_values,dtype=np.float32)
x_train = x_train.reshape(-1,1)

y_values = [2*i + 1 for i in x_values]
y_train = np.array(y_values, dtype=np.float32)
y_train = y_train.reshape(-1,1)

## 线性回归模型: 本质上就是一个不加 激活函数的 全连接层
class LinearRegressionModel(nn.Module):
    def __init__(self, input_size, output_size):
        super(LinearRegressionModel, self).__init__()
        self.linear = nn.Linear(input_size, output_size)
        
    def forward(self, x):
        out = self.linear(x)
        return out
    
input_size = 1
output_size = 1

model = LinearRegressionModel(input_size, output_size)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# 指定号参数和损失函数
epochs = 500
learning_rate = 0.01
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
criterion = nn.MSELoss()

# train model
for epoch in range(epochs):
    epochs+=1
    #注意 将numpy格式的输入数据转换成 tensor
    inputs = torch.from_numpy(x_train)
    labels = torch.from_numpy(y_train)
    
    #每次迭代梯度清零
    optimizer.zero_grad()
    
    #前向传播
    outputs = model(inputs)
    
    #计算损失
    loss = criterion(outputs, labels)
    
    #反向传播
    loss.backward()
    
    #updates weight and parameters
    optimizer.step()
    if epoch % 50 == 0:
        print("Epoch: {}, Loss: {}".format(epoch, loss.item()))

# predict model test,预测结果并且奖结果转换成np格式
predicted = model(torch.from_numpy(x_train).requires_grad_()).data.numpy()
predicted

#model save
torch.save(model.state_dict(),'model.pkl')

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/245651.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

WPF仿网易云搭建笔记(6):Style进阶详解

文章目录 专栏和Gitee仓库前言Style简单使用样式字典全局样式局部全局样式全局样式穿透 专栏和Gitee仓库 WPF仿网易云 Gitee仓库 WPF仿网易云 CSDN博客专栏 前言 WPF想要批量设置样式属性,一共有3个方法 Style样式Template控件模板DataTemplate数据模板 WPF 零基础…

word四级目录序号不随上级目录序号变化问题解决方法

一、word中的几个元素简介 1、word中的列表 如下图所示,代表word的列表: 2、word中的标题 如下图所示,代表word的标题: 3、word中的编号/序号 如下图所示,代表word的编号/序号: 4、word中的目录 如下图…

【Python】人工智能-机器学习——不调库手撕深度网络分类问题

1. 作业内容描述 1.1 背景 数据集大小150该数据有4个属性,分别如下 Sepal.Length:花萼长度(cm)Sepal.Width:花萼宽度单位(cm)Petal.Length:花瓣长度(cm)Petal.Width:花瓣宽度(cm)category:类别&#xff0…

【一起学Rust | 框架篇 | Tauri2.0框架】Tauri App开启远程调试功能

文章目录 前言一、搭建PageSpy环境二、接入SDK三、进行远程调试调试控制台网络抓包审查元素 四、延伸 前言 Tauri在Rust圈内成名已久,凭借Rust的可靠性,使用系统原生的Webview构建更小的App 以及开发人员可以灵活的使用各种前端框架而一战成名。 然而&…

软考机考考试第一批经验分享

由于机考的特殊性,考试环境与传统笔试环境有所不同。下面是与考试环境相关的总结: 草稿纸:考场提供足够数量的草稿纸,每位考生都会分发一张白纸作为草稿纸。在草稿纸上需要写上准考证号。如果不够用,可以向监考老师再次…

RISCV中的寄存器操作

控制状态寄存器指令 (csrrc、csrrs、csrrw、csrrci、csrrsi、csrrwi), 使我们可以轻松地访问一些程序性能计数器。对于这些 64 位计数器, 我们一次可以读取 32 位。这些计数器包括了系统时间, 时钟周期以及执行的指令数目。 CSRRW 先读取寄存器的值:tCS…

ES6学习(三):Set和Map容器的使用

Set容器 set的结构类似于数组,但是成员是唯一且不会重复的。 创建的时候需要使用new Set([])的方法 创建Set格式数据 let set1 new Set([])console.log(set1, set1)let set2 new Set([1, 2, 3, 4, 5])console.log(set2, set2) 对比看看Set中唯一 let set3 new Set([1, 1,…

Lists.partition是如何实现懒加载的?

前言&#xff1a; 最近看到一篇文章&#xff0c;里面提及了google的common包下Lists.partition方法为懒加载&#xff0c;只有在遍历时才会真正分区。平时使用时并未感觉到,感觉有点好奇。特此将自己寻找的答案的过程整理记录下来。 源码&#xff1a; public static <T>…

云原生之深入解析K8s中的微服务项目设计与实现

一、微服务项目的设计 ① 微服务设计的思想 一个单片应用程序将被构建、测试并顺利地通过这些环境。事实证明&#xff0c;一旦投资于将生产路径自动化&#xff0c;那么部署更多的应用程序似乎就不再那么可怕了。请记住&#xff0c;CD的目标之一就是让部署变得无聊&#xff0c…

idea中定时+多数据源配置

因项目要求,需要定时从达梦数据库中取数据,并插入或更新到ORACLE数据库中 1.pom.xml <?xml version"1.0" encoding"UTF-8"?> <project xmlns"http://maven.apache.org/POM/4.0.0" xmlns:xsi"http://www.w3.org/2001/XMLSchema-…

多架构容器镜像构建实战

最近在一个国产化项目中遇到了这样一个场景&#xff0c;在同一个 Kubernetes 集群中的节点是混合架构的&#xff0c;也就是说&#xff0c;其中某些节点的 CPU 架构是 x86 的&#xff0c;而另一些节点是 ARM 的。为了让我们的镜像在这样的环境下运行&#xff0c;一种最简单的做法…

react+datav+echarts实现可视化数据大屏

&#x1f4d3;最近有点闲&#xff0c;就学习了下react&#xff0c;没想到就把react学完了&#xff0c;觉得还不错&#xff0c;就打算出一把reactdatav的简易版可视化数据大屏供大家做个参考。 &#x1f4d3;效果如下 1下载必要的框架 &#x1f4d3; react路由 npm install re…

0-50KHz频率响应模拟量高速信号隔离变送器

0-50KHz频率响应模拟量高速信号隔离变送器 型号&#xff1a;JSD TA-2322F系列 高速响应时间&#xff0c;频率响应时间快 特点&#xff1a; ◆小体积,低成本,标准 DIN35mm 导轨安装方式 ◆六端隔离(输入、输出、工作电源和通道间相互隔离) ◆高速信号采集 (-3dB,Min≤ 3.5 uS,订…

【Qt5】ui文件最后会变成头文件

2023年12月14日&#xff0c;周四下午 我也是今天下午偶然间发现这个的 在使用Qt的uic&#xff08;User Interface Compiler&#xff09;工具编译ui文件时&#xff0c;会生成对应的头文件。 在Qt中&#xff0c;ui文件是用于描述用户界面的XML文件&#xff0c;而头文件是用于在…

【JUC】二十九、synchronized锁升级之轻量锁与重量锁

文章目录 1、轻量锁2、轻量锁的作用3、轻量锁的加锁和释放4、轻量级锁的代码演示5、重量级锁6、重量级锁的原理7、锁升级和hashcode的关系8、锁升级和hashcode关系的代码证明9、synchronized锁升级的总结10、JIT编译器对锁的优化&#xff1a;锁消除和锁粗化11、结语 &#x1f4…

保障网络安全:了解威胁检测和风险评分的重要性

在当今数字时代&#xff0c;网络安全问题变得愈发突出&#xff0c;而及时发现和迅速应对潜在威胁成为保障组织信息安全的首要任务。令人震惊的是&#xff0c;根据2023年的数据&#xff0c;平均而言&#xff0c;检测到一次网络入侵的时间竟然长达207天。这引起了对安全策略和技术…

MPLS专线和互联网专线有什么区别?如何选择?

MPLS和互联网专线是什么&#xff1f; MPLS专线和互联网专线是企业网络连接的常见方式。MPLS专线基于多协议标签交换&#xff08;MPLS&#xff09;该技术利用专线连接两个或多个分支机构&#xff0c;提供高质量的数据传输服务。互联网专线是基于公共知识产权基础设施的连接方式…

Python实现高效摸鱼,批量识别银行卡号并自动写入Excel表格

前言 每当有新员工入职&#xff0c;人事小姐姐都要收集大量的工资卡信息&#xff0c;并且生成Excel文档&#xff0c;看到小姐姐这么辛苦&#xff0c;我就忍不住要去帮她了… 于是我用1行代码就实现了自动识别银行卡信息并且自动生成Excel文件&#xff0c;小姐姐当场就亮眼汪汪…

ChatGPT一周年,一图总结2023生成式AI里程碑大事件时间线

带你探索AI的无限可能&#xff01;AI一日&#xff0c;人间一年&#xff0c;这句话绝非空谈&#xff01; AI技术在不断地发展&#xff0c;让我们一起期待它未来更多的可能性吧&#xff01; 2022 年 11 月 30 日&#xff0c;OpenAI 宣布正式推出 ChatGPT。365 天过去&#xff0c;…

羊大师解读提高免疫力,能从羊奶开始吗?

羊大师解读提高免疫力&#xff0c;能从羊奶开始吗&#xff1f; 在当今充满挑战的世界中&#xff0c;拥有强大的免疫力是保持健康的关键。免疫系统是我们身体的守护者&#xff0c;能够抵御病菌和疾病&#xff0c;使我们远离健康问题。而如何提高免疫力一直是人们关注的焦点。近…