[Pytorch]手写数字识别——真·手写!

Github网址:https://github.com/diaoquesang/pytorchTutorials/tree/main

本教程创建于2023/7/31,几乎所有代码都有对应的注释,帮助初学者理解dataset、dataloader、transform的封装,初步体验调参的过程,初步掌握opencv、pandas、os等库的使用,😋纯手撸手写数字识别项目(为减少代码量简化了部分数据集相关操作),全流程跑通Pytorch!❤️❤️❤️
This tutorial was created on 2023/7/31. Almost all the code has corresponding comments, to help beginners understand dataset, dataloader, transform packaging, preliminary experience of the process of tuning the parameters, the initial grasp of the use of libraries such as opencv, pandas, os, etc., 😋 and get involved in this handwritten digit recognition project (we simplified some dataset-related operations in order to reduce the amount of code). Enjoy the whole process of running Pytorch!❤️❤️❤️

如果喜欢本项目的话,留下你的⭐吧!
Give me a ⭐ if you like this project!

一、train.py

import torch
import torchvision

from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

import os
import cv2 as cv
import pandas as pd


class myDataset(Dataset):  # 定义数据集类
    def __init__(self, annotations_file, img_dir, transform=None,
                 target_transform=None):  # 传入参数(标签路径,图像路径,图像预处理方式,标签预处理方式)
        self.img_labels = pd.read_csv(annotations_file, sep=" ", header=None)
        # 从标签路径中读取标签,sep为划分间隔符,header为列标题的行位置
        self.img_dir = img_dir  # 读取图像路径
        self.transform = transform  # 读取图像预处理方式
        self.target_transform = target_transform  # 读取标签预处理方式

    def __len__(self):
        return len(self.img_labels)  # 读取标签数量作为数据集长度

    def __getitem__(self, idx):  # 从数据集中取出数据
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[
            idx, 0])
        # 从标签对象中取出第idx行第0列(第0列为图像位置所在列)的值(numberImages\5.bmp),并与图像路径(numberImages)进行拼接
        image = cv.imread(img_path)  # 用openCV的imread函数读取图像
        label = self.img_labels.iloc[idx, 1]  # 从标签对象中取出第idx行第1列(第1列为图像标签所在列)的值(5)
        if self.transform:
            image = self.transform(image)  # 图像预处理
        if self.target_transform:
            label = self.target_transform(label)  # 标签预处理
        return image, label  # 返回图像和标签


class myTransformMethod1():  # Python3默认继承object类
    def __call__(self, img):  # __call___,让类实例变成一个可以被调用的对象,像函数
        img = cv.resize(img, (28, 28))  # 改变图像大小
        img = cv.cvtColor(img, cv.COLOR_BGR2RGB)  # 将BGR(openCV默认读取为BGR)改为RGB
        return img  # 返回预处理后的图像

# 测试函数
# print(pd.read_csv("annotations.txt", sep=" ", header=None))
# print(os.path.join("numberImages", pd.read_csv("annotations.txt", sep=" ", header=None).iloc[5, 0]))
# print(pd.read_csv("annotations.txt", sep=" ", header=None).iloc[5, 1])
# cv.imshow("1",cv.imread(os.path.join("numberImages", pd.read_csv("annotations.txt", sep=" ", header=None).iloc[5, 0])))
# cv.waitKey(0)


class myNetwork(nn.Module):  # 定义神经网络
    def __init__(self):
        super().__init__()  # 继承nn.Module的构造器
        self.flatten = nn.Flatten(-3, -1)
        # 继承nn.Module的Flatten函数并改为flatten,考虑到推理时没有batch(CHW),若使用默认值(1,-1)会导致C没有被flatten,故使用(-3,-1)
        self.linear_relu_stack = nn.Sequential(  # 定义前向传播序列
            nn.Linear(3 * 28 * 28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10),
        )

    def forward(self, x):  # 定义前向传播方法
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits


# 设置运行环境,默认为cuda,若cuda不可用则改为mps,若mps也不可用则改为cpu
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")  # 输出运行环境

