Pytorch入门实战 P1-实现手写数字识别

目录

一、前期准备(环境+数据)

1、首先查看我们电脑的配置;

2、使用datasets导入MNIST数据集

3、使用dataloader加载数据集

4、数据可视化

二、构建简单的CNN网络

三、训练模型

1、设置超参数

2、编写训练函数

3、编写测试函数

4、正式训练

四、结果可视化


  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊 | 接辅导、项目定制

一、前期准备(环境+数据)

编辑器:Pycharm

环境语言:python、pytorch

1、首先查看我们电脑的配置;

即:看看我们电脑是CPU版本还是GPU的。

import torch
torch.cuda.is_available()

# 返回 False,则是CPU版本;反之是GPU版本

查看自己的电脑配置后,一般写代码的时候,只需要看是CPU或GPU,然后根据不同的版本运行代码。一般我们会选择使用判断语句这样写:

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

不懂得torch.device() 的,我把官网附在下面了。 

2、使用datasets导入MNIST数据集

本文中,我们主要用来实现手写数字的识别,因此,我们得先得到手写数字数据集,即MNIST数据集。

【MNIST数据集背景介绍】:

        手写数据集,如MNIST,是一个经典的机器学习数据集,主要用于手写数字识别。

        这个数据集包含了来自250个不同的人手写数字图片,其中50%是学生,50%来自人口普查局的工作人员。训练集一共包含了60,000张图像和标签,而测试集一共包含了10,000张图像和标签。测试集中前5000个来自最初NIST项目的训练集,后5000个来自最初MNIST项目的测试集。前5000个比后5000个要规整,这是因为前5000个数据来自于美国人口普查局的员工,他们的书写相对更标准,而后5000个来自于大学生,书写风格可能更多样。

        该数据集的收集目的是希望通过算法,实现对手写数字的识别。在手写数字识别分类中,每个样本都是一个28x28像素的灰度图像,表示手写数字0到9。

总的来说,手写数据集为机器学习领域的研究者提供了一个标准化的、大规模的、有挑战性的数据集,有助于推动手写数字识别等相关技术的发展。


【导入MNIST数据集】:

首先,使用datasets下载MNIST数据,并划分好训练集和测试集。

import torchvision 

# 训练集数据

train_ds = torchvision.datasets.MNIST('data',
                                       train=True,
                                       transform=torchvision.transforms.ToTensor(),
                                        download=True)
# 测试集数据
test_ds = torchvisio.datasets.MNIST('data',
                               train=False,
                               transform=torchvision.transforms.ToTensor(),
                               download=True)

我们先来看下原型:

torchvision.datastes.MNIST( root,train=True, transform=None, download=False)

其中:

        root是要把下载的数据集存入的文件夹的名字。

        train:True表示是训练集;False表示是测试集;

        transform: 这里的参数,选择一个你想要的数据转换函数,直接完成数据转化。

        download:True 从互联网上下载数据集,并把数据集放在root目录下。


