手写数字识别Minst(CNN)

文章目录

  • 手写数字识别
    • 网络结构
    • 加载数据集
    • 数据集可视化
    • CNN网络结构
    • 训练模型
    • 保存模型和加载模型
    • 测试模型

手写数字识别

网络结构

网上给出的基本网络结构:
在这里插入图片描述
然而在本数据集中,输入图不是1*32*32,是1*28*28。所以正确的网络结构应该是

levelinputstrideoutput
11*28*286*5*516*24*24
MaxPool6*24*24MaxPool26*12*12
26*12*1216*5*5116*8*8
MaxPool16*8*8MaxPool216*4*4
Flatten16*4*4Flatten256
3FC256FC120
4FC120FC84
5FC84FC10

加载数据集

# -*-coding =utf-8 -*-
import torch
import matplotlib.pyplot as plt
import torchvision

# 定义数据转换
transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.1307,), (0.3081,))
])

# 加载数据集
batch_size=32
path = r'05data'
train_dataset = torchvision.datasets.MNIST(root=path, train=True,transform=transform,download =False)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataset = torchvision.datasets.MNIST(root=path, train=True,transform=transform,download =False)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
# loader.shape=1875*[32*1*28*28,32]

最后loader.shape是1875*[32*1*28*28,32],即 number*[batch(data)*height*width, batch(label)]

数据集可视化


from sklearn.preprocessing import MinMaxScaler
# 归一化转为[0,255]
transfer=MinMaxScaler(feature_range=(0, 255)) 
def visualize_loader(batch,predicted=''): 
    # batch=[32*1*28*28,32]
    imgs=batch[0].squeeze().numpy() # 消squeeze()一维
    fig, axes = plt.subplots(4, 8, figsize=(12, 6))
    labels=batch[1].numpy()
    if str(predicted)=='':
        predicted=labels
    for i, ax in enumerate(axes.flat):
        ax.imshow(imgs[i])
        ax.set_title(predicted[i],color='black' if predicted[i]==labels[i] else 'red')
        ax.axis('off')
    plt.tight_layout()
    plt.show()

# loader.shape=1875*[32*1*28*28,32]
for batch in train_loader:       
    break
visualize_loader(batch)

在这里插入图片描述
上图是对数据集的可视化。

CNN网络结构

在PyTorch的torch.nn模块中,卷积函数Conv2d的输入张量的形状应为[batch_size, channels, height, width]对应数据集,无需修改(在一些架构中,可能是[batch_size, height, width, channels])。

# 创建模型
import torch
import torch.nn as nn
import torch.optim as optim

