学习pytorch18 pytorch完整的模型训练流程

pytorch完整的模型训练流程

  • 1. 流程
    • 1. 整理训练数据 使用CIFAR10数据集
    • 2. 搭建网络结构
    • 3. 构建损失函数
    • 4. 使用优化器
    • 5. 训练模型
    • 6. 测试数据 计算模型预测正确率
    • 7. 保存模型
  • 2. 代码
    • 1. model.py
    • 2. train.py
  • 3. 结果
    • tensorboard结果
      • 以下图片 颜色较浅的线是真实计算的值,颜色较深的线是做了平滑处理的值
      • 训练loss
      • 测试loss
      • 测试集正确率
  • 4. 需要注意的细节

1. 流程

1. 整理训练数据 使用CIFAR10数据集

train_data = torchvision.datasets.CIFAR10(root='./dataset', train=True, transform=torchvision.transforms.ToTensor(),
                                          download=True)

2. 搭建网络结构

在这里插入图片描述
model.py

3. 构建损失函数

loss_fn = nn.CrossEntropyLoss()

4. 使用优化器

learing_rate = 1e-2 # 0.01
optimizer = torch.optim.SGD(net.parameters(), lr=learing_rate)

5. 训练模型

output = net(imgs)    # 数据输入模型
loss = loss_fn(output, targets)  # 损失函数计算损失 看计算的输出和真实的标签误差是多少
# 优化器开始优化模型  1.梯度清零  2.反向传播  3.参数优化
optimizer.zero_grad()  # 利用优化器把梯度清零 全部设置为0
loss.backward()        # 设置计算的损失值的钩子,调用损失的反向传播,计算每个参数结点的参数
optimizer.step()       # 调用优化器的step()方法 对其中的参数进行优化  

6. 测试数据 计算模型预测正确率

output = net(imags)
# 计算测试集的正确率
preds = (output.argmax(1)==targets).sum()
accuracy += preds 
rate = accuracy/len(test_data)

调用模型输出tensor 数据类型的 argmax方法, argmax或获取一行或者一列数值中最大数值的下标位置,argmax(0) 是从列的维度取一列数值的最大值的下标,argmax(1) 是从行的维度取一行数值的最大值的下标
output.argmax(1)==targets 会输出如下图最后一行 [false, ture], 对应位置相同则为true,对应位置不同则为false;
调用sum()方法,计算求和,false值为0,true值为1.
最后计算得出测试集整体正确率: rate = accuracy/len(test_data)
在这里插入图片描述

7. 保存模型

torch.save(net, './net_epoch{}.pth'.format(i))

2. 代码

1. model.py

import torch
from torch import nn

# 2. 搭建模型网络结构--神经网络
class Cifar10Net(nn.Module):
    def __init__(self):
        super(Cifar10Net, self).__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=32, kernel_size=5, stride=1, padding=2),
            nn.MaxPool2d(kernel_size=2),
            nn.Conv2d(32, 32, 5, 1, 2),
            nn.MaxPool2d(kernel_size=2),
            nn.Conv2d(32, 64, 5, 1, 2),
            nn.MaxPool2d(kernel_size=2),
            nn.Flatten(),
            nn.Linear(64*4*4, 64),
            nn.Linear(64, 10)
        )

    def forward(self, x):
        x = self.net(x)
        return x


if __name__ == '__main__':
    net = Cifar10Net()
    input = torch.ones((64, 3, 32, 32))
    output = net(input)
    print(output.shape)

2. train.py

import torch
import torchvision
from torch import nn
from torch.utils.tensorboard import SummaryWriter

from p24_model import *

# 1. 准备数据集
# 训练数据
from torch.utils.data import DataLoader

train_data = torchvision.datasets.CIFAR10(root='./dataset', train=True, transform=torchvision.transforms.ToTensor(),
                                          download=True)
# 测试数据
test_data = torchvision.datasets.CIFAR10(root='./dataset', train=False, transform=torchvision.transforms.ToTensor(),
                                         download=True)