下载完成后的目录是这样的:(如果已经下载一次了,后续就可以把download改为False,不然每次运行都会下载

3、使用dataloader加载数据集

使用dataloader加载数据集,并设置好基本的batch_size。

import torch 

batch_size = 32   # 每批加载样本的大小

# 加载训练集数据
train_dl = torch.utils.data.DataLoader(train_ds,
                                       batch_size=batch_size,
                                       shuffle=True)   # 每个epoch重新排列数据

# 加载测试集数据
test_dl = torch.utils.data.DataLoader(test_ds,
                                       batch_size=batch_size)

我们可以取一个批次查看下数据的格式:

imgs,labels = next(iter(train_dl))

print(imgs.shape)   # 得到结果是   torch.Size([32,1,28,28])

其中:我们得到的数据的shape位:[batch_size , channel, height, weight]

                batch_size :是我们自己设置的(上面的代码中有设置过)

                channel: 通道数  (黑白图像一般的通道数为1;RGB格式图像的通道数为:3)

                height: 图片的高度

                weight: 图片的宽度

train_dl 就是我们上面的使用dataloader加载的训练集的数据。

iter(train_dl) 将数据加载器转换为一个迭代器(iterator),使得我们可以使用Python的next()函数来逐个访问数据加载器中的元素。

next()  函数用于获取迭代器中的下一个元素。这里,它被用来获取train_dl中的下一批量数据。


4、数据可视化

数据可视化,就是使用代码展示下,我们上面获取的数据(获取20个数字的图片)。

plt.figure(figsize=(20,5))
for i,imgs in enumerate(imgs[:20]):
    # 维度缩减
    npimg = np.squeeze(imgs.numpy())
    plt.subplot(2,10,i+1)  # 指定划分的行数、列数及子图的索引。
    plt.imshow(npimg,cmap=plt.cn.binary)  # 展示图片,以cmap给的色彩展示
    plt.axis('off')  # 关闭坐标轴
plt.show()  # 展示图片

 运行结果展示:


至此,我们的前期准备工作准备结束。我们即将进入第二部分!!!

二、构建简单的CNN网络

我们现在简单看下上面这个图,上图里面是一个简单的CNN网络图。

依次包括:输入层、卷积层1、池化层1、卷积层2、池化层2、全连接层1、全连接层2、全连接层3、输出层。 

对于一般的CNN网络来说,都是由【特征网络】和【分类网络】构成的。

①nn.Conv2d  为卷积层,用于提取图片的特征,传入参数为:input_channel、out_channel、kernal_size;

②nn.MaxPool2d 为池化层,进行下采样,更高层的抽象表示图像特征,传入参数为:kernal_size

③nn.ReLU 为激活函数,使得模型可以拟合非线性数据。

④nn.Squential 可以按构造顺序连接网络,在初始化阶段就设定好网络结构,不需要在前向传播中重新写一遍。


下面的代码,我们以这个图为例,两层卷积、两层池化、全连接层。

num_classes = 10   # 图片的类别数


class Model(nn.Module):
    def __init__(self):
        super().__init__()
        # 特征提取网络
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3)   # 输入图像的通道数、输出图像的通道数、卷积核大小  (RGB图像的输入通道数为3)
        self.pool1 = nn.MaxPool2d(2)                    # 设置池化层,池化核大小为2*2
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3)   # 第二层卷积,卷积核大小为3*3
        self.pool2 = nn.MaxPool2d(2)

        # 分类网络
        self.fc1 = nn.Linear(1600,64)
        self.fc2 = nn.Linear(64,num_classes)

    # 前向传播
    def forward(self,x):
        x = self.pool1(F.relu(self.conv1(x)))
        x = self.pool2(F.relu(self.conv2(x)))

        x = torch.flatten(x,start_dim=1)

        x = F.relu(self.fc1(x))
        x = self.fc2(x)

        return x
# 打印并加载模型
model = Model().to(device)
print(model)
print("查看模型信息:")
summary(model)

​​​​​​​

这个步骤很重要,这块就是我们一般修改模型的地方。如果是写论文的话,这里是很重要的。因为是刚开始学习,就先能大概了解就行,后续我还会继续学习的。也会多多更新这里的。

三、训练模型

我们现在已经构建好了CNN的网络模型,那么就开始设置一些参数训练模型吧。

1、设置超参数

loss_fn = nn.CrossEntropyLoss()  # 创建损失函数
learn_rate = 1e-1  # 学习率
opt = torch.optim.SGD(model.parameters(), lr=learn_rate)

2、编写训练函数

'''
    1、optimizer.zero_grad() 函数会遍历模型的所有参数,通过内置方法截断反向传播的梯度流,再将每个参数的梯度值设为0,即上次的梯度记录会被清空。
    2、loss.backward() Pytorch的反向传播(即:tensor.backward())是通过autograd包来实现的,autograd包会根据tensor进行过的数学运算来自动计算其对应的梯度。
    3、optimizer.step()  step() 函数的作用是执行一次优化步骤,通过梯度下降法来更新参数的值。因为梯度下降是基于梯度的,所以在执行optimizer.step() 函数前应先指向那个loss.backward()函数来计算梯度。
'''

# 训练循环
print('准备进入----训练集里面')


