【PyTorch】使用PyTorch创建卷积神经网络并在CIFAR-10数据集上进行分类

前言

在深度学习的世界中,图像分类任务是一个经典的问题,它涉及到识别给定图像中的对象类别。CIFAR-10数据集是一个常用的基准数据集,包含了10个类别的60000张32x32彩色图像。在本博客中,我们将探讨如何使用PyTorch框架创建一个简单的卷积神经网络(CNN)来对CIFAR-10数据集中的图像进行分类。

在下一篇博客中,我们将尝试不断优化模型结构和训练过程,以达到更高的准确率和性能。

引用

关于卷积神经网络的原理,感兴趣的请参阅我的另一篇博客,里面只使用numpy和基础函数组建了一个卷积神经网络模型,并完成训练和测试
【手搓深度学习算法】从头创建卷积神经网络

背景

卷积神经网络是深度学习中用于图像识别和分类的一种强大工具。它们能够自动从图像中提取特征,并通过一系列卷积层、池化层和全连接层来学习图像的复杂模式。

CIFAR-10数据集包含了飞机、汽车、鸟类、猫、鹿、狗、青蛙、马、船和卡车等10个类别的图像。每个类别有6000张图像,其中50000张用于训练,10000张用于测试。
请添加图片描述

代码解析

我们的目标是构建一个能够处理CIFAR-10数据集的CNN模型。以下是我们的模型结构和数据处理流程的简要概述:

数据预处理

我们首先定义了unpickle函数来加载CIFAR-10数据集的批次文件。read_data函数用于读取数据,将其转换为适合卷积网络输入的格式,并进行归一化处理。我们还提供了一个选项来将图像转换为灰度。

def unpickle(file):
    import pickle
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict

def read_data(file_path, gray = False, percent = 0, normalize = True):
    data_src = unpickle(file_path)
    np_data = np.array(data_src["data".encode()]).astype("float32")
    np_labels = np.array(data_src["labels".encode()]).astype("float32").reshape(-1,1)
    single_data_length = 32*32 
    image_ret = None
    if (gray):
        np_data = (np_data[:, :single_data_length] + np_data[:, single_data_length:(2*single_data_length)] + np_data[:, 2*single_data_length : 3*single_data_length])/3
        image_ret = np_data.reshape(len(np_data),32,32)
    else:
        image_ret = np_data.reshape(len(np_data),32,32,3)
    
    if(normalize):
        mean = np.mean(np_data)
        std = np.std(np_data)
        np_data = (np_data - mean) / std
    
    if (percent != 0):
        np_data = np_data[:int(len(np_data)*percent)]
        np_labels = np_labels[:int(len(np_labels)*percent)]
        image_ret = image_ret[:int(len(image_ret)*percent)]
    num_classes = len(np.unique(np_labels))
    np_data, np_labels = convert_to_conv_input(np_data, np_labels)
    return np_data, np_labels, num_classes, image_ret 

网络结构

Conv类定义了我们的CNN模型,它包含一个卷积层、一个最大池化层、一个ReLU激活函数和一个全连接层。在forward方法中,我们指定了数据通过网络的流程。

class Conv(th.nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super(Conv, self).__init__()
        self.conv = th.nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3)
        self.pool = th.nn.MaxPool2d(kernel_size=2,stride=2)
        self.relu = th.nn.ReLU()
        self.linear = th.nn.Linear(16*15*15, 10)
        self.softmax = th.nn.Softmax(dim=1)
        
    def forward(self, x):
        x = self.conv(x) #32,16,30,30
        x = self.pool(x) #32,16,15,15
        x = self.relu(x)
        x = x.view(x.size(0), -1)
        x = self.linear(x)
        return x
    
    # 在predict函数中,额外调用了softmax,将线性层的10个特征值转化为概率,在前向传播中不用是因为pytorch中交叉熵函数自带了softmax
    def predict(self,x):
        x = self.forward(x)
        x = self.softmax(x)
        return x
卷积层、池化层、线性层的输入特征数量的计算方法

线性层的输入特征个数取决于前面层的输出。
具体来说,线性层的输入特征个数是卷积层和池化层处理后的输出特征图的总元素数量。

卷积层定义如下:

self.conv = th.nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3)

这里,in_channels=3 表示输入图像有3个颜色通道(RGB),out_channels=16 表示卷积层将输出16个特征图。

接下来是池化层:

self.pool = th.nn.MaxPool2d(kernel_size=2, stride=2)

kernel_size=2,表示池化窗口的大小是2x2。stride=2 表示池化操作的步长是2。

为了计算线性层的输入特征个数,我们需要知道卷积层和池化层之后的输出特征图的大小。这可以通过计算公式得到,或者通过在实际数据上运行网络的前向传播来确定。