# 查看数据大小--size
print("训练数据集大小:", len(train_data))
print("测试数据集大小:", len(test_data))
# 利用DataLoader来加载数据集
train_loader = DataLoader(dataset=train_data, batch_size=64)
test_loader = DataLoader(dataset=test_data, batch_size=64)

# 2. 导入模型结构 创建模型
net = Cifar10Net()

# 3. 创建损失函数  分类问题--交叉熵
loss_fn = nn.CrossEntropyLoss()

# 4. 创建优化器
# learing_rate = 0.01
# 1e-2 = 1 * 10^(-2) = 0.01
learing_rate = 1e-2
print(learing_rate)
optimizer = torch.optim.SGD(net.parameters(), lr=learing_rate)

# 设置训练网络的一些参数
epoch = 10   # 记录训练的轮数
total_train_step = 0  # 记录训练的次数
total_test_step = 0   # 记录测试的次数

# 利用tensorboard显示训练loss趋势
writer = SummaryWriter('./train_logs')

for i in range(epoch):
    # 训练步骤开始
    net.train()  # 可以加可以不加  只有当模型结构有 Dropout BatchNorml层才会起作用
    for data in train_loader:
        imgs, targets = data  # 获取数据
        output = net(imgs)    # 数据输入模型
        loss = loss_fn(output, targets)  # 损失函数计算损失 看计算的输出和真实的标签误差是多少
        # 优化器开始优化模型  1.梯度清零  2.反向传播  3.参数优化
        optimizer.zero_grad()  # 利用优化器把梯度清零 全部设置为0
        loss.backward()        # 设置计算的损失值,调用损失的反向传播,计算每个参数结点的参数
        optimizer.step()       # 调用优化器的step()方法 对其中的参数进行优化
        # 优化一次 认为训练了一次
        total_train_step += 1
        if total_train_step % 100 == 0:
            print('训练次数: {}   loss: {}'.format(total_train_step, loss))
        # 直接打印loss是tensor数据类型,打印loss.item()是打印的int或float真实数值, 真实数值方便做数据可视化【损失可视化】
        # print('训练次数: {}   loss: {}'.format(total_train_step, loss.item()))
        writer.add_scalar('train-loss', loss.item(), global_step=total_train_step)

    # 利用现有模型做模型测试
    # 测试步骤开始
    total_test_loss = 0
    accuracy = 0
    net.eval()  # 可以加可以不加  只有当模型结构有 Dropout BatchNorml层才会起作用
    with torch.no_grad():
        for data in test_loader:
            imags, targets = data
            output = net(imags)
            loss = loss_fn(output, targets)
            total_test_loss += loss.item()
            # 计算测试集的正确率
            preds = (output.argmax(1)==targets).sum()
            accuracy += preds
    # writer.add_scalar('test-loss', total_test_loss, global_step=i+1)
    writer.add_scalar('test-loss', total_test_loss, global_step=total_test_step)
    writer.add_scalar('test-accracy', accuracy/len(test_data), total_test_step)
    total_test_step += 1
    print("---------test loss: {}--------------".format(total_test_loss))
    print("---------test accuracy: {}--------------".format(accuracy))
    # 保存每一个epoch训练得到的模型
    torch.save(net, './net_epoch{}.pth'.format(i))

writer.close()

3. 结果

训练数据集大小: 50000
测试数据集大小: 10000
0.01
训练次数: 100   loss: 2.2905373573303223
训练次数: 200   loss: 2.2878968715667725
训练次数: 300   loss: 2.258394718170166
训练次数: 400   loss: 2.1968581676483154
训练次数: 500   loss: 2.0476632118225098
训练次数: 600   loss: 2.002145767211914
训练次数: 700   loss: 2.016021728515625
---------test loss: 316.382279753685--------------
训练次数: 800   loss: 1.8957302570343018
训练次数: 900   loss: 1.8659226894378662
训练次数: 1000   loss: 1.9004186391830444
训练次数: 1100   loss: 1.9708642959594727
......

tensorboard结果

