PyTorch数据结构

前言:因为最近开始读深度学习代码,主要都是用PyTorch框架,所以来补一些PyTorch基础,先从数据结构入手。

PyTorch数据结构

  • PyTorch
  • PyTorch数据结构
    • 张量
      • 属性:维度、轴、形状
      • 常见的操作
    • 数据集
      • 构造代码
      • DataLoader
    • 模块
  • 参考

PyTorch

PyTorch:PyTorch是一个开源深度学习框架,有很多好用的深度学习工具,提供了丰富的库,可以很方便构建和训练神经网络模型。

  • GPU加持:PyTorch提供了GPU优化操作和管理,使得在GPU上运行模型更高效。
  • 提供预训练模型和模型库:PyTorch提供了很多预训练模型和模型库,能很方便进行深度学习模型的开发。
  • 支持分布式训练:PyTorch支持分布式训练,可以在多个GPU和多台机器上加速训练。
  • 动态计算图:PyTorch使用动态计算图来计算图,在运行时动态生成而不是编译时静态生成,可以观察动态生成的数据流向。
  • 自动求导:PyTorch内置了自动求导功能,避免手动去计算非常复杂的导数,极大地减少了构建模型的时间。

PyTorch数据结构

张量

Tensor(张量):Tensor是PyTorch中最基本的数据结构,类似于多维数组。它可以表示标量、向量、矩阵或任意维度的数组。

属性:维度、轴、形状

维度(Dimensions):维度又可以叫做阶(Rank),理解为数组的维度。只有标量就是0维度,一维数组就是1维度,其余以此类推。

轴(Axis):轴数和维数、阶数相同,多维张量需要索引才能引用到里面的内容,这个不同维度索引就是轴。例如形状3×4的张量,需要访问张量[0][2]位置的内容,轴指的就是“[]”里面的索引,第一个轴的长度是3,第二的轴的长度是4。

形状(Shape):张量的形状由每个轴的长度决定,轴的长度就是对应维度能索引的大小。

相关的代码:
常用的函数有Tensor.size()和Tensor.dim()。

import torch
# 创建3维张量
tensor_3d = torch.tensor([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]])
print("张量形状:", tensor_3d.size())# torch.Size([1,2,3]) 1个2*3的数组
print("轴数:", tensor_3d.dim())# 3
print(tensor_3d.size(1)==tensor_3d.size(-2))# True x.size(index)表示取出某个维度上的大小,正的表示从左到右,负的表示从右到左

常见的操作

重构(reshape):torch.reshape()可以重构张量的形状。过程是先把所有的内容按行排列,然后先分高纬再分低维度。

重构举例:

x = torch.tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11])
x.reshape(4,3) # 变成3行4列矩阵
x # tensor([[ 0,  1,  2],[ 3,  4,  5],[ 6,  7, 8],[ 9,  10,  11]])
x.reshape(2,3,2) 
# 先变一维 [ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11]
# 再变二维 [[ 0,  1,  2,  3,  4,  5],[ 6,  7,  8,  9, 10, 11]]
# 再变三维 [[[ 0,  1],[ 2,  3],[ 4,  5]],[[ 6,  7],[ 8,  9],[10, 11]]]
x # tensor([[[ 0,  1],[ 2,  3],[ 4,  5]],[[ 6,  7],[ 8,  9],[10, 11]]])

张量索引:

x = torch.tensor([[[ 0,  1,  2,  3],[ 4,  5,  6,  7],[ 8,  9, 10, 11]]])
print(x[0])# tensor([[ 0,  1,  2,  3],[ 4,  5,  6,  7],[ 8,  9, 10, 11]])
print(x[0][0])# tensor([0, 1, 2, 3])
print(A[0, 0:2, :])# tensor([[0, 1, 2, 3],[4, 5, 6, 7]])

拼接(cat和stack):

  • stack:在新创建的维度上进行拼接,会扩宽维度。
  • cat:按张量维度进行拼接。
# [2,3]->[2,9]
x=torch.ones((2, 3))
y = torch.cat([x, x], dim=0)
print(y)# tensor([[1., 1., 1.],[1., 1., 1.],[1., 1., 1.],[1., 1., 1.]])
# [2,3]->[2, 3, 2]
y = torch.stack([x, x], dim=2)
print(y)# tensor([[[1., 1.],[1., 1.],[1., 1.]],[[1., 1.],[1., 1.],[1., 1.]]])