计算公式如下:

对于卷积层,输出特征图的大小可以通过以下公式计算:

H_out = (H_in + 2 * padding - dilation * (kernel_size - 1) - 1) / stride + 1
W_out = (W_in + 2 * padding - dilation * (kernel_size - 1) - 1) / stride + 1

对于池化层,输出特征图的大小也可以通过类似的公式计算。

由于没有指定paddingdilation,查看函数定义可知它们的默认值分别是0和1。因此,如果输入图像的大小是32x32,卷积层之后的大小将是:

H_out = (32 - 1 * (3 - 1) - 1) / 1 + 1 = 30
W_out = (32 - 1 * (3 - 1) - 1) / 1 + 1 = 30

因此,卷积层的输出将有16个30x30的特征图。

然后,池化层将这些特征图的大小减半(因为kernel_size=2stride=2),所以输出将是16个15x15的特征图。

最后,线性层的输入特征个数将是这些特征图的总元素数量:

num_features = out_channels * H_out_pool * W_out_pool = 16 * 15 * 15 = 3600

因此,线性层的正确定义应该是:

self.linear = th.nn.Linear(3600, num_classes)

训练过程

main函数中,我们初始化了模型、损失函数和优化器。我们使用随机梯度下降(SGD)作为优化算法,并设置了学习率。接着,我们进入了训练循环,其中包括前向传播、损失计算、反向传播和权重更新。

loss_function = th.nn.CrossEntropyLoss()
optimizer = th.optim.SGD(conv_model.parameters(), lr = lr)

测试和评估

训练完成后,我们使用训练好的模型对测试数据进行评估,并计算准确率。我们还提供了一个predict方法,它在给定输入数据后返回模型的预测概率。

def predict(self,x):
        x = self.forward(x)
        x = self.softmax(x)
        return x
softmax激活函数

Softmax 激活函数是一种广泛使用的函数,它将一个实数向量转换为概率分布。在深度学习中,它常常用于多类别分类问题的输出层。

Softmax 函数的定义如下:

softmax ( z ) i = e z i ∑ j e z j \text{softmax}(z)_i = \frac{e^{z_i}}{\sum_{j} e^{z_j}} softmax(z)i=jezjezi

其中 z z z 是输入向量, z i z_i zi z z z 的第 i i i 个元素, softmax ( z ) i \text{softmax}(z)_i softmax(z)i 是输出向量的第 i i i 个元素。

Softmax 函数的主要特性是它的输出是一个概率分布,即所有输出元素的值都在 ( 0 , 1 ) (0, 1) (0,1) 区间内,且所有输出元素的值之和为 1。这使得 Softmax 函数非常适合用于表示概率。

Softmax 函数的一个重要性质是它是连续的,且其导数容易计算。这使得 Softmax 函数在深度学习中的反向传播过程中非常有用。

Softmax 函数的导数如下:

∂ ∂ z i softmax ( z ) i = softmax ( z ) i ( 1 − softmax ( z ) i ) \frac{\partial}{\partial z_i}\text{softmax}(z)_i = \text{softmax}(z)_i(1 - \text{softmax}(z)_i) zisoftmax(z)i=softmax(z)i(1softmax(z)i)

这个导数表达式表明,对于 Softmax 函数的输出 y i y_i yi,其对输入 z i z_i zi 的导数等于 y i ( 1 − y i ) y_i(1 - y_i) yi(1yi)。这个导数表达式在反向传播过程中非常有用,因为它可以直接用于计算梯度。

训练过程中没有使用softmax层,是应为torch的交叉熵损失函数已经包含了softmax的操作,如果叠加使用,可能得到错误的结果。

运行结果

作为一个简单的卷积模型,在测试集上得到了60%的准确率
请添加图片描述

完整代码

本文不提供完整代码,因为随着我的微调优化过程,已经没有这个版本的基线代码了,想要最终代码的欢迎阅读下一篇博客 “记一次卷积网络调优的过程”
在这里插入图片描述

注意点

  • 数据预处理:确保数据被正确地加载和归一化,这对模型的训练效果至关重要。
  • 模型结构:模型的层数和参数需要根据任务的复杂性来调整。过于简单的模型可能无法捕捉到数据中的复杂特征,而过于复杂的模型可能会导致过拟合。
  • 损失函数:我们使用交叉熵损失函数,它适用于多类别分类问题。
  • 优化器:在每次迭代前,记得清除累积的梯度,以避免错误的梯度更新。

