PyTorch 神经网络搭建模板

1. Dataset & DataLoader🍁

PyTorch 中,DatasetDataLoader 是用来处理数据的重要工具。它们的作用分别如下:

Dataset: Dataset 用于存储数据样本及其对应的标签。在使用神经网络训练时,通常需要将原始数据集转换为 Dataset 对象,以便能够通过 DataLoader 进行批量读取数据,同时也可以方便地进行数据增强、数据预处理等操作。

DataLoader: DataLoader 用于将 Dataset 封装成一个可迭代对象,以便轻松地访问数据集中的样本。通过设置 batch_size 参数,DataLoader 可以将数据集分成若干个批次,每个批次包含指定数量的样本。此外,DataLoader 还支持对数据进行 shuffle、多线程读取等操作,使得训练过程更加高效。

使用 Dataset 和 DataLoader 可以使得数据处理过程更加模块化和可维护,同时也可以提高训练效率。分别封装在 torch.utils.data.Datasettorch.utils.data.DataLoader

class MyDataset(Dataset):
    def __init__(self):
  
    def __len__(self):  
        
    def __getitem__(self):
        

这是一个定义了自定义数据集类 MyDataset 的模板代码,它继承了 PyTorch 中的 Dataset 类,其中包含了三个必要的函数:

__init__:用于初始化数据集,可以在这个函数中读取数据、进行预处理等操作。

__len__:用于返回数据集中样本的数量。

__getitem__:用于根据给定的索引 index 返回对应的样本及其标签。在这个函数中,需要根据索引从数据集中读取相应的样本和标签,并进行相应的预处理和转换。

需要在这个模板代码中添加具体的代码实现,以实现自定义数据集的功能。

from torch.utils.data import DataLoader

train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

使用 DataLoaders 准备训练和测试数据。在训练模型时,我们通常希望以“小批量(minibatches)”方式传递样本,每个 epoch 重新洗牌数据以减少模型过拟合,DataLoader 是一个可迭代对象。

next(iter(train_dataloader))

iter(train_dataloader) 将 train_dataloader 转换为一个迭代器对象,可以通过 next 函数逐一获取 DataLoader 中的数据。因此,next(iter(train_dataloader)) 将返回一个包含一个 batch 数据的元组。

具体来说,next 函数会从 train_dataloader 中获取下一个 batch 的数据,并将其转换为一个元组 (batch_data, batch_labels),其中 batch_data 是一个张量(tensor),形状为 [batch_size, input_size],表示一个 batch 中所有样本的输入特征;batch_labels 也是一个张量,形状为 [batch_size, output_size],表示一个 batch 中所有样本的输出标签,下面再举个例子吧。

my_list = [1, 2, 3, 4, 5] 
my_iterator = iter(my_list) 
print(next(my_iterator)) # 输出 1 
print(next(my_iterator)) # 输出 2 
print(next(my_iterator)) # 输出 3

在上面的例子中,my_list 是一个列表对象,通过 iter() 函数将其转换为迭代器 my_iterator。然后通过 next() 函数依次获取 my_iterator 中的每一个元素。

DataLoader 在创建时可以指定多个参数来控制数据的加载方式,常用的参数如下:

dataset:指定要加载的数据集。

batch_size:指定每个 batch 中样本的数量。

shuffle:指定是否在每个 epoch 开始时洗牌数据集。

sampler:指定一个自定义的数据采样器,用于控制每个 batch 中的样本顺序。

batch_sampler:指定一个自定义的 batch 采样器,用于控制 batch 的顺序和样本数量。

num_workers:指定数据加载时的线程数,用于加速数据读取。

collate_fn:指定一个自定义的函数,用于将一个 batch 中的多个样本拼接为一个张量(tensor)。

pin_memory:指定是否将数据加载到 GPU 的显存中,以加速数据读取。

drop_last:指定在数据集大小不是 batch_size 的倍数时,是否丢弃最后一个不足 batch_size 的 batch。

2. Build Model🌺

import torch
from torch import nn

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

我们通过继承 nn.Module 来定义神经网络,并在 __init__ 中初始化神经网络的层。每个 nn.Module 子类在 forward 方法中实现对输入数据的操作。

class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10),
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

这段代码定义了一个名为 NeuralNetwork 的神经网络类,它继承自 nn.Module