安装tensorboard运行环境

pip install tensorboard
pip install opencv-python
pip install six
tensorboard --logdir=train_logs

以下图片 颜色较浅的线是真实计算的值,颜色较深的线是做了平滑处理的值

训练loss

在这里插入图片描述

测试loss

在这里插入图片描述

测试集正确率

在这里插入图片描述

4. 需要注意的细节

https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module

所有网络层继承于torch.nn.Module, net.train() net.eval() 在模型训练或测试之初 可以加可以不加 只有当模型结构有 Dropout BatchNorml层才会起作用,当模型有这两个网络层的时候,两个代码需要加上。
在这里插入图片描述

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/230326.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

idea使用maven的package打包时提示“找不到符号”或“找不到包”

介绍:由于我们的项目是多模块开发项目,在打包时有些模块内容更新导致其他模块在引用该模块时不能正确引入。 情况一:找不到符号 情况一:找不到包 错误代码部分展示: Failure to find com.xxx.xxxx:xxx:pom:0.5 in …

3D渲染和动画制作软件KeyShot Pro mac附加功能

KeyShot 11 mac是一款专业化实时3D渲染工具,使用它可以简化3d渲染和动画制作流程,并且提供最准确的材质及光线,渲染效果更加真实,KeyShot为您提供了使用 CPU 或 NVIDIA GPU 进行渲染的能力和选择,并能够线性扩展以获得…

HarmonyOS4.0从零开始的开发教程10管理组件状态

HarmonyOS(八)管理组件状态 概述 在应用中,界面通常都是动态的。如图1所示,在子目标列表中,当用户点击目标一,目标一会呈现展开状态,再次点击目标一,目标一呈现收起状态。界面会根…

Django的logging-日志模块的简单使用方法

扩展阅读: Python-Django的“日志功能-日志模块(logging模块)-日志输出”的功能详解 现在有下面的Python代码: # -*- coding: utf-8 -*-def log_out_test(content_out):print(content_out)content1 "i love you01" log_out_test(content1)现…

<Linux>(极简关键、省时省力)《Linux操作系统原理分析之Linux 设备管理》(29)

《Linux操作系统原理分析之Linux 设备管理》(29) 10 Linux 设备管理10.1 Linux 设备分类与识别10.1.1 Linux 设备的分类10.1.2 设备文件 10.2 设备驱动程序与设备注册10.2.1 设备驱动程序10.2.2 设备注册 10.3Linux 的 I/O 控制方式10.3.1 查…

Docker, Docker-compose部署Sonarqube

参考文档 镜像地址: https://hub.docker.com/_/sonarqube/tags Docker部署文档地址 Installing from Docker | SonarQube Docs Docker-compose文档部署地址: Installing from Docker | SonarQube Docs 部署镜像 2.1 docker部署 # 宿主机执行 $. vi /etc/sysctl.conf…

探索CSS:从入门到精通Web开发(二)

前言 当我们谈论网页设计和开发时,CSS(层叠样式表)无疑是其中的重要一环。作为HTML的伴侣,CSS赋予网页以丰富的样式和布局,使得网站看起来更加吸引人并且具备更好的可读性。本书将通过一系列深入浅出的方式&#xff0…

kafka学习笔记--安装部署、简单操作