# 定义模型
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, kernel_size=5, stride=1)
        self.relu = nn.ReLU()
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(6, 16, kernel_size=5, stride=1)
        self.flatten=nn.Flatten()
        self.fc3 = nn.Linear(256, 120)
        self.fc4 = nn.Linear(120, 84)
        self.fc5 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.conv2(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.flatten(x)
        x = self.fc3(x)
        x = self.relu(x)
        x = self.fc4(x)
        x = self.relu(x)
        x = self.fc5(x)
        return x

打印模型结构

model = CNN()
print(model)
CNN(
  (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
  (relu): ReLU()
  (maxpool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (fc3): Linear(in_features=256, out_features=120, bias=True)
  (fc4): Linear(in_features=120, out_features=84, bias=True)
  (fc5): Linear(in_features=84, out_features=10, bias=True)
)

训练模型

import torch.optim as optim

num_epochs=1
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    for images, labels in train_loader:
        images = images.to(device)
        labels = labels.to(device)
        
        # 前向传播
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        # 反向传播和优化
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # 统计准确率
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        
        running_loss += loss.item()
    
    train_loss = running_loss / len(train_loader)
    train_accuracy = correct / total
    
    # 在测试集上评估模型
    model.eval()
    test_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for images, labels in test_loader:
            images = images.to(device)
            labels = labels.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
            test_loss += loss.item()
    
    test_loss = test_loss / len(test_loader)
    test_accuracy = correct / total
    
    # 打印训练过程中的损失和准确率
    print(f"Epoch [{epoch+1}/{num_epochs}] - Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.4f}, Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.4f}")
Epoch [1/1] - Train Loss: 0.0154, Train Accuracy: 0.9951, Test Loss: 0.0109, Test Accuracy: 0.9964

保存模型和加载模型


#torch.save(model.state_dict(), '05model.pth')

# 创建一个新的模型实例
model = CNN()
# 加载模型的参数
model.load_state_dict(torch.load('05model.pth'))

测试模型


for batch in test_loader:       
    break
imgs=batch[0]
outputs = model(imgs)
_, predicted = torch.max(outputs.data, 1)
predicted=predicted.numpy()

print(predicted)

visualize_loader(batch,predicted)

在这里插入图片描述

上图中可视化了其中的32次预测,只有第三行第四列的“8”被预测为“5”,其余均是正确。
在测试集的总体预测准确度为99.64%,正确率挺高的。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/41263.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

好用的思维导图软件有哪些?这几款简单好用

好用的思维导图软件有哪些?思维导图是一种非常有用的思维工具,可以帮助我们组织和理清复杂的信息。在如今的数字时代,有很多软件可以帮助我们创建和编辑思维导图。下面介绍几款简单好用的思维导图软件。 第一款:迅捷画图 这是一款…

生成式AI时代,亚马逊云科技致力推动技术的普惠,让更多企业受益

当谈及AIGC时, 我们该谈些什么? 生成式AI技术与应用的不断发展,为各个行业都注入了全新的机会与活力。AIGC成为了今年最为激动人心的技术话题。亚马逊云科技也一马当先,在6月27-28日,2023亚马逊云科技中国峰会上分享…

babel兼容低版本游览器

文章目录 1. webpack项目的搭建2. babel 命令行使用3. babel的预设与编译器流程4. babel项目中配置4.1 babel-loader与插件的使用4.2 babel-preset使用 5. 游览器兼容性使用5.1 browserslist工具与编写规则5.2 browserslist配置5.3 优化babel的配置文件 6. polyfill6.1 useBuil…

Redis基础 进阶项目实战总结笔记

文章目录 一、启动的三种方式1.默认启动2.指定配置启动3.开机自启动 二、数据类型1.string:字符串2. hash:哈希3. list:列表4. set:集合5. sorted set:有序集合 三、黑马课程的进阶项目实战总结博文笔记Redis实现短信登…

配置spark

配置spark Yarn 模式Standalone 模式Local 模式 Yarn 模式 tar -zxvf spark-3.0.0-bin-hadoop3.2.tgz -C /opt/module cd /opt/module mv spark-3.0.0-bin-hadoop3.2 spark-yarn修改 hadoop 配置文件/opt/module/hadoop/etc/hadoop/yarn-site.xml, 并分发 <!--是否启动一…

Vue3统计数值(Statistic)

可自定义设置以下属性&#xff1a; 数值的标题&#xff08;title&#xff09;&#xff0c;类型&#xff1a;string | slot&#xff0c;默认&#xff1a;‘’数值的内容&#xff08;value&#xff09;&#xff0c;类型&#xff1a;string | number&#xff0c;默认&#xff1a;…

【Python】数据可视化利器PyCharts在测试工作中的应用

点击跳转原文&#xff1a;【Python】数据可视化利器PyCharts在测试工作中的应用 实际应用&#xff1a;常态化性能压测数据统计 import random from pyecharts.charts import Line, Bar, Grid, Pie, Page from pyecharts import options as opts # 查询过去 8 次数据 time_rang…

Low-Light Image Enhancement via Self-Reinforced Retinex Projection Model 论文阅读笔记

这是马龙博士2022年在TMM期刊发表的基于改进的retinex方法去做暗图增强&#xff08;非深度学习&#xff09;的一篇论文 文章用一张图展示了其动机&#xff0c;第一行是估计的亮度层&#xff0c;第二列是通常的retinex方法会对估计的亮度层进行RTV约束优化&#xff0c;从而产生…

NFTScan 与 Decert 达成合作伙伴,双方在 NFT 数据方面展开合作

近日&#xff0c;NFT 数据基础设施 NFTScan 与 Decert 达成合作伙伴关系&#xff0c;双方在多链 NFT 数据层面展开合作。在 Decert 产品中&#xff0c;由 NFTScan 为其提供专业的多链 NFT 数据支持&#xff0c;为用户带来优质的 NFT 搜索查询等相关交互功能&#xff0c;提升用户…

事务@transactional执行产生重复数据

背景 系统设计之初&#xff0c;每次来新请求&#xff0c;业务层会先查询数据库&#xff0c;判断是否存在相同的id数据&#xff08;id是唯一标识产品的&#xff09;&#xff0c;有则返回当前数据库查到的数据&#xff0c;根据数据决定下一步动作&#xff0c;没有则认为是初次请…

一则 MySQL 参数设置不当导致复制中断的故障案例

本文分享了一个数据库参数错误配置导致复制中断的问题&#xff0c;以及对参数配置的建议。 作者&#xff1a;秦福朗 爱可生 DBA 团队成员&#xff0c;负责项目日常问题处理及公司平台问题排查。热爱互联网&#xff0c;会摄影、懂厨艺&#xff0c;不会厨艺的 DBA 不是好司机&…

wsl2中安装docker

1、安装docker 执行以下脚本&#xff1a; 这个脚本在执行之前需要先执行chmod x install-docker.sh这个命令 # install docker curl -fsSL get.docker.com -o get-docker.sh sh get-docker.shif [ ! $(getent group docker) ]; thensudo groupadd docker; elseecho "doc…

UE4/5AI制作基础AI跳跃(适合新手)

目录 制作 添加逻辑 添加导航链接代理 结果 在上一章中&#xff0c;我们讲解了简单的AI跟随玩家&#xff0c;制作了一个基础的ai。 UE4/5AI制作基础AI&#xff08;适合新手入门&#xff0c;运用黑板&#xff0c;行为树&#xff0c;ai控制器&#xff0c;角色类&#xff0c;任…

数据库应用:CentOS 7离线安装PostgreSQL

目录 一、理论 1.PostgreSQL 2.PostgreSQL离线安装 3.PostgreSQL初始化 4.PostgreSQL登录操作 二、实验 1.CentOS 7离线安装PostgreSQL 2.登录PostgreSQL 3.Navicat连接PostgreSQL 三、总结 一、理论 1.PostgreSQL &#xff08;1&#xff09;简介 PostgreSQL 是一个…

性能测试工具 jmeter 录制脚本,传递 cookie,循环执行接口

目录 前言&#xff1a; 代理录制脚本 循环重复添加接口 登录并传递 cookie 给新建产品接口 循环执行脚本 前言&#xff1a; 在使用JMeter进行性能测试时&#xff0c;录制脚本是一种常用的方法。录制脚本可以帮助你捕获和重放用户与应用程序之间的交互&#xff0c;以模拟真…

matlab中画有重影的机器人运动过程【给另一个机器人设置透明度】

1、前言如题 2、参考连接如下 How to plot two moving robot in the same figure and change one of them transparency&#xff1f; - MATLAB Answers - MATLAB Central (mathworks.cn)3、代码&#xff1a;【找到figure中对应对象并设置属性】 % Create two instances of a…

什么是70v转12v芯片?

问&#xff1a;什么是70v转12v芯片&#xff1f; 答&#xff1a;70v转12v芯片是一种电子器件&#xff0c;其功能是将输入电压范围在9v至100v之间的电源转换为稳定的12v输出电压。这种芯片通常被用于充电器、车载电池充电器和电源适配器等设备中。 问&#xff1a;这种芯片的最大…

微信小程序使用字体图标——链接引入

目录 1.下载字体图标 1.选择需要的图标加入购物车添加到项目 2.查看项目 3.生成在线链接 4.复制生成的链接 等下放到iconfont.json中​编辑 2.引入链接 1.下载 2.生成iconfont.json文件 3. 在iconfont.json中 放入生成的链接 4.需要重新编译小程序之后在终端执行 5…

03 QT对象树

Tips: QT通过对象树机制&#xff0c;能够自动、有效的组织和管理继承自QObject的Qt对象&#xff0c;不需要用户手动回收资源&#xff0c;系统自动调用析构函数。 验证对象树功能&#xff1a; 新建C文件 继承自QPushButton&#xff0c;但没有QPushButton&#xff0c;但有其父类…

【雕爷学编程】Arduino动手做(164)---Futaba S3003舵机模块3

37款传感器与模块的提法&#xff0c;在网络上广泛流传&#xff0c;其实Arduino能够兼容的传感器模块肯定是不止37种的。鉴于本人手头积累了一些传感器和执行器模块&#xff0c;依照实践出真知&#xff08;一定要动手做&#xff09;的理念&#xff0c;以学习和交流为目的&#x…
最新文章