升维(unsqueeze):指定维度插入新的维度。

x = torch.tensor([1,2,3,4])
y = x.unsqueeze(dim=0)
print(y)# tensor([[1],[2],[3],[4]])

降维(squeeze):移除制定或维度大小为1的维度。

x = torch.tensor([[[1],[2]],[[3],[4]]])
y = x.squeeze(2)
print(y)# tensor([[1, 2],[3, 4]])

升维和降维的好处,在于深度学习通常做运算会有维度要求。

数据集

Dataset(数据集):Dataset是一个抽象类,用于表示数据集。

使用:通过继承Dataset类,可以自定义数据集,并实现数据加载、预处理和获取样本等功能。PyTorch还提供了一些内置的数据集类,如MNIST、CIFAR-10等,用于方便地加载常用的数据集。

构造代码

代码:
通过继承该类来自定义自己的数据集类,在继承时要求必须重载__len__()和__getitem__()这两个方法。

  • __len__():返回的是数据集的大小。
  • __getitem__():实现索引数据集中的某一个数据。
import torch
from torch.utils.data import Dataset

class BasicDataset(Dataset):# 继承Dataset
    def __init__(self, data_tensor, target_tensor):
        self.data_tensor = data_tensor
        self.target_tensor = target_tensor

    def __getitem__(self, index):
        return self.data_tensor[index], self.target_tensor[index]

    def __len__(self):
        return self.data_tensor.size(0)

# 生成数据
data_tensor = torch.randn(4, 3)# 生成一个每个元素服从正态分布的4行3列随机张量
target_tensor = torch.rand(10)# 从区间[0,1)的均匀分布中随机抽取一个随机数生成一个张量

# 将数据封装成Dataset
tensor_dataset = BasicDataset(data_tensor, target_tensor)

print(tensor_dataset[1])# 调用__getitem__

print(len(tensor_dataset))# 调用__len__

DataLoader

DataLoader:DataLoader将Dataset对象或自定义数据类的对象封装成一个迭代器,通过迭代器可以输出Dataset的内容。

DataLoader参数:

  • dataset:表示Dataset类,数据从哪读取以及如何读取。
  • batch_size:表示批大小。
  • shuffle:表示每个epoch要不要重新打乱数据,默认false。
  • num_works:用多少个子进程读取数据。
  • drop_last:表示当样本数不能被batch_size整除时,是否舍弃最后一批数据。

batch和epoch的区别:

  • 一个epoch就是将所有训练样本训练一次的过程,神经网络的训练往往会需要很多次epoch才会loss收敛到合适的程度。
  • 将整个训练样本分成若干个Batch。

使用代码:

# batch_size设置为2,shuffle=False不打乱数据顺序,num_workers=1使用1个子进程
dataloader = BasicDataset(dataset, batch_size=2, shuffle=False, num_workers=1)

# 以for循环形式输出
for input, target in dataloader:
    print(input, target)

模块

Module(模块):Module是PyTorch中用于构建模型的基类。通过继承Module类,可以定义自己的模型,并实现前向传播和反向传播等方法。Module提供了参数管理、模型保存和加载等功能,方便模型的训练和部署。

实际去看深度学习代码的时候,会发现定义模型的类,都是继承nn.Module(模块)。

代码技巧:

  • 网络中具有可学习参数的层(如全连接层、卷积层等)放在构造函数__init__()中。不具有参数的也可以放入(ReLU、dropout、BatchNormanation),如果不写在构造函数的话,可以在forward方法中用nn.functional来代替。
  • forward方法是必须要重写的,是实现模型的功能,实现各个层之间的连接关系的核心。在阅读模型的时候,阅读forward是搞懂模型运作流程最好的方式。
import torch
import torch.nn.functional as F
 
