基于pytorch实现手写数字识别

1,先安装pytorch,在pytorch环境中安装库:

1)进入所安装的pytorch环境,我的是pytorch

所以激活它:

conda activate pytorch

2)使用pip安装numpy,torch,torchvision,matplotlib库 

pip install numpy torch torchvision matplotlib

回车安装4个库

2,再将test.py文件用vscode打开,pycharm也行(主要我不怎么会用),这里用vscode展示。

 注意右下角环境要选好。

这里我已经测试了两次,最高在0.96左右。

献上源码:

import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
import matplotlib.pyplot as plt


class Net(torch.nn.Module):#定义一个NET类,它就是神经网络的主体

    def __init__(self):
        super().__init__()  #四个全连接层
        self.fc1 = torch.nn.Linear(28*28, 64)#输入为28*28的像素尺寸图像
        self.fc2 = torch.nn.Linear(64, 64)
        self.fc3 = torch.nn.Linear(64, 64)#中间三层放了64个节点
        self.fc4 = torch.nn.Linear(64, 10)#输出为10个数字类别
    def forward(self, x):#forward函数定义了前向传播过程,参数x是图像输入
        x = torch.nn.functional.relu(self.fc1(x))#每层连接中我们先做全连接线性计算
        x = torch.nn.functional.relu(self.fc2(x))#再套上一个激活函数torch.nn.functional.relu
        x = torch.nn.functional.relu(self.fc3(x))
        x = torch.nn.functional.log_softmax(self.fc4(x), dim=1)#输出层通过sodtmax归一化,这里的log_softmax是为了提高计算的稳定性。
        return x#在softmax之外又套上了torch.nn.functional.log_softmax对数运算


def get_data_loader(is_train):#导入数据
    to_tensor = transforms.Compose([transforms.ToTensor()])#定义一个tensor,是一个多维数组,中文叫张量
    data_set = MNIST("", is_train, transform=to_tensor, download=True)#下载MNIST数据集,""是下载目录,空表示当前目录,is_train用来决定是导入训练集还是测试集
    return DataLoader(data_set, batch_size=15, shuffle=True)#batch_size=15表示一个批次含15张图片,shuffle=True表示数据是随机打乱的,最后返回数据加载器


def evaluate(test_data, net):#用来评估神经网络的正确率,
    n_correct = 0
    n_total = 0
    with torch.no_grad():
        for (x, y) in test_data:
            outputs = net.forward(x.view(-1, 28*28))#计算神经网络预测值
            for i, output in enumerate(outputs):#对批次结果进行比较,累加正确预测的数量
                if torch.argmax(output) == y[i]:#argmax函数计算数据中最大值的序号也就是预测的手写数字结果
                    n_correct += 1
                n_total += 1
    return n_correct / n_total#返回正确率


def main():

    train_data = get_data_loader(is_train=True)#导入训练集
    test_data = get_data_loader(is_train=False)#导入测试集
    net = Net()#初始化神经网络
    
    print("initial accuracy:", evaluate(test_data, net))#打印初始网络的正确率
    optimizer = torch.optim.Adam(net.parameters(), lr=0.001)#一下几行代码训练神经网络,都是pytorch的固定写法
    for epoch in range(2):#epoch反复训练,提高数据集的利用率,每一个轮次就是一个epoch
        for (x, y) in train_data:
            net.zero_grad()#初始化
            output = net.forward(x.view(-1, 28*28))#正向传播
            loss = torch.nn.functional.nll_loss(output, y)#计算差值,nll_loss是对数损失函数,是为了匹配前面的log_softmax中的对数运算
            loss.backward()#反向误差传播
            optimizer.step()#优化网络参数
        print("epoch", epoch, "accuracy:", evaluate(test_data, net))#每个轮次后打印当前网络的正确率

    for (n, (x, _)) in enumerate(test_data):#训练完成后随机抽取3张图像显示网络预测结果
        if n > 3:
            break
        predict = torch.argmax(net.forward(x[0].view(-1, 28*28)))
        plt.figure(n)
        plt.imshow(x[0].view(28, 28))
        plt.title("prediction: " + str(int(predict)))
    plt.show()


if __name__ == "__main__":
    main()

 1,讲解

1)使用MNIST数据集:手写数字图片7万张(训练6万张,测试1万张)。

2)什么是神经网络?

通过softmax归一化得到了看起来像概率的数值(概率分布),但它还不是真的概率,

调整a,b的值,如梯度下降算法,ADAM算法将神经网络问题转为最优化问题,重复过程几万次。

神经网络的本质是一个数学函数,训练的过程就是调整函数中的参数。

观察公式是线性的,但不是每个都是线性的,所以再套上一个非线性函数(也叫激活函数),f()

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/425675.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【MDVRP多站点物流配送车辆路径规划问题(带容量限制)】基于遗传算法GA求解

课题名称:基于遗传算法求解带容量限制的多站点的物流配送路径问题MDVRP 版本时间:2023-03-12 代码获取方式:QQ:491052175 或者 私聊博主获取 模型描述: 15个城市中,其中北京,长沙和杭州三座…

Android 多桌面图标启动, 爬坑点击打开不同页面

备注 : MainActivity 正常带界面的UI MainActivityBt 和 MainActivityUsb 是透明的,即 android:theme"style/TranslucentTheme" ###场景1:只有MainActivity 设置成:android:launchMode"singleTask" 点击顺序&#xff1…

02-pycharm详细安装教程(大妈看了都会)

目录 1.官方下载pycharm 2.开始安装pycharm 3.开始运行pycharm 1.官方下载pycharm 官方:https://www.jetbrains.com/zh-cn/pycharm/download/?sectionwindows 提示:安装pycharm之前建议先安装python,因为pycharm要基于python环境才能运…