这个神经网络包含一个 Flatten 层和一个由3个线性层和2个 ReLU 激活函数组成的神经网络层。

__init__ 方法:在 Python 中,当一个类继承自另一个类时,它会继承该类的所有属性和方法。在 PyTorch 中,当你定义一个自己的神经网络类时,你通常会继承 nn.Module 这个基类,因为 nn.Module 已经定义好了很多用于搭建神经网络的基本组件和方法。

当你定义自己的神经网络类时,你需要调用基类的构造函数来继承基类的属性和方法。super().__init__() 就是调用基类(nn.Module)的构造函数,并返回一个代表基类实例的对象,这样你的神经网络类就可以使用 nn.Module 的所有属性和方法了。

forward 方法:就是神经网络的前向传播过程

model = NeuralNetwork().to(device)

这行代码创建了一个名为 model 的神经网络模型实例,使用了前面定义的 NeuralNetwork 类,并将其移动到了特定的设备(CPU 或 GPU)上。使用 to() 方法可以将模型移动到特定的设备上,从而利用 GPU 加速模型的训练和推理。如果设备是 GPU,则模型的所有参数和缓存都会复制到 GPU 上,如果设备是 CPU,则会复制到系统内存中。

3. Optimization🌹

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()
)

test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)

train_dataloader = DataLoader(training_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)

class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10),
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

model = NeuralNetwork()

使用 FashionMNIST 数据集,和之前描述的 Datasets & DataLoadersBuild Model

learning_rate = 1e-3
batch_size = 64
epochs = 5

learning_rate :在每个 batch/epoch 更新模型参数的量。较小的值会导致较慢的学习速度,而较大的值可能会在训练过程中产生不可预测的行为。

batch_size:在更新参数之前,通过网络传播的数据样本数量。

epochs:迭代数据集的次数。