class MyNet(torch.nn.Module):
    def __init__(self):
        super(MyNet, self).__init__()  # 第一句话,调用父类的构造函数
        self.conv1 = torch.nn.Conv2d(3, 32, 3, 1, 1)
        # self.relu1=torch.nn.ReLU()
        # self.max_pooling1=torch.nn.MaxPool2d(2,1)
        self.conv2 = torch.nn.Conv2d(3, 32, 3, 1, 1)
        # self.relu2=torch.nn.ReLU()
        # self.max_pooling2=torch.nn.MaxPool2d(2,1)
        self.dense1 = torch.nn.Linear(32 * 3 * 3, 128)
        self.dense2 = torch.nn.Linear(128, 10)
 
    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)# self.relu1(x)
        x = F.max_pool2d(x)# self.max_pooling1(x)
        x = self.conv2(x)
        x = F.relu(x)#  self.relu2(x)
        x = F.max_pool2d(x)# self.max_pooling2(x)
        x = self.dense1(x)
        x = self.dense2(x)
        return x

参考

nn.Module类详解
Dataset和DataLoader
PyTorch数据结构

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/499543.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

数据结构——lesson11排序之快速排序

💞💞 前言 hello hello~ ,这里是大耳朵土土垚~💖💖 ,欢迎大家点赞🥳🥳关注💥💥收藏🌹🌹🌹 💥个人主页&#x…

ASP.NET制作试卷(单选+多选)

需求: 1.包含单选题、多选题。 2.所有题做完再提交。 3.提示错误、统计分数(提交后)。 项目结构: 效果展示: 效果展示(视频): ASP.NET练习1效果 index.aspx代码: &l…

排序---数组和集合

