(三)Pytorch快速搭建卷积神经网络模型实现手写数字识别(代码+详细注解)

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档

文章目录

  • 前言
    • Q1:卷积网络和传统网络的区别
    • Q2:卷积神经网络的架构
    • Q3:卷积神经网络中的参数共享,也是比传统网络的优势所在
    • 4、 具体的实现代码+网络搭建


前言

深度学习pytorch系列第三篇啦,之前更了FC,NN,这篇是卷积神经网络(cNN)模型实现手写数字识别,依然是重在理解哈,具体的理解内容我都以注释的形式放在了代码中,我就直接放代码了,因为我把一些知识点和理解的东西用注释的形式写了


首先是关于卷积神经网络的一些点

Q1:卷积网络和传统网络的区别

传统网络只适合结构化数据,不适合图像数据,由于图像数据的数据量大(表现为像素点多),传统网络需要使用的参数量太大

Q2:卷积神经网络的架构

卷积神经网络包括:输入层,卷积层,池化层,全连接层
重点介绍卷积层!!
卷积就是针对每个区域去计算特征。可以这样做的原因是:图片是有像素点构成的,针对每个像素点进行处理,需要的参数量过于庞大,并且相邻的像素点之间是存在联系的
特征图的个数与卷积核的个数一致。每个卷积核通过对输入特征图进行卷积操作,生成一个输出特征图。因此,卷积核的个数决定了输出的特征图的个数。
使用不同的卷积核学习同一个位置,可以得到不同的特征图,从而使特征多样化
卷积核的大小一般使用3*3
卷积核的大小规格一般是固定的,卷积核的数量理论上是越多越好
卷积层涉及的参数有:滑动窗口步长,卷积核尺寸,边缘填充,卷积核个数
卷积结果计算公式:长:h2=(h1-Fh+2p)/s +1 宽:w2=(w1-Fw+2p)/s +1
其中:w1,h1表示输入的宽度,长度;w2和h2表示输出特征图的宽度、长度,F表示卷积核的长和宽,s表示滑动窗口的补偿,p表示边界填充
经过卷积操作后,特征图的长和宽也可以保持不变
池化层的作用就是筛选好的特征,pool是只筛选位置的,channel是全部使用的
池化也称为下采样,(一次只能下采样原来的一半,不能直接224-16)
卷积神经网络由多个block组成,重点就在于怎么设计这个block的组成
关于卷积神经网络的层数,带权重参数的就算是一层,6个conn+1个fc,就可以说是7层网络结构

Q3:卷积神经网络中的参数共享,也是比传统网络的优势所在

同一个卷积核在各个位置上的参数都是一致的
权重参数的个数与输入数据的大小无关

4、 具体的实现代码+网络搭建

# 读取数据
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets,transforms
# transforms  进行预处理,比如进行tensor转换
import matplotlib.pyplot as plt
import numpy as np
#全连接:batch*28*28,全连接各个像素点之间无关
# cnn:batch*1*28*28  ,多了一个参数channel,卷积会综合考虑一个窗口之间的关系,因此各个像素点并不是独立的,卷积网络更适合处理图像数据
# 定义超参数
input_size = 28  #图像的总尺寸28*28
num_classes = 10  #标签的种类数
num_epochs = 3  #训练的总循环周期
batch_size = 64  #一个撮(批次)的大小,64张图片
# 训练集
train_dataset = datasets.MNIST(root='./data',
                            train=True,
                            transform=transforms.ToTensor(),
                            download=True)
# 测试集
test_dataset = datasets.MNIST(root='./data',
                           train=False,
                           transform=transforms.ToTensor())

# 构建batch数据
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=batch_size,
                                           shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                           batch_size=batch_size,
                                           shuffle=True)