def train_loop(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    for batch, (X, y) in enumerate(dataloader):
        # Compute prediction and loss
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss, current = loss.item(), (batch + 1) * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

定义模型训练函数

for 循环中,我们使用 enumerate 函数遍历 dataloader 中的每个批次(batch),并将批次索引(batch index)和包含输入数据和标签的元组解压缩为 X 和 y。

然后计算出当前批次中的预测(prediction)和损失(loss),以便我们可以通过优化器(optimizer)调整模型的参数以最小化损失。

其次的三行执行反向传播(backpropagation)并使用优化器更新模型的参数。optimizer.zero_grad() 将优化器的梯度归零,否则梯度会出现累加现象。然后使用 backward 函数计算损失相对于模型参数的梯度,最后使用 step 函数将优化器的梯度更新应用到模型的参数上。

这个 if 语句在每100个批次之后打印出当前的损失和训练样本数量,以便我们可以了解模型的训练进度。

def test_loop(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss, correct = 0, 0

    with torch.no_grad():
        for X, y in dataloader:
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()

    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

定义模型测试函数

在前三行中,我们计算出数据集的大小和批次数量,并初始化测试损失(test loss)和正确分类数量(correct)。

这个 with 语句在上下文中禁用梯度计算,因为测试阶段不需要计算梯度,以便我们可以仅使用模型的前向传递(forward pass)进行测试。在这个 for 循环中,我们遍历 dataloader 中的每个批次,使用模型计算出预测,计算当前批次的测试损失,并使用 argmax 函数找到每个样本的预测标签,然后将正确分类的数量累加到 correct 变量中。

计算出平均测试损失和正确分类的比例,并打印出测试结果。我们将测试损失除以批次数量来得到平均测试损失,并将正确分类的数量除以数据集大小来得到正确分类的比例。最后,我们打印出测试结果,其中包括正确分类的百分比和平均测试损失。

correct += (pred.argmax(1) == y).type(torch.float).sum().item()

这行代码有点抽象

这行代码的作用是计算当前批次中正确分类的数量,它可以分为几个步骤来理解:

首先,pred.argmax(1) 用来计算模型预测的最大概率值对应的类别,其中1表示按行计算最大值,即计算每个样本最有可能属于哪个类别。

接下来,pred.argmax(1) == y 用于将预测类别与真实类别进行比较,生成一个大小为批次大小的布尔张量,表示哪些样本被正确分类了。

然后,(pred.argmax(1) == y).type(torch.float) 将布尔张量转换为浮点数张量,其中正确分类的样本对应的元素值为1,错误分类的样本对应的元素值为0。

最后,.sum().item() 用于将正确分类的样本的元素值求和,并将结果转换为 Python 数值类型。

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train_loop(train_dataloader, model, loss_fn, optimizer)
    test_loop(test_dataloader, model, loss_fn)
print("Done!")

定义了一个交叉熵损失函数和一个随机梯度下降(SGD)优化器。交叉熵损失通常用于多类别分类问题,而 SGD 优化器是一种基本的梯度下降算法,用于更新模型的参数,使其逐渐逼近最优值。

这里定义了一个循环,用于多次训练和测试模型。具体来说,循环会运行 epochs 次,其中每次循环代表一个“训练周期”(epoch),在每个训练周期中,代码会先调用 train_loop() 函数来训练模型,然后调用 test_loop() 函数来测试模型在测试集上的性能。

4. Save & Load Model🌸

# Additional information
# 记录模型的相关训练信息
EPOCH = 5
PATH = "model.pt"
LOSS = 0.4

torch.save({
            'epoch': EPOCH,
            'model_state_dict': net.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': LOSS,
            }, PATH)

下面是模型的加载。

model = Net()  # 自己定义的网络
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

model.eval()
# - or -
model.train()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/990.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

嵌入式软件开发之Linux下C编程

目录 前沿 Hello World! 编写代码 编译代码 GCC编译器 gcc 命令 编译错误警告 编译流程 Makefile 基础 何为 Makefile Makefile 的引入 前沿 在 Windows 下我们可以使用各种各样的 IDE 进行编程,比如强大的 Visual Studio。但是在Ubuntu 下如何进…

【Java版oj】day10 井字棋、密码强度等级

目录 一、井字棋 (1)原题再现 (2)问题分析 (3)完整代码 二、密码强度等级 (1)原题再现 (2)问题分析 (3)完整代码 一、井字棋 &a…

CAT8网线测试仪使用中:线缆的抗干扰参数解读以及线缆工艺改进注意事项

FLUKE Agent platform -深圳维信,带你更深入的了解铜缆测试,详细为您讲解什么是TCL/ELTCL,他们对数据的传输到底有什么影响呢? 前情分析:为什么用双绞线传输信号?(一图就懂) TCL&a…

【深度解刨C语言】符号篇(全)

文章目录一.注释二.续行符与转义符1.续行符2.转义符三.回车与换行四.逻辑操作符五.位操作符和移位操作符六.前置与后置七.字符与字符串八./和%1.四种取整方式2.取模与取余的区别和联系3./两边异号的情况1.左正右负2.左负右正九.运算符的优先级一.注释 注释的两种符号&#xff…

Sentinel

SentinelSentinel介绍什么是Sentinel?为什么需要流量控制?为什么需要熔断降级?一些普遍的使用场景本文介绍参考:Sentinel官网《Spring Cloud Alibaba 从入门到实战.pdf》Sentinel下载/安装项目演示构建项目控制台概览演示之前需先明确&#…

【webrtc】ICE 到VCMPacket的视频内存分配

ice的数据会在DataPacket 构造是进行内存分配和拷贝而后DataPacket 会传递给rtc模块处理rtc模块使用DataPacket 构造rtp包最终会给到OnReceivedPayloadData 进行rtp组帧。吊炸天的是DataPacket 竟然没有声明析构方法。RtpVideoStreamReceiver::OnReceivedPayloadData 的内存是外…

3.网络爬虫——Requests模块get请求与实战

Requests模块get请求与实战requests简介:检查数据请求数据保存数据前言: 前两章我们介绍了爬虫和HTML的组成,方便我们后续爬虫学习,今天就教大家怎么去爬取一个网站的源代码(后面学习中就能从源码中找到我们想要的数据…

普通Java工程师 VS 优秀架构师

1 核心能力 1.1 要成为一名优秀的Java架构师 只懂技术还远远不够,懂技术/懂业务/懂管理的综合型人才,才是技术团队中的绝对核心。 不仅仅是架构师,所有的技术高端岗位,对人才的综合能力都有较高的标准。 架构路线的总设计师 规…

安卓渐变的背景框实现

安卓渐变的背景框实现1.背景实现方法1.利用PorterDuffXfermode进行图层的混合,这是最推荐的方法,也是最有效的。2.利用canvas裁剪实现,这个方法有个缺陷,就是圆角会出现毛边,也就是锯齿。3.利用layer绘制边框1.背景 万…

多线程案例——阻塞队列

目录 一、阻塞队列 1. 生产者消费者模型 (1)解耦合 (2)“削峰填谷” 2. 标准库中的阻塞队列 3. 自己实现一个阻塞队列(代码) 4. 自己实现生产者消费者模型(代码) 一、阻塞队列…

【Pytorch】 理解张量Tensor

本文参加新星计划人工智能(Pytorch)赛道:https://bbs.csdn.net/topics/613989052 这是目录张量Tensor是什么?张量的创建为什么要用张量Tensor呢?总结张量Tensor是什么? 在深度学习中,我们经常会遇到一个概念&#xff…

更改Hive元数据发生的生产事故

今天同事想在hive里用中文做为分区字段。如果用中文做分区字段的话,就需要更改Hive元 数据库。结果发生了生产事故。导致无法删除表和删除分区。记一下。 修改hive元数据库的编码方式为utf后可以支持中文,执行以下语句: alter table PARTITI…

Vue初入,了解Vue的发展与优缺点

作者简介:一名计算机萌新、前来进行学习VUE,让我们一起进步吧。 座右铭:低头赶路,敬事如仪 个人主页:我叫于豆豆吖的主页 前言 从本章开始进行Vue前端的学习,了解Vue的发展,以及背后的故事。 一.vue介…

ASEMI代理瑞萨TW9992AT-NA1-GE汽车芯片

编辑-Z TW9992AT-NA1-GE是一款低功耗NTSC/PAL模拟视频解码器,专为汽车应用而设计。它支持单端、差分和伪差分复合视频输入。集成了对电池短路和对地短路检测,先进的图像增强功能,如可编程的自动对比度调整(ACA)和MIPI…

【Linux】网络编程套接字(下)

🎇Linux: 博客主页:一起去看日落吗分享博主的在Linux中学习到的知识和遇到的问题博主的能力有限,出现错误希望大家不吝赐教分享给大家一句我很喜欢的话: 看似不起波澜的日复一日,一定会在某一天让你看见坚持…

ASEMI代理MIMXRT1064CVJ5B原装现货NXP车规级MIMXRT1064CVJ5B

编辑:ll ASEMI代理MIMXRT1064CVJ5B原装现货NXP车规级MIMXRT1064CVJ5B 型号:MIMXRT1064CVJ5B 品牌:NXP /恩智浦 封装:LFGBA-196 批号:2023 安装类型:表面贴装型 引脚数量:196 类型&#…

【Hadoop-yarn-01】大白话讲讲资源调度器YARN,原来这么好理解

YARN作为Hadoop集群的御用调度器,在整个集群的资源管理上立下了汗马功劳。今天我们用大白话聊聊YARN存在意义。 有了机器就有了资源,有了资源就有了调度。举2个很鲜活的场景: 在单台机器上,你开了3个程序,分别是A、B…

Redis知识点汇总

前言 梳理知识 说一下项目中的Redis的应用场景 首先知道Redis的5大value类型: string,list,hash, set ,zset 2.基本上是缓存 3.为的是服务无状态, 4.无锁化 Redis是单线程还是多线程 1.无论什么版本,工作线程就一个 2.6.x高版本出现IO多线程

三天吃透操作系统面试八股文

本文已经收录到Github仓库,该仓库包含计算机基础、Java基础、多线程、JVM、数据库、Redis、Spring、Mybatis、SpringMVC、SpringBoot、分布式、微服务、设计模式、架构、校招社招分享等核心知识点,欢迎star~ Github地址:https://github.com/…

基于python的超市历年数据可视化分析

人生苦短 我用python Python其他实用资料:点击此处跳转文末名片获取 数据可视化分析目录人生苦短 我用python一、数据描述1、数据概览二、数据预处理0、导入包和数据1、列名重命名2、提取数据中时间,方便后续分析绘图三、数据可视化1、美国各个地区销售额的分布&…
最新文章