def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)  # 训练集的大小,一共60000张图片
    num_batches = len(dataloader)  # 批次数目,1875(60000/32)

    train_loss, train_acc = 0, 0  # 初始化训练损失和正确率

    for X, y in dataloader:  # 获取图片及其标签
        X, y = X.to(device), y.to(device)

        # 计算预测误差
        pred = model(X)  # 网络输出
        loss = loss_fn(pred, y)  # 计算网络输出和真实值之间的差距,targets为真实值,计算二者差值,即为损失。

        # 反向传播   (以下三个基本上是固定的)
        optimizer.zero_grad()  # grade属性归零
        loss.backward()  # 反向传播
        optimizer.step()  # 每一步自动更新

        # 记录acc与loss
        train_acc += (pred.argmax(1) == y).type(
            torch.float).sum().item()  # 表示计算预测正确的样本数量,并将其作为一个标量值返回。这通常用于评估分类模型的准确率或计算分类问题的正确预测数量。
        '''
            pred.argmax(1)返回数组pred在第一个轴(即行)上最大值所在的索引。这通常用于多分类问题中,其中pred是一个包含预测概率的二维数组,每行表示一个样本的预测概率分布。
            pred.argmax(1) == y是一个布尔值,其中等号是否成立代表对应样本的预测是否正确。(True表示正确,False表示错误)

            .type(torch.float)是将布尔数组的数据类型转换为浮点数类型,即将True转换为1.0;将False转换为0.0
            .sum() 是对数组中的元素进行求和,计算出预测正确的样本数量。
            .item() 将求和结果转换为标量值,以便在Python中使用或打印。
        '''
        train_loss += loss.item()
    train_acc /= size
    train_loss /= num_batches

    return train_acc, train_loss

3、编写测试函数

测试函数、训练函数大致相同,但是由于不进行梯度下降对网络权重进行更新,所以不需要传入优化器。
def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)  # 测试集的大小,一共10000张图片
    num_batches = len(dataloader)  # 批次数目313 (10000/32=321.5 向上取整)
    test_loss, test_acc = 0, 0

    # 当不进行训练时,停止梯度更新,节省计算内存消耗
    with torch.no_grad():
        for imgs, target in dataloader:
            imgs, target = imgs.to(device), target.to(device)

            # 计算loss
            target_pred = model(imgs)
            loss = loss_fn(target_pred, target)

            test_loss += loss.item()
            test_acc += (target_pred.argmax(1) == target).type(torch.float).sum().item()

    test_acc /= size
    test_loss /= num_batches

    return test_acc, test_loss

4、正式训练

epochs = 5
train_loss = []
train_acc = []
test_loss = []
test_acc = []

for epoch in range(epochs):  # epoch 索引值
    model.train()  # 启用Batch Normalization和Dropout
    '''
        如果模型中有BN(Batch Normalization)和Dropout ,需要在训练时添加model.train() 。 model.train() 是保证BN层能够用到每一批数据的均值和方差。
        对于Dropout ,model.train() 是随机取一部分网络连接来训练更新参数。
    '''
    epoch_train_acc, epoch_train_loss = train(train_dl, model, loss_fn, opt)

    model.eval()  # 不启用Batch Normalization 和Dropout
    '''
        如果模型中有BN(Batch Normalization)和Dropout ,需要在测试时添加model.eval() . model.eval() 是保证BN层能够用全部训练数据的均值和方差,
        即:测试过程中要保证BN层的均值和方差不变。对于Dropout, model.eval() 是利用到了所有网络连接,即:不进行随机舍弃神经元。

        训练完train样本后,生成的模型model要用来测试样本。在model(test)之前,需要加上model.eval(),否则的话,有输入数据,即使不训练,它也会改变权值。
        这是model中还有BN层和Dropout所带来的性质。
    '''
    epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_fn)

    train_acc.append(epoch_train_acc)
    train_loss.append(epoch_train_loss)
    test_acc.append(epoch_test_acc)
    test_loss.append(epoch_test_loss)

    template = 'Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f},Test_acc:{:.1f}%,Test_loss:{:.3f}'
    print(template.format(epoch + 1, epoch_train_acc * 100, epoch_train_loss, epoch_test_acc * 100, epoch_test_loss))
print('Done')

四、结果可视化

结果可视化,主要使用的是import matplotlib.pyplot as plt 的绘图。

warnings.filterwarnings('ignore')  # 忽略警告信息
# plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签
plt.rc('font', family='PingFang HK')
plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号
plt.rcParams['figure.dpi'] = 100  # 分辨率

epochs_range = range(epochs)

plt.figure(figsize=(12, 3))
plt.subplot(1, 2, 1)