model = myNetwork().to(device)  # 创建神经网络模型实例

# 设置超参数
learning_rate = 1e-5  # 学习率
batch_size = 8  # 每批数据数量
epochs = 3000  # 总轮数

img_path = "./numberImages"  # 设置图像路径
label_path = "./annotations.txt"  # 设置标签路径

myTransform = transforms.Compose([myTransformMethod1(), transforms.ToTensor()])
# 定义图像预处理组合,ToTensor()中Pytorch将HWC(openCV默认读取为height,width,channel)改为CHW,并将值[0,255]除以255进行归一化[0,1]

myDataset = myDataset(label_path, img_path, myTransform)  # 创建数据集实例

myDataLoader = DataLoader(myDataset, batch_size=batch_size,
                          shuffle=True)
# 创建数据读取器(可对训练集和测试集分别创建),batch_size为每批数据数量(一般为2的n次幂以提高运行速度),shuffle为随机打乱数据

def train():
    # 根据epochs(总轮数)训练
    for epoch in range(epochs):
        totalLoss = 0
        # 分批读取数据
        for batch, (images, labels) in enumerate(myDataLoader):
            # 数据转换到对应运行环境
            images = images.to(device)
            labels = labels.to(device)

            pred = model(images)  # 前向传播

            myLoss = nn.CrossEntropyLoss()  # 定义损失函数(交叉熵)
            optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)  # 定义优化器

            loss = myLoss(pred, labels)  # 计算损失函数

            totalLoss += loss  # 计入总损失函数

            loss.backward()  # 反向传播
            optimizer.step()  # 更新权重
            optimizer.zero_grad()  # 清空梯度

            if batch % 1 == 0:  # 每隔1个batch输出1次loss
                loss, current = loss.item(), min((batch + 1) * batch_size,len(myDataset))
                print(f"epoch: {epoch:>5d} loss: {loss:>7f}  [{current:>5d}/{len(myDataset):>5d}]")

        if epoch == 0:
            minTotalLoss = totalLoss
        if totalLoss < minTotalLoss:
            print("······························模型已保存······························")
            minTotalLoss = totalLoss
            torch.save(model, "./myModel.pth")  # 保存性能最好的模型


if __name__ == "__main__":
    model.train()  # 设置训练模式
    train()

二、eval.py

import torch
import torchvision

from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

import os
import cv2 as cv
import pandas as pd


class myTransformMethod1():  # Python3默认继承object类
    def __call__(self, img):  # __call___,让类实例变成一个可以被调用的对象,像函数
        img = cv.resize(img, (28, 28))  # 改变图像大小
        img = cv.cvtColor(img, cv.COLOR_BGR2RGB)  # 将BGR(openCV默认读取为BGR)改为RGB
        return img  # 返回预处理后的图像


class myNetwork(nn.Module):  # 定义神经网络
    def __init__(self):
        super().__init__()  # 继承nn.Module的构造器
        self.flatten = nn.Flatten(-3, -1)
        # 继承nn.Module的Flatten函数并改为flatten,考虑到推理时没有batch(CHW),若使用默认值(1,-1)会导致C没有被flatten,故使用(-3,-1)
        self.linear_relu_stack = nn.Sequential(  # 定义前向传播序列
            nn.Linear(3 * 28 * 28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10),
        )

    def forward(self, x):  # 定义前向传播方法
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits


if __name__ == "__main__":
    model = torch.load("./myModel.pth").to("cuda")  # 载入模型
    model.eval()  # 设置推理模式
    myTransform = transforms.Compose([myTransformMethod1(), transforms.ToTensor()])
    # 定义图像预处理组合,ToTensor()中Pytorch将HWC(openCV默认读取为height,width,channel)改为CHW,并将值[0,255]除以255进行归一化[0,1]
    for i in range(10):
        img = cv.imread("./numberImages/"+str(i)+".bmp")  # 用openCV的imread函数读取图像
        img = myTransform(img).to("cuda")  # 图像预处理
        print(torch.argmax(model(img)))

三、其余资料详见Github

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/52933.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