本文内容来自尚硅谷B站公开教学视频,仅做个人总结、学习、复习使用,任何对此文章的引用,应当说明源出处为尚硅谷,不得用于商业用途。 如有侵权、联系速删 视频教程链接:【尚硅谷】Kafka3.x教程(从入门到调优…

深入理解 Promise:前端异步编程的核心概念

深入理解 Promise:前端异步编程的核心概念 本文将帮助您深入理解 Promise,这是前端异步编程的核心概念。通过详细介绍 Promise 的工作原理、常见用法和实际示例,您将学会如何优雅地处理异步操作,并解决回调地狱问题。 异步编程和…

python主流开发工具排名,python开发工具有哪些

本篇文章给大家谈谈python的开发工具软件有哪些,以及python主流开发工具排名,希望对各位有所帮助,不要忘了收藏本站喔。 python中用到哪些软件 一、Python代码编辑器1、sublime Textsublime Text是一款非常流行的代码编辑器,支持P…

基于单片机指纹考勤机控制系统设计

**单片机设计介绍,基于单片机指纹考勤机控制系统设计 文章目录 一 概要二、功能设计设计思路 三、 软件设计原理图 五、 程序六、 文章目录 一 概要 基于单片机的指纹考勤机控制系统是一种用于管理员工考勤和实现门禁控制的设计方案。它通过使用单片机作为主控制器…

Amazon CodeWhisperer 提供新的人工智能驱动型代码修复、IaC 支持以及与 Visual Studio 的集成...

Amazon CodeWhisperer 的人工智能(AI)驱动型代码修复和基础设施即代码(IaC)支持已正式推出。Amazon CodeWhisperer 是一款用于 IDE 和命令行的人工智能驱动型生产力工具,现已在 Visual Studio 中推出,提供预…

nodejs微信小程序+python+PHP的游戏测评网站设计与实现-计算机毕业设计推荐

目 录 摘 要 I ABSTRACT II 目 录 II 第1章 绪论 1 1.1背景及意义 1 1.2 国内外研究概况 1 1.3 研究的内容 1 第2章 相关技术 3 2.1 nodejs简介 4 2.2 express框架介绍 6 2.4 MySQL数据库 4 第3章 系统分析 5 3.1 需求分析 5 3.2 系统可行性分析 5 3.2.1技术可行性:…

初识Matter——esp-box控制两盏灯

初识Matter 一、效果展示 二、准备 1.ubuntu系统/Mac系统电脑 2.安装esp-idf及esp-matter环境 3.esp-box设备 4.两块esp32 5.两个led灯或使用板载灯 三、烧录固件(esp-box) 下载esp-box例程 git地址:GitHub - espressif/esp-box: Th…

微信小程序 - PC端选择ZIP文件

微信小程序 - PC端选择文件 分享代码片段场景分析解决思路附魔脚本chooseMediaZip 选择附魔后的ZIP文件相关方法测试方法 参考资料 分享代码片段 不想听废话的,直接看代码。 https://developers.weixin.qq.com/s/UL9aojmn7iNU 场景分析 如果你的微信小程序需要选…

机器视觉相机镜头光源选型

镜头选型工具 - HiTools - 海康威视 Hikvisionhttps://www.hikvision.com/cn/support/tools/hitools/cl8a9de13648c56d7f/ 海康机器人-机器视觉产品页杭州海康机器人股份有限公司海康机器人HIKROBOT是面向全球的机器视觉和移动机器人产品及解决方案提供商,业务聚焦于…

dell服务器重启后显示器黑屏

1.硬件层面:观察主机的指示灯 (1)指示灯偏黄,硬件存在问题(内存条有静电,拔出后用橡皮擦擦拭;或GPU松动) a.电源指示灯黄,闪烁三下再闪烁一下,扣下主板上的纽…

response应用及重定向和request转发

请求和转发: response说明一、response文件下载二、response验证码实现1.前置知识:2.具体实现:3.知识总结 三、response重定向四、request转发五、重定向和转发的区别 response说明 response是指HttpServletResponse,该响应有很多的应用&…

从阻抗匹配看拥塞控制

先来理解阻抗匹配,但我不按传统方式解释,因为传统方案你要先理解如何定义阻抗,然后再学习什么是输入阻抗和输出阻抗,最后再看如何让它们匹配,而让它们匹配的目标仅仅是信号不反射,以最大能效被负载接收。 …

数据结构——二叉树的链式结构

个人主页:日刷百题 系列专栏:〖C语言小游戏〗〖Linux〗〖数据结构〗 〖C语言〗 🌎欢迎各位→点赞👍收藏⭐️留言📝 ​ 一、二叉树的创建 这里我们使用先序遍历的思想来创建二叉树,这里的内容对于刚接触二…