plt.plot(epochs_range, train_acc, label='Training Accuracy')
plt.plot(epochs_range, test_acc, label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_loss, label='Train Loss')
plt.plot(epochs_range, test_loss, label="Test Loss")
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

上述我们已经先开始训练模型了,正式训练模型,我们会得到很多数据,因此,我们要用一些可视化工具来清晰的展示,我们得到的数据,看看训练数据和测试数据会有哪些差异呢。

由于我本地电脑跑起来,风扇呼呼响。这里我用的是谷歌提供的免费的 工具跑的代码,训练数据共花费3分钟左右。


至此,我们使用Pytorch完成了手写数字识别,也算是一个简单的基础入门实战啦。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/437166.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

双碳目标下DNDC模型建模方法及在土壤碳储量、温室气体排放、农田减排、土地变化、气候变化中的应用

由于全球变暖、大气中温室气体浓度逐年增加等问题的出现,“双碳”行动特别是碳中和已经在世界范围形成广泛影响。国家领导人在多次重要会议上讲到,要把“双碳”纳入经济社会发展和生态文明建设整体布局。同时,提到要把减污降碳协同增效作为促…

浅析extern关键字

C中extern关键字的使用 文章目录 C中extern关键字的使用前言正文1. C与C编译区别2. C调用C函数3. C中调用C函数 总结 前言 ​ C 是一种支持多范式的编程语言,它既可以实现面向对象的编程,也可以实现泛型编程和函数式编程。C 还具有与C语言的兼容性&…

大数据最佳实践

本文主要收录一些大数据不错的实践文章 1、数禾云上数据湖最佳实践 https://blog.51cto.com/u_15089766/2601706 该文章介绍了数禾云的数据胡实践,包含presto以及数据湖等组件的一些部署架构,文章听不错的,里面提到了为了避免presto与yarn计…

【Kaggle】练习赛《肥胖风险的多类别预测》

前言 作为机器学习的初学者,Kaggle提供了一个很好的练习和学习平台,其中有一个栏目《PLAYGROUND》,可以理解为游乐场系列赛,提供有趣、平易近人的数据集,以练习他们的机器学习技能,并每个月都会有一场比赛…

【开源】SpringBoot框架开发快乐贩卖馆管理系统

目录 一、摘要1.1 项目介绍1.2 项目录屏 二、功能模块2.1 数据中心模块2.2 搞笑视频模块2.3 视频收藏模块2.4 视频评分模块2.5 视频交易模块2.6 视频好友模块 三、系统设计3.1 用例设计3.2 数据库设计3.2.1 搞笑视频表3.2.2 视频收藏表3.2.3 视频评分表3.2.4 视频交易表 四、系…

【剑指offer--C/C++】JZ6 从尾到头打印链表

一、题目 二、本人思路及代码 直接在链表里进行翻转不太方便操作,但是数组就可以通过下标进行操作,于是, 思路1、 先遍历链表,以此存到vector中,然后再从后往前遍历这vector,存入到一个新的vector,就完成…

OPC UA 学习笔记:状态机/有限状态机

有限状态机 有限状态机 (FSM) 是程序员、数学家、工程师和其他专业人士用来描述具有有限数量条件状态的系统的数学模型。 有限状态机的构成包括以下内容: 一组潜在的输入事件。与潜在输入事件相对应的一组可能的输出事件。系统可以显示的一…

dubbo3适配springboot2.7.3

版本详细 <dependency><groupId>org.apache.dubbo</groupId><artifactId>dubbo</artifactId><version>3.0.3</version> </dependency><parent><groupId>org.springframework.boot</groupId><artifactId&…

13年测试老鸟,接口性能测试-压测总结汇总,一文概全...

目录&#xff1a;导读 前言一、Python编程入门到精通二、接口自动化项目实战三、Web自动化项目实战四、App自动化项目实战五、一线大厂简历六、测试开发DevOps体系七、常用自动化测试工具八、JMeter性能测试九、总结&#xff08;尾部小惊喜&#xff09; 前言 1、概述 性能测试…

LVS负载均衡群集之NAT与DR模式

一 集群和分布式 企业群集应用概述 群集的含义 Cluster&#xff0c;集群、群集 由多台主机构成&#xff0c;但对外只表现为一个整体&#xff0c;只提供一个访问入口(域名或IP地址)&#xff0c;相当于一台大型计算机。 问题&#xff1f; 互联网应用中&#xff0c;随着站点对…