CTF:信息泄露.(CTFHub靶场环境)

CTF&#xff1a;信息泄露.&#xff08;CTFHub靶场环境&#xff09; “ 信息泄露 ” 是指网站无意间向用户泄露敏感信息&#xff0c;泄露了有关于其他用户的数据&#xff0c;例如&#xff1a;另一个用户名的财务信息&#xff0c;敏感的商业 或 商业数据 &#xff0c;还有一些有…

读取application-dev.properties的中文乱码【bug】

读取application-dev.properties的中文编码【bug】 2023-7-30 22:37:46 版权 禁止其他平台发布时删除以下此话 本文首次发布于CSDN平台 作者是CSDN日星月云 博客主页是https://blog.csdn.net/qq_51625007 禁止其他平台发布时删除以上此话 bug 读取application-dev.propert…

2023年的深度学习入门指南(20) - LLaMA 2模型解析

2023年的深度学习入门指南(20) - LLaMA 2模型解析 上一节我们把LLaMA 2的生成过程以及封装的过程的代码简单介绍了下。还差LLaMA 2的模型部分没有介绍。这一节我们就来介绍下LLaMA 2的模型部分。 这一部分需要一些深度神经网络的基础知识&#xff0c;不懂的话不用着急&#xf…

建木使用进阶-创建密钥管理

阿丹&#xff1a; 第一次我们进入建木&#xff0c;第一件事情就是配置我们相关的密钥。 解读&#xff1a; 在建木中我们可以进行创建密钥来对我们服务器等密码进行方便的管理。 注意&#xff1a; 登录的时候账号为&#xff1a;admin 密码为&#xff1a;123456 这是初始…

Windows环境下git客户端中的git-bash和MinGW64

我们在 Windows10 操作系统下&#xff0c;安装了 git 客户端之后&#xff0c;可以通过 git-bash.exe 打开一个 shell&#xff1a; 执行一些 linux 系统里的命令&#xff1a; 注意到上图紫色的 MINGW64. Mingw-w64 是原始 mingw.org 项目的改进版&#xff0c;旨在支持 Window…

【playbook】Ansible的脚本----playbook剧本

Ansible的脚本----playbook剧本 1.playbook剧本组成2.playbook剧本实战演练2.1 实战演练一&#xff1a;给被管理主机安装Apache服务2.2 实战演练二&#xff1a;使用sudo命令将远程主机的普通用户提权为root用户2.3 实战演练三&#xff1a;when条件判断指定的IP地址2.4 实战演练…

SpringBoot中ErrorPage(错误页面)的使用--【ErrorPage组件】

SpringBoot系列文章目录 SpringBoot知识范围-学习步骤–【思维导图知识范围】 文章目录 SpringBoot系列文章目录本系列校训 SpringBoot技术很多很多环境及工具&#xff1a;必要的知识深层一些的知识 上效果图在Spring Boot里使用ErrorPage还要注意的是 配套资源作业&#xff…

使用Windbg分析从系统应用程序日志中找到的系统自动生成的dump文件去排查问题

目录 1、尝试将Windbg附加到目标进程上进行动态调试&#xff0c;但Windbg并没有捕获到 2、在系统应用程序日志中找到了系统在程序发生异常时自动生成的dump文件 2.1、查看应用程序日志的入口 2.2、在应用程序日志中找到系统自动生成的dump文件 3、使用Windbg静态分析dump文…

Mysql的锁

加锁的目的 对数据加锁是为了解决事务的隔离性问题&#xff0c;让事务之前相互不影响&#xff0c;每个事务进行操作的时候都必须先加上一把锁&#xff0c;防止其他事务同时操作数据。 事务的属性 &#xff08;ACID&#xff09; 原子性 一致性 隔离性 持久性 事务的隔离级别 锁…

大数据课程D4——hadoop的YARN