可能的优化点

  • 学习率调整:可以尝试使用学习率调度器来在训练过程中调整学习率,以改善模型的收敛速度和性能。
  • 权重初始化:尝试不同的权重初始化方法,以帮助模型更快地收敛。
  • 正则化技术:使用如Dropout、L2正则化等技术来减少过拟合。
  • 数据增强:通过对训练图像进行随机变换(如旋转、缩放、裁剪等),可以增加模型的泛化能力。
  • 更深的网络:考虑增加更多的卷积层和池化层来提取更复杂的特征。
  • 批量归一化:在卷积层之后添加批量归一化层,以稳定训练过程并加速收敛。

结论

通过本博客,我们展示了如何使用PyTorch框架构建一个简单的CNN模型,并在CIFAR-10数据集上进行训练和测试。虽然我们的模型结构相对简单,但它为理解深度学习和图像分类提供了一个很好的起点。在下一篇博客中,我们将尝试不断优化模型结构和训练过程,以达到更高的准确率和性能。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/343981.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

app如何实现悬浮框滚动到那个模块定位到那。

如图&#xff1a; 使用uniapp内置方法 onPageScroll 获取到滚动了多少。 其实拿到屏幕滚动多少就很简单了&#xff0c;下面是思路。 tap栏切换效果代码就不贴了。直接贴如何到那个模块定位到哪&#xff0c;和点击定位到当前模块。 <view v-if"show" class&qu…

08-微服务Seata分布式事务使用

一、分布式事务简介 1.1 概念 事务ACID&#xff1a; A&#xff08;Atomic&#xff09;&#xff1a;原子性&#xff0c;构成事务的所有操作&#xff0c;要么都执行完成&#xff0c;要么全部不执行&#xff0c;不可能出现部分成功部分失 败的情况。 C&#xff08;Consistency&…

网络协议与攻击模拟_06攻击模拟SYN Flood

一、SYN Flood原理 在TCP三次握手过程中&#xff0c; 客户端发送一个SYN包给服务器服务端接收到SYN包后&#xff0c;会回复SYNACK包给客户端&#xff0c;然后等待客户端回复ACK包。但此时客户端并不会回复ACK包&#xff0c;所以服务端就只能一直等待直到超时。服务端超时后会…

麒麟系统—— openKylin 安装到虚拟机以及开放SSH通过工具连接

麒麟系统—— openKylin 安装到虚拟机以及开放SSH通过工具连接 1. 在VMware中安装openKylin麒麟系统步骤1&#xff1a;准备VMware环境步骤2&#xff1a;创建新的虚拟机步骤3&#xff1a;安装openKylin麒麟系统步骤4&#xff1a;调整分别率步骤5&#xff1a;安装SSH 2. 使用Open…

Textual Inversion、DreamBooth、LoRA、InstantID:从低成本进化到零成本实现IP专属的AI绘画模型

2023年7月份国内有一款定制写真AI工具爆火。一款名为妙鸭相机的AI写真小程序&#xff0c;成功在C端消费者群体中出圈&#xff0c;并在微信、微博和小红书等平台迅速走红&#xff0c;小红书上的话题Tag获得了330多万的浏览量&#xff0c;相关微信指数飙升到了1800万以上。 其他…

kali安装LAMP和DVWA

LANMP简介 LANMP是指一组通常用来搭建动态网站或者服务器的开源软件&#xff0c;本身都是各自独立的程序&#xff0c;但是因为常被放在一起使用&#xff0c;拥有了越来越高的兼容度&#xff0c;共同组成了一个强大的Web应用程序平台。 L:指Linux&#xff0c;一类Unix计算机操作…

java基础加强(1)

1.xml 1.1概述【理解】 万维网联盟(W3C) 万维网联盟(W3C)创建于1994年&#xff0c;又称W3C理事会。1994年10月在麻省理工学院计算机科学实验室成立。 建立者&#xff1a; Tim Berners-Lee (蒂姆伯纳斯李)。 是Web技术领域最具权威和影响力的国际中立性技术标准机构。 到目前为…

写Shell以交互方式变更Ubuntu的主机名

以下是一个简单的 Bash 脚本&#xff0c;用于以交互方式更改 Ubuntu 20 系统的主机名&#xff1a; 1#!/bin/bash 2 3# 提示用户输入新的主机名 4read -p "请输入新的系统名称&#xff08;主机名&#xff09;: " new_hostname 5 6# 检查是否输入了新的主机名 7if [ -…

猫用空气净化器哪些好?五款宠物空气净化推荐!

如今&#xff0c;养宠物的家庭越来越多了&#xff01;家里因此变得更加温馨&#xff0c;但同时也会带来一些问题&#xff0c;比如异味和空气中的毛发可能会对健康造成困扰。 为了避免家中弥漫着异味&#xff0c;特别是来自宠物便便的味道&#xff0c;一款能够处理家里异味的宠…