# 卷积网络模块构建
# 一般卷积层,relu层,池化层可以写成一个套餐
# 注意卷积最后结果还是一个特征图,需要把图转换成向量才能做分类或者回归任务
# 定义一个网络
class CNN(nn.Module):
    def __init__(self):
        #         构造函数
        # 卷积网络一般是组合进行的:conv pool relu可以当一个组合
        super(CNN, self).__init__()
        self.conv1 = nn.Sequential(  # 输入大小 (1, 28, 28)
            nn.Conv2d(  # 2d卷积做任务
                in_channels=1,  # 灰度图
                out_channels=16,  # 要得到几多少个特征图,就是卷积核的个数,相当于有16个卷积核
                kernel_size=5,  # 卷积核大小 5*5的
                stride=1,  # 步长
                padding=2,  # 如果希望卷积后大小跟原来一样,需要设置padding=(kernel_size-1)/2 if stride=1,一般是这么希望的
                #                                             如果不能整除pytorch采用向下取整
            ),  # 输出的特征图为 (16, 28, 28)
            nn.ReLU(),  # relu层
            nn.MaxPool2d(kernel_size=2),  # 进行池化操作(2x2 区域), 输出结果为: (16, 14, 14),一般是pooling后是之前的一半
        )
        self.conv2 = nn.Sequential(  # 下一个套餐的输入 (16, 14, 14)
            nn.Conv2d(16, 32, 5, 1, 2),  # 输出 (32, 14, 14)
            nn.ReLU(),  # relu层
            nn.Conv2d(32, 32, 5, 1, 2),
            nn.ReLU(),
            nn.MaxPool2d(2),  # 输出 (32, 7, 7)
        )

        self.conv3 = nn.Sequential(  # 下一个套餐的输入 (32, 7, 7)
            nn.Conv2d(32, 64, 5, 1, 2),  # 输出 (64, 7, 7)
            nn.ReLU(),  # 输出 (64, 7, 7)
        )
        # 只有pool的时候才会筛选特征

        self.out = nn.Linear(64 * 7 * 7, 10)  # 全连接层得到的结果,最后的任务是10分类任务,进行一个wx+b的操作去做分类

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = x.view(x.size(0), -1)  # flatten操作,结果为:(batch_size, 32 * 7 * 7),和reshape操作一样
        # reshape操作:总的大小是不变的,提供一个维度后,后边的维度自动计算
        # 比如当前的x:64*7*7,x.size:64,也就是要从三维转成两维,总的大小不变,就变为64*49这样,-1可以简单的看成一个占位符号
        # 变换维度,开始是64*7*7,转成batchsize*特征个数,比如64*49
        output = self.out(x)
        return output
# 定义准确率
def accuracy(predictions, labels):
    pred = torch.max(predictions.data, 1)[1] # 最大值是多少,最大值的索引,只要索引就可以
    rights = pred.eq(labels.data.view_as(pred)).sum()
    return rights, len(labels)
# 训练网络模型
# 实例化
net = CNN()
# 损失函数
criterion = nn.CrossEntropyLoss()
# 优化器,学习率是0.001
optimizer = optim.Adam(net.parameters(), lr=0.001)  # 定义优化器,普通的随机梯度下降算法
# 开始训练循环
for epoch in range(num_epochs):
    # 当前epoch的结果保存下来
    train_rights = []

    for batch_idx, (data, target) in enumerate(train_loader):  # 针对容器中的每一个批进行循环
        net.train()
        output = net(data)
        loss = criterion(output, target)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        right = accuracy(output, target)
        train_rights.append(right)
        # 每一个batch都进行训练,每一百个batch进行一次评估
        if batch_idx % 100 == 0:

            net.eval()
            val_rights = []

            for (data, target) in test_loader:
                output = net(data)
                right = accuracy(output, target)
                val_rights.append(right)

            # 准确率计算
            train_r = (sum([tup[0] for tup in train_rights]), sum([tup[1] for tup in train_rights]))
            val_r = (sum([tup[0] for tup in val_rights]), sum([tup[1] for tup in val_rights]))

            print('当前epoch: {} [{}/{} ({:.0f}%)]\t损失: {:.6f}\t训练集准确率: {:.2f}%\t测试集正确率: {:.2f}%'.format(
                epoch, batch_idx * batch_size, len(train_loader.dataset),
                       100. * batch_idx / len(train_loader),
                loss.data,
                       100. * train_r[0].numpy() / train_r[1],
                       100. * val_r[0].numpy() / val_r[1]))