文章作者邮箱&#xff1a;yugongshiyesina.cn 地址&#xff1a;广东惠州 ▲ 本章节目的 ⚪ 了解YARN的概念和结构&#xff1b; ⚪ 掌握YARN的资源调度流程&#xff1b; ⚪ 了解Hadoop支持的资源调度器&#xff1a;FIFO、Capacity、Fair&#xff1b; ⚪ 掌握YA…

jenkins自定义邮件发送人姓名

jenkins发送邮件的时候发送人姓名默认的&#xff0c;如果要自定义发件人姓名&#xff0c;只需要修改如下信息即可&#xff1a; 系统管理-system-Jenkins Location下的系统管理员邮件地址 格式为&#xff1a;自定义姓名<邮件地址>

三分钟白话RocketMQ系列—— 核心概念

目录 关键字摘要 Q1&#xff1a;RocketMQ是什么&#xff1f; Q2: 作为消息中间件&#xff0c;RocketMQ和kafka有什么区别&#xff1f; Q3: RocketMQ的基本架构是怎样的&#xff1f; Q4&#xff1a;RocketMQ有哪些核心概念&#xff1f; 总结 RocketMQ是一个开源的分布式消…

测试|测试分类

测试|测试分类 文章目录 测试|测试分类1.按照测试对象分类&#xff08;部分掌握&#xff09;2.是否查看代码&#xff1a;黑盒、白盒灰盒测试3.按开发阶段分&#xff1a;单元、集成、系统及验收测试4.按实施组织分&#xff1a;α、β、第三方测试5.按是否运行代码&#xff1a;静…

SpringMVC程序开发

1.什么是Spring MVC? Spring Web MVC是基于Servlet API构建的原始的Web框架&#xff0c;从一开始是就包含在Spring框架中。它的正式名称“Spring Web MVC"来自其源模板的名称&#xff08;Spring-webmvc)&#xff0c;但通常被称为“Spring MVC" 从上述的定义我们可…

Unity游戏源码分享-ARPG游戏Darklight.rar

Unity游戏源码分享-ARPG游戏Darklight.rar 玩法 项目地址&#xff1a;https://download.csdn.net/download/Highning0007/88105464

Android Studio 的版本控制Git

Android Studio 的版本控制Git。 Git 是最流行的版本控制工具&#xff0c;本文介绍其在安卓开发环境Android Studio下的使用。 本文参考链接是&#xff1a;https://learntodroid.com/how-to-use-git-and-github-in-android-studio/ 一&#xff1a;Android Studio 中设置Git …

Flowable-服务-微服务任务

目录 定义图形标记XML内容界面操作 定义 Sc 任务不是 BPMN 2.0 规范定义的官方任务&#xff0c;在 Flowable 中&#xff0c;Sc 任务是作为一种特殊的服务 任务来实现的&#xff0c;主要调用springcloud的微服务使用。 图形标记 由于 Sc 任务不是 BPMN 2.0 规范的“官方”任务…

在腾讯云服务器OpenCLoudOS系统中安装mysql(有图详解)

1. 创建MySQL安装目录 mkdir -p app/soft//mysql 2. 进入MySQL安装目录&#xff0c;下载&#xff0c;安装 cd /app/soft/mysql/ wget http://dev.mysql.com/get/mysql-5.7.26-1.el7.x86_64.rpm-bundle.tar 得到安装包&#xff1a; 解压安装包&#xff1a; 查看系统是否自带…

ts一些常用符号

非空断言操作符(!) 具体是指在上下文中当类型检查器无法断定类型时&#xff0c;一个新的后缀表达式操作符 ! 可以用于断言操作对象是非 null 和非 undefined 类型。具体而言&#xff0c;x! 将从 x 值域中排除 null 和 undefined 。 1. 赋值时忽略 undefined 和 null function…

JVM源码剖析之JIT工作流程

版本信息&#xff1a; jdk版本&#xff1a;jdk8u40思想至上 Hotspot中执行引擎分为解释器、JIT及时编译器&#xff0c;上篇文章描述到解释器过度到JIT的条件。JVM源码剖析之达到什么条件进行JIT优化 这篇文章大致讲述JIT的编译过程。在JDK中javac和JIT两部分跟编译原理挂钩&a…