azure devops工具实践分析

对azure devops此工具的功能深挖,结合jira的使用经验的分析 1、在backlog的功能描述,可理解为需求项,这里包括了bug,从开发的角度修复bug也是个工作项,所以需求的范围是真正的需求(开发接收到的已经确认的…

二叉树的前序后序中序层序

文章目录 一、二叉树的前序遍历二、二叉树的中序三、后序四、层序 引言:首先我们讲一下什么是二叉树的前序中序后序层序 前序:从 根 左子树 右子树访问 中序:从 左子树 根 右子树访问 后序:从 左子树 右子树 根访问 等到根为空的…

java面试题(spring框架篇)(黑马 )

树形图: 一、Spring框架种的单例bean是线程安全吗? Service Scope("singleton") public class UserServiceImpl implements UserService{ } singleton:bean在每个Spring IOC容器中只有一个实例 protype:一个bean的定义可以有多个…

使用 Grafana 使用JSON API 请求本地接口 报错 bad gateway(502)解决

一 . 问题: 在用docker部署Grafana 来实现仪表盘的展示,使用到比较多的就是使用JAON API插件调用本地部署的API,比如访问localhost下的 /test_data 接口,一般我们使用的是http://localhost:8080/test_data, 但是在访…

Python测试框架pytest介绍用法

1、介绍 pytest是python的一种单元测试框架,同自带的unittest测试框架类似,相比于unittest框架使用起来更简洁、效率更高 pip install -U pytest 特点: 1.非常容易上手,入门简单,文档丰富,文档中有很多实例可以参考 2.支持简单的单…

计算机网络-第2章 物理层

本章内容:物理层和数据通信的概念、传输媒体特点(不属于物理层)、信道复用、数字传输系统、宽带接入 2.1-2.2 物理层和数据通信的概念 物理层解决的问题:如何在传输媒体上传输数据比特流,屏蔽掉传输媒体和通信手段的差…

Java学习27--IDEA常用快捷键

智能显示相关提示:altenter,用来快速生成Scanner,或者new object等等,也可以爆红线求提示 代码模板大全ctrlj 可以快速生成try catch finally模块的surround with:ctrlaltt(我换成了altc) 生成getter/setter/构造器等结构-genera…

武器大师——操作符详解(下)

目录 六、单目操作符 七、逗号表达式 八、下标引用以及函数调用 8.1.下标引用 8.2.函数调用 九、结构体 9.1.结构体 9.1.1结构的声明 9.1.2结构体的定义和初始化 9.2.结构成员访问操作符 9.2.1直接访问 9.2.2间接访问 十、操作符的属性 10.1.优先性 10.2.结合性 …

【MySQL】SQL 优化

MySQL - SQL 优化 1. 在 MySQL 中,如何定位慢查询? 1.1 发现慢查询 现象:页面加载过慢、接口压力测试响应时间过长(超过 1s) 可能出现慢查询的场景: 聚合查询多表查询表数据过大查询深度分页查询 1.2 通…

2023 版王道单科书勘误汇总(3.30)

注:因2023版对题目编号做了优化“历年真题全部放最后、且按年份排序”,以方便大家根据需要保留某些年份的真题作为最后的模拟。所以造成了一些题目和解析的编号错误。 数据结构: P11 P20 P56 P278 P326 “2.”中第 3 行”题 5改成”9”,第6行”题 8”改成…

线性表——单链表的增删查改

本节复习链表的增删查改 首先, 链表不是连续的, 而是通过指针联系起来的。 如图: 这四个节点不是连续的内存空间, 但是彼此之间使用了一个指针来连接。 这就是链表。 现在我们来实现链表的增删查改。 目录 单链表的全部接口…

【EAI 027】Learning Interactive Real-World Simulators

Paper Card 论文标题:Learning Interactive Real-World Simulators 论文作者:Mengjiao Yang, Yilun Du, Kamyar Ghasemipour, Jonathan Tompson, Leslie Kaelbling, Dale Schuurmans, Pieter Abbeel 作者单位:UC Berkeley, Google DeepMind, …

探索设计模式的魅力:备忘录模式揭秘-实现时光回溯、一键还原、后悔药、历史的守护者和穿越时空隧道

​🌈 个人主页:danci_ 🔥 系列专栏:《设计模式》 💪🏻 制定明确可量化的目标,并且坚持默默的做事。 备忘录模式揭秘-实现时光回溯、一键还原、后悔药和穿越时空隧道 文章目录 一、案例场景&…

19.1 SpringBoot入门

19.1 SpringBoot入门 1. SpringBoot1.1 简介1.2 核心特点1.3 SpringBoot演变1.4 SpringBoot版本1. SpringBoot 1.1 简介 1.2 核心特点

【系统分析师】-计算机组成结构

1、计算机结构 2、存储系统 Cache是访问最快 DRAM是存取最快 先来先服务 FCFS:按照磁道号访问顺序 最短寻道时间优先SSTF:查找下一个最少的磁道数。柱面相同找磁头、磁头相同找扇区 3、数据传输控制方式 4、总线 总线: 分 时 传 输 &#…

十四、计算机视觉-形态学梯度

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 一、梯度的概念二、梯度的应用三、梯度如何实现 一、梯度的概念 形态学梯度(Morphological Gradient)是数字图像处理中的一种基本操作&…

C++学习笔记:二叉搜索树

二叉搜索树 什么是二叉搜索树?搜索二叉树的操作查找插入删除 二叉搜索树的应用二叉搜索树的代码实现K模型:KV模型 二叉搜索树的性能怎么样? 什么是二叉搜索树? 二叉搜索树又称二叉排序树,它或者是一棵空树,或者是具有以下性质的二叉树: 若它的左子树…