实现结果
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/198728.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

C++二分查找、离线算法:最近的房间

作者推荐 利用广度优先或模拟解决米诺骨牌 本文涉及的基础知识点 二分查找算法合集 题目 一个酒店里有 n 个房间,这些房间用二维整数数组 rooms 表示,其中 rooms[i] [roomIdi, sizei] 表示有一个房间号为 roomIdi 的房间且它的面积为 sizei 。每一…

linux设置主机名

查看主机名:hostname 临时修改主机名:hostname 新主机名 [rootlocalhost ~]#hostname centos [rootlocalhost ~]#hostname centos 永久修改主机名: [rootlocalhost ~]#cat /etc/hostname localhost.localdomain

ArrayList 和 HashMap 源码解析

1、ArrayList 1.1、ArrayList 构造方法 无参创建一个 ArrayList 数组默认为空数组 transient Object[] elementData; private static final Object[] DEFAULTCAPACITY_EMPTY_ELEMENTDATA {}; private int size; // 数组容量大小public ArrayList() {this.elementData DEFA…

基于springboot校园车辆管理系统

背景 伴随着社会经济的快速发展,机动车保有量不断增加。不断提高的大众生活水平以及人们不断增长的自主出行需求,人们对汽车的 依赖性在不断增强。汽车已经发展成为公众日常出行的一种重要的交通工具。在如此形势下,高校校园内的机动车数量也…

java设计模式学习之【原型模式】

文章目录 引言原型模式简介定义与用途实现方式UML 使用场景优势与劣势原型模式在spring中的应用员工记录示例代码地址 引言 原型模式是一种创建型设计模式,它允许对象能够复制自身,以此来创建一个新的对象。这种模式在需要重复地创建相似对象时非常有用…

近五年—中国十大科技进展(2018年—2022年)

近五年—中国十大科技进展(2018-2022) 2022年中国十大科技进展1. 中国天眼FAST取得系列重要进展2. 中国空间站完成在轨建造并取得一系列重大进展3. 我国科学家发现玉米和水稻增产关键基因4. 科学家首次发现并证实玻色子奇异金属5. 我国科学家将二氧化碳人…

Vue 定义只读数据 readonly 与 shallowReadonly

readonly 让一个响应式数据变为 **深层次的只读数据**。 shallowReadonly 让一个响应式数据变为 **浅层次的只读数据**,只读第一层。 isReadonly 判断一个数据是不是只读数据。 应用场景:不希望数据被修改时使用。 readonly深层次只读: …

读像火箭科学家一样思考笔记12_实践与测试(下)