python基础学习-03 安装

python3 可应用于多平台包括 Windows、Linux 和 Mac OS X。 Unix (Solaris, Linux, FreeBSD, AIX, HP/UX, SunOS, IRIX, 等等。)Win 9x/NT/2000Macintosh (Intel, PPC, 68K)OS/2DOS (多个DOS版本)PalmOSNokia 移动手机Windows CEAcorn/RISC OSBeOSAmigaVMS/OpenVMSQNXVxWorksP…

npm install运行报错npm ERR! gyp ERR! not ok问题解决

执行npm install的时候报错&#xff1a; npm ERR! path D:..\node_modules\\**node-sass** npm ERR! command failed ...npm ERR! gyp ERR! node -v v20.11.0 npm ERR! gyp ERR! node-gyp -v v3.8.0 npm ERR! gyp ERR! not ok根据报错信息&#xff0c;看出时node-sass运行出现…

图像处理算法:白平衡、除法器、乘法器~笔记

参考&#xff1a; 基于FPGA的自动白平衡算法的实现 白平衡初探 (qq.com) FPGA自动白平衡实现步骤详解-CSDN博客 xilinx 除法ip核&#xff08;divider&#xff09; 不同模式结果和资源对比&#xff08;VHDL&ISE&#xff09;_ise除法器ip核-CSDN博客 数…

绝地求生:本周三停机维护更新4小时: RASH悲喜套装即将下线!

本周三将迎来停机维护更新四小时~&#xff0c;同时游戏商城内RASH悲喜联名套装即将下线&#xff0c;同时空投签到任务和荣都地图翻牌任务即将下线~ 预计维护时间: 2024年1月24日08:00~12:00 本周地图轮换情况 (1月24日 ~ 1月31日) 可自主选择地图的地区:艾伦格、泰戈、帝斯顿、…

前沿重器[41] | 综述-面向大模型的检索增强生成(RAG)

前沿重器 栏目主要给大家分享各种大厂、顶会的论文和分享&#xff0c;从中抽取关键精华的部分和大家分享&#xff0c;和大家一起把握前沿技术。具体介绍&#xff1a;仓颉专项&#xff1a;飞机大炮我都会&#xff0c;利器心法我还有。&#xff08;算起来&#xff0c;专项启动已经…

JavaScript入门分享

文章目录 一、JavaScript简介二、第一个JavaScript案例三、在浏览器中执行JavaScript代码四、JavaScript的输出方法五、JavaScript的语法六、JavaScript的数据类型七、JavaScript的定义变量/函数八、热门文章 一、JavaScript简介 JavaScript是一种高级编程语言&#xff0c;用于…

假期刷题打卡--Day12

1、MT1128骰子的反面 小码哥抛出一个六面骰子。每个面上都印有一个数字&#xff0c;数字在1到6之间。输入正面的数字&#xff0c;输出对面的数字。 其他情况输出-1。 格式 输入格式&#xff1a; 输入为整型 输出格式&#xff1a; 输出为整型 样例 1 输入&#xff1a; …

【CANoe使用大全】——工程新建

文章目录 1、硬件连接2、通道配置2.1通道协议选择2.2映射通道配置2.3.波特率采样点配置 1、硬件连接 前提条件&#xff1a;软件、驱动均已经安装完成 硬件通过UBS接入电脑&#xff0c;Status状态灯为黄色闪烁状态说明硬件设备与电脑连接正常 2、通道配置 2.1通道协议选择 …

npm i 报一堆版本问题

1&#xff0c;先npm cache clean --force 再下载 插件后缀加上 --legacy-peer-deps 2&#xff0c; npm ERR! code CERT_HAS_EXPIRED npm ERR! errno CERT_HAS_EXPIRED npm ERR! request to https://registry.npm.taobao.org/yorkie/download/yorkie-2.0.0.tgz failed, reason…

Linux之安装配置VCentOS7+换源

目录 一、安装 二、配置 三、安装工具XSHELL 3.1 使用XSHELL连接Linux 四、换源 前言 首先需要安装VMware虚拟机&#xff0c;在虚拟机里进行安装Linux 简介 Linux&#xff0c;一般指GNU/Linux&#xff08;单独的Linux内核并不可直接使用&#xff0c;一般搭配GNU套件&#…

SQL注入示例

例一、基础SQL注入&#xff1a;load_file读文件 CISP-PTE 认证考试 首先是有单引号和括号的&#xff0c;首要是要闭合&#xff0c;然后回显点是在-1的位置&#xff0c;读取文件上面的key的话使用的是load_file(/tmp/360/key) id-1)%09ununionion%09select%091,2,3,load_file…
最新文章