1、数组排序 Arrays.sort(int[] a)这种形式是对一个数组的所有元素进行排序,并且是按照从小到大的排序。 public static void main(String[] args) {Integer []arr {1,2,3,4,5,6};//升序Arrays.sort(arr);for (int x:arr){System.out.print(x " ");}Sys…

大学生租房系统的设计与实现|Springboot+ Mysql+Java+ B/S结构(可运行源码+数据库+设计文档)

本项目包含可运行源码数据库LW,文末可获取本项目的所有资料。 推荐阅读100套最新项目持续更新中..... 2024年计算机毕业论文(设计)学生选题参考合集推荐收藏(包含Springboot、jsp、ssmvue等技术项目合集) 1. 系统功能…

ForkJoinPool、CAS原子操作

ForkJoinPool ForkJoinPool是由JDK1.7后提供多线程并行执行任务的框架。可以理解为一种特殊的线程池。 1.任务分割:Fork(分岔),先把大的任务分割成足够小的子任务,如果子任务比较大的话还要对子任务进行继续分割。 …

C#手麻系统源码,医院手术麻醉信息系统源码,前端框架:Vue,Ant-Design,后端框架:百小僧开源框架

手术麻醉管理系统覆盖了从患者入院,经过术前、术中、术后,直至出院的全过程。医院手术麻醉系统能够规范麻醉科和手术室的工作流程、实现麻醉手术过程中的信息数字化和网络化、自动生成麻醉手术中的各种医疗文书、完整共享HIS、LIS和PACS等手术患者信息&a…

RPA机器人:人人都会实现的机器人

在这个数字化飞速发展的时代,微信已经成为我们日常生活和工作中不可或缺的社交工具。然而,随着联系人数量的不断增加,如何高效管理这些社交关系成为了许多人面临的挑战。今天,我要为大家介绍的,是一款能够彻底改变你微…

PHP实现单列内容快速查重与去重

应用场景:excel一列内容比如身份证号&#xff0c;可能有重复的&#xff0c; 则用此工具快速查询那些重复及显示去重后内容。 使用&#xff1a;粘贴一列数据&#xff0c;然后提交发送。 <?php $tm "单列查重去重(粘贴Excel中1列内容查重)!";function tipx($str…

WEB embedded APP (javafx)

WEB embedded APP &#xff08;javafx&#xff09; &#xff08;BS 嵌入CS&#xff09; CS嵌入BS_哔哩哔哩_bilibili

生信软件14 - bcftools提取和注释VCF文件关键信息

bcftools可用于变异信息的描述性统计&#xff0c;计算&#xff0c;过滤和格式转换。 1. 显示VCF文件的头信息 bcftools view -h sample.vcf##fileformatVCFv4.2 ##FILTER<IDPASS,Description"All filters passed"> ##bcftoolsVersion1.5htslib-1.5 ##bcftool…

vmware,linux,centos7,NAT模式下的网络配置

centos7的NAT网络配置 NAT模式说明虚拟机网络配置工具本机配置net8网络&#xff08;NAT的网域&#xff09;本机的IP配置(用于net8局域网内解析主机IP和域名对应关系使用)&#xff08;可选&#xff09;虚拟机内的网络配置虚拟机ping不通www.baidu.com的情况下虚拟机ping可以ping…

我劝你不要买29.99万的小米SU7

文 | AUTO芯球 作者 | 雷歌 我在想我是不是贱啊&#xff1f;&#xff01; 我昨晚兴奋得头晕脸热的&#xff0c;身边一众关注车的朋友&#xff0c;也感觉到了车圈过年的气氛。 原因就是小米SU7的价格公布了。 21.59万元起售价格出来以后&#xff0c;就好比新年0点一过的那个…

C++:sizeof关键字(7)

sizeof用于统计数据所占用内存的大小 用法&#xff1a;sizeof( 变量名称 / 变量) 直接上代码&#xff0c;可以在让大家直观的感受到sizeof关键字的用法 #include<iostream> using namespace std;// 语法&#xff1a; sizeof&#xff08;数据类型|变量名&#xff09;// 用…

PS从入门到精通视频各类教程整理全集,包含素材、作业等(2)

PS从入门到精通视频各类教程整理全集&#xff0c;包含素材、作业等 最新PS以及插件合集&#xff0c;可在我以往文章中找到 由于阿里云盘有分享次受限制和文件大小限制&#xff0c;今天先分享到这里&#xff0c;后续持续更新 初级教程素材 等文件 https://www.alipan.com/s/fC…

从0到1利用express搭建后端服务

目录 1 架构的选择2 环境搭建3 安装express4 创建启动文件5 express的核心功能6 加入日志记录功能7 日志记录的好处本节代码总结 不知不觉学习低代码已经进入第四个年头了&#xff0c;既然低代码很好&#xff0c;为什么突然又自己架构起后端了呢&#xff1f;我有一句话叫低代码…

C++——vector类及其模拟实现

前言&#xff1a;前边我们进行的string类的方法及其模拟实现的讲解。这篇文章将继续进行C的另一个常用类——vector。 一.什么是vector vector和string一样&#xff0c;隶属于C中STL标准模板库中的一个自定义数据类型&#xff0c;实际上就是线性表。两者之间有着很多相似&…

安装docker 并搭建出一颗爱心树

1、docker介绍 Docker 是⼀个开源的容器运⾏时软件&#xff08;容器运⾏时是负责运⾏容器的软件&#xff09;&#xff0c;基于 Go 语 ⾔编写&#xff0c;并遵从 Apache2.0 协议开源。 Docker可以让开发者打包⾃⼰的应⽤以及依赖到⼀个轻量的容器中&#xff0c;然后发布到任何…

Python 垃圾回收和弱引用(Weakref)

Python中的赋值语句是建立变量名与对象的引用关系&#xff0c;多个变量可以引用同一个对象&#xff0c;当对象的引用数归零时&#xff0c;可能会被当作垃圾回收。而弱引用即可以引用对象&#xff0c;又不会阻止对象被当作垃圾回收&#xff0c;因此这个特性非常适合用在缓存场景…

值得收藏!2024年人工智能顶级会议投稿信息汇总(计算机视觉领域)

计算机视觉是人工智能领域的重要分支。它融合了图像处理、模式识别、机器学习和人工智能等多个领域的技术&#xff0c;旨在让计算机具备类似甚至超越人类视觉系统的能力。本文将精选介绍计算机视觉领域内的重要会议&#xff0c;包括会议主题、稿件提交的截止日期、会议的时间与…

SpringCloudConfig 使用git搭建配置中心

一 SpringCloudConfig 配置搭建步骤 1.引入 依赖pom文件 引入 spring-cloud-config-server 是因为已经配置了注册中心 <dependencies><dependency><groupId>org.springframework.cloud</groupId><artifactId>spring-cloud-config-server</…
最新文章