1. 舆论的火箭科学 1.1. 如果苹果违反了“即飞即测”原则,那苹果的iPhone就不会问世了 1.1.1. iPhone在其上市前的民意调查中相当失败 1.1.1.1. iPhone不可能获得太大市场份额,不可能。 1.1.1.1.1. 微软前CEO史蒂夫鲍尔默(Steve Ballmer&…

msng病毒分析

这是一个非常古老的文件夹病毒,使用XP系统的文件夹图标,采用VB语言开发,使用了一种自定义的壳来保护,会打开网址http://www.OpenClose.ir,通过软盘、U盘和共享目录进行传播,会在U盘所有的目录下生成自身的副本&#xf…

采集工具-免费采集器下载

在当今信息时代,互联网已成为人们获取信息的主要渠道之一。对于研究者和开发者来说,如何快速准确地采集整个网站数据是至关重要的一环。以下将从九个方面详细探讨这一问题。 确定采集目标 在着手采集之前,明确目标至关重要。这有助于确定采集…

三季度营收下滑16.3%,网易云音乐如何讲出新故事?

在选择重新回归音乐本身后,网易云音乐(09899.HK)业绩承压的困局写在最新的三季报里。 「不二研究」据网易云音乐三季报发现:今年三季度,网易云音乐净收入同比下滑16.3%。目前,网易云音乐主要面临营收下滑、商业化场景探索尚未形成…

MSB3541 Files 的值“<<<<<<< HEAD”无效。路径中具有非法字符。

MSB3541 Files 的值“<<<<<<< HEAD”无效。路径中具有非法字符。 一般来说出现这个问题是因为使用git版本控制工具合并代码出现了问题&#xff0c;想要解决也很简单。 如图点击错误后定位到文件&#xff0c;发现也没有什么问题。 根据错误后边的提示&a…

前后端分离开发出现的跨域问题

先说说什么是跨域。 请求的URL地址中的协议、域名、端口号中的任意一个与当前URL不同就是跨域。 比如&#xff1a; 当前页面的URL请求的URL是否跨域原因htttp://localhost:8080htttps://localhost:8080是协议不同htttp://localhostll:8080htttp://localhost:8080是域名不同htt…

JVM 内存结构

&#x1f680; 作者主页&#xff1a; 有来技术 &#x1f525; 开源项目&#xff1a; youlai-mall &#x1f343; vue3-element-admin &#x1f343; youlai-boot &#x1f33a; 仓库主页&#xff1a; Gitee &#x1f4ab; Github &#x1f4ab; GitCode &#x1f496; 欢迎点赞…

【赠书第9期】巧用ChatGPT高效搞定Excel数据分析

文章目录 前言 1 操作步骤 1.1 数据清理和整理 1.2 公式和函数的优化 1.3 图表和可视化 1.4 数据透视表的使用 1.5 条件格式化和筛选 1.6 数据分析技巧 1.7 自动化和宏的创建 2 推荐图书 3 粉丝福利 前言 ChatGPT 是一个强大的工具&#xff0c;可以为你提供在 Exce…

【Newman+Jenkins】实施接口自动化测试

一、是什么Newman Newman就是纽曼手机这个经典牌子&#xff0c;哈哈&#xff0c;开玩笑啦。。。别当真&#xff0c;简单地说Newman就是命令行版的Postman&#xff0c;查看官网地址。 Newman可以使用Postman导出的collection文件直接在命令行运行&#xff0c;把Postman界面化运…

链接2:静态链接、目标文件、符号和符号表

文章目录 静态链接符号解析 (symbolresolution)重定位 (relocation) 目标文件1.可重定位目标文件2.可执行目标文件3.共享目标文件 可重定位目标文件text:rodata:.data.bss.symtab.rel.text.rel.data:debug:line:strtab: 符号和符号表由m定义并能被其他模块引用的全局符号由其他…

【用unity实现100个游戏之17】从零开始制作一个类幸存者肉鸽(Roguelike)游戏3(附项目源码)

文章目录 本节最终效果前言近战武器控制近战武器生成升级增加武器伤害和数量查找离主角最近的敌人子弹预制体生成子弹发射子弹参考源码完结 本节最终效果 前言 本节紧跟着上一篇&#xff0c;主要实现武器功能。 近战武器 新增Bullet&#xff0c;子弹脚本 public class Bull…

REST-Assured--JAVA REST服务自动化测试的Swiss Army Knife

什么是REST-Assured REST Assured是一套基于 Java 语言实现的开源 REST API 测试框架 Testing and validation of REST services in Java is harder than in dynamic languages such as Ruby and Groovy. REST Assured brings the simplicity of using these languages into t…

Java第二十章多线程

一、线程简介 线程是操作系统能够进行运算调度的最小单位&#xff0c;它被包含在进程之中&#xff0c;是进程中的实际运作单位。一个进程可以包含多个线程&#xff0c;这些线程可以并发执行。线程拥有自己的栈和局部变量&#xff0c;但是它们共享进程的其他资源&#xff0c;如…