leetCode刷题 4.寻找两个正序数组的中位数

目录 1. 思路 2. 解题方法 3. 复杂度 4. Code 题目&#xff1a; 给定两个大小分别为 m 和 n 的正序&#xff08;从小到大&#xff09;数组 nums1 和 nums2。请你找出并返回这两个正序数组的 中位数 。 算法的时间复杂度应该为 O(log (mn)) 。 示例 1&#xff1a; 输入&…

重磅!云智慧推出轻量智能化服务管理平台轻帆云

近日&#xff0c;云智慧推出智能服务管理平台轻帆云&#xff0c;通过构建服务体系、规范服务流程、保障服务质量、提升服务效能&#xff0c;为企业提供安全可靠的一站式服务管理解决方案。SaaS轻量化部署方式&#xff0c;仅需通过简单操作&#xff0c;即可轻松完成搭建&#xf…

Java EE之线程安全问题

一.啥是线程安全问题 有些代码&#xff0c;在单个线程执行时完全正确&#xff0c;但同样的代码让多个线程同时执行&#xff0c;就会出现bug。例如以下代码&#xff1a; 给定一个变量count&#xff0c;让线程t1 t2分别自增5000次&#xff0c;然后进行打印&#xff0c;按理说co…

libftdi库编译

目录 1. 下载源码 2. Ubuntu下编译 2.1 配置编译环境 2.2 编译 3. Android NDK下编译 3.1 编译libconfuse 3.2 编译libusb 3.3 编译libudev 3.3 编译libftdi 分2部分&#xff0c;先在Ubuntu中编译&#xff0c;然后在Android NDK中编译。 1. 下载源码 下载地址&#…

企业财务分析该怎么做?重点分析哪些财务指标?

在企业经营管理的过程中&#xff0c;财务分析是评估当前企业或特定部门财务状况和绩效的过程&#xff0c;这一过程通常涉及对财务报表&#xff08;如资产负债表、利润表和现金流量表&#xff09;进行定量和定性的评估&#xff0c;以便为盈利能力、偿债能力、现金流动性和资金稳…

VMware虚拟机安装Linux教程(超详细)

目录 一、安装VMware VMware下载&#xff08;16 pro&#xff09;&#xff1a; 镜像文件&#xff08;不一定选择CentOS&#xff0c;只是为了有图形界面更好的操作)​ 安装VMware 安装虚拟机 第一步&#xff1a;点击创建新的虚拟机。​ 第二步&#xff1a;选择自定义 &…

HTML结构及常见标签

1.HTML结构 认识 HTML 标签 HTML 代码是由 " 标签 " 构成的 . 形如 : <body> hello </body> <body id "myId" > hello </body> 标签名 (body) 放到 < > 中 大部分标签成对出现 . <body> 为开始标签 , …

ant-desgin charts双轴图DualAxes,柱状图无法立即显示,并且只有在调整页面大小(放大或缩小)后才开始显示

摘要 双轴图表中&#xff0c;柱状图无法立即显示&#xff0c;并且只有在调整页面大小&#xff08;放大或缩小&#xff09;后才开始显示 官方示例代码 在直接复制&#xff0c;替换为个人数据时&#xff0c;出现柱状图无法显示问题 const config {data: [data, data],xFiel…

Kubernetes-3

Kubernetes学习第3天 Kubernetes-31、查看实时的cpu和内存消耗1.1、kubectl top node 2、卷的使用2.1、什么是卷&#xff1f;1. 解决数据持久性问题2. Kubernetes 中的卷抽象概念3. 共享数据示例4. Kubernetes 中的卷使用5. 不同类型的卷6. 灵活、可靠的数据管理 2.2、联想到do…

CVE-2024-27198 JetBrains TeamCity 身份验证绕过漏洞分析

漏洞简介 JetBrains TeamCity 是一款由 JetBrains 公司开发的持续集成和持续交付服务器。它提供了强大的功能和工具&#xff0c;旨在帮助开发团队构建、测试和部署他们的软件项目 JetBrains TeamCity发布新版本修复了两个高危漏洞JetBrains TeamCity 身份验证绕过漏洞(CVE-20…