PyTorch: 基于VGG16处理MNIST数据集的图像分类任务

引言

在本博客中,小编将向大家介绍如何使用VGG16处理MNIST数据集的图像分类任务。MNIST数据集是一个常用的手写数字分类数据集,包含60,000个训练样本和10,000个测试样本。我们将使用Python编程语言和PyTorch深度学习框架来实现这个任务。

在Conda虚拟环境下安装pytorch

# CUDA 11.6
pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu116
# CUDA 11.3
pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113
# CUDA 10.2
pip install torch==1.12.1+cu102 torchvision==0.13.1+cu102 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu102
# CPU only
pip install torch==1.12.1+cpu torchvision==0.13.1+cpu torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cpu

步骤一:利用代码自动下载mnist数据集

import torchvision.datasets as datasets  
import torchvision.transforms as transforms  
  
# 定义数据预处理操作  
transform = transforms.Compose([
    transforms.Resize(224), # 将图像大小调整为(224, 224)
    transforms.ToTensor(),  # 将图像转换为PyTorch张量
    transforms.Normalize((0.5,), (0.5,))  # 对图像进行归一化
])
  
# 下载并加载MNIST数据集  
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)  
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform)

步骤二:搭建基于VGG16的图像分类模型

class VGGClassifier(nn.Module):
    def __init__(self, num_classes):
        super(VGGClassifier, self).__init__()
        self.features = models.vgg16(pretrained=True).features  # 使用预训练的VGG16模型作为特征提取器
        # 重构VGG16网络的第一层卷积层,适配mnist数据的灰度图像格式
        self.features[0] = nn.Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        self.classifier = nn.Sequential(
            nn.Linear(512 * 7 * 7, 4096),  # 添加一个全连接层,输入特征维度为512x7x7,输出维度为4096
            nn.ReLU(True),
            nn.Dropout(), # 随机将一些神经元“关闭”,这样可以有效地防止过拟合。
            nn.Linear(4096, 4096),  # 添加一个全连接层,输入和输出维度都为4096
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(4096, num_classes),  # 添加一个全连接层,输入维度为4096,输出维度为类别数(10)
        )
        self._initialize_weights()  # 初始化权重参数

    def forward(self, x):
        x = self.features(x)  # 通过特征提取器提取特征
        x = x.view(x.size(0), -1)  # 将特征张量展平为一维向量
        x = self.classifier(x)  # 通过分类器进行分类预测
        return x

    def _initialize_weights(self):  # 定义初始化权重的方法,使用Xavier初始化方法
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)

步骤三:训练模型

import torch.optim as optim  
from torch.utils.data import DataLoader  
  
# 定义超参数和训练参数  
batch_size = 64  # 批处理大小  
num_epochs = 5  # 训练轮数
learning_rate = 0.01  # 学习率
num_classes = 10  # 类别数(MNIST数据集有10个类别)  
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # 判断是否使用GPU进行训练,如果有GPU则使用GPU进行训练,否则使用CPU。

# 定义训练集和测试集的数据加载器  
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)  
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)  
  
# 初始化模型和优化器  
model = VGGClassifier(num_classes=num_classes).to(device)  # 将模型移动到指定设备(GPU或CPU)  
criterion = nn.CrossEntropyLoss()  # 使用交叉熵损失函数  
optimizer = optim.SGD(model.parameters(), lr=learning_rate)  # 使用随机梯度下降优化器(SGD)  
  
# 训练模型  
for epoch in range(num_epochs):  
    for i, (images, labels) in enumerate(train_loader):  
        images = images.to(device)  # 将图像数据移动到指定设备  
        labels = labels.to(device)  # 将标签数据移动到指定设备  
          
        # 前向传播  
        outputs = model(images)  
        loss = criterion(outputs, labels)  
          
        # 反向传播和优化  
        optimizer.zero_grad()  # 清空梯度缓存  
        loss.backward()  # 计算梯度  
        optimizer.step()  # 更新权重参数  
          
        if (i+1) % 100 == 0:  # 每100个batch打印一次训练信息  
            print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, i+1, len(train_loader), loss.item()))  
              
# 保存模型参数  
torch.save(model.state_dict(), './model.pth')

步骤四:测试模型

# 加载训练好的模型参数
model.load_state_dict(torch.load('./model.pth'))
model.eval()  # 将模型设置为评估模式,关闭dropout等操作

# 定义评估指标变量
correct = 0  # 记录预测正确的样本数量
total = 0  # 记录总样本数量

# 测试模型性能
with torch.no_grad():  # 关闭梯度计算,节省内存空间
    for images, labels in test_loader:
        images = images.to(device)  # 将图像数据移动到指定设备
        labels = labels.to(device)  # 将标签数据移动到指定设备
        outputs = model(images)  # 模型前向传播,得到预测结果
        _, predicted = torch.max(outputs.data, 1)  # 取预测结果的最大值对应的类别作为预测类别
        total += labels.size(0)  # 更新总样本数量
        correct += (predicted == labels).sum().item()  # 统计预测正确的样本数量

# 计算模型准确率并打印出来
accuracy = 100 * correct / total  # 计算准确率,将正确预测的样本数量除以总样本数量并乘以100得到百分比形式的准确率。
print('Accuracy of the model on the test images: {} %'.format(accuracy))  # 打印出模型的准确率。

运行结果

在这里插入图片描述

后续模型的优化和改进建议

  1. 数据增强:通过旋转、缩放、平移等方式来增加训练数据,从而让模型拥有更好的泛化能力。
  2. 调整模型参数:可以尝试调整模型的参数,比如学习率、批次大小、迭代次数等,来提高模型的性能。
  3. 更换网络结构:可以尝试使用更深的网络结构,如ResNet、DenseNet等,来提高模型的性能。
  4. 调整优化器:本次代码采用SGD优化器,但仍可以尝试使用不同的优化器,如Adam、RMSprop等,来找到最适合我们模型的优化器。
  5. 添加正则化操作:为了防止过拟合,可以添加一些正则化项,如L1正则化、L2正则化等。
  6. 代码目前只有等训练完全结束后才能进入测试阶段,后续可以在每个epoch结束,甚至是指定的迭代次数完成后便进入测试阶段。因为训练完全结束的模型很可能已经过拟合,在测试集上不能表现较强的泛化能力。

完整代码

import torch
import torch.nn as nn

import torch.optim as optim
from torch.utils.data import DataLoader

import torchvision.models as models
import torchvision.datasets as datasets
import torchvision.transforms as transforms

import warnings
warnings.filterwarnings("ignore")

# 定义数据预处理操作
transform = transforms.Compose([
    transforms.Resize(224), # 将图像大小调整为(224, 224)
    transforms.ToTensor(),  # 将图像转换为PyTorch张量
    transforms.Normalize((0.5,), (0.5,))  # 对图像进行归一化
])

# 下载并加载MNIST数据集
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform)


class VGGClassifier(nn.Module):
    def __init__(self, num_classes):
        super(VGGClassifier, self).__init__()
        self.features = models.vgg16(pretrained=True).features  # 使用预训练的VGG16模型作为特征提取器
        # 重构网络的第一层卷积层,适配mnist数据的灰度图像格式
        self.features[0] = nn.Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        self.classifier = nn.Sequential(
            nn.Linear(512 * 7 * 7, 4096),  # 添加一个全连接层,输入特征维度为512x7x7,输出维度为4096
            nn.ReLU(True),
            nn.Dropout(), # 随机将一些神经元“关闭”,有效地防止过拟合。
            nn.Linear(4096, 4096),  # 添加一个全连接层,输入和输出维度都为4096
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(4096, num_classes),  # 添加一个全连接层,输入维度为4096,输出维度为类别数(10)
        )
        self._initialize_weights()  # 初始化权重参数

    def forward(self, x):
        x = self.features(x)  # 通过特征提取器提取特征
        x = x.view(x.size(0), -1)  # 将特征张量展平为一维向量
        x = self.classifier(x)  # 通过分类器进行分类预测
        return x

    def _initialize_weights(self):  # 定义初始化权重的方法,使用Xavier初始化方法
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)



# 定义超参数和训练参数
batch_size = 64  # 批处理大小
num_epochs = 5  # 训练轮数(epoch)
learning_rate = 0.01  # 学习率(learning rate)
num_classes = 10  # 类别数(MNIST数据集有10个类别)
device = torch.device(
    "cuda:0" if torch.cuda.is_available() else "cpu")  # 判断是否使用GPU进行训练,如果有GPU则使用第一个GPU(cuda:0)进行训练,否则使用CPU进行训练。

# 定义数据加载器
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

# 初始化模型和优化器
model = VGGClassifier(num_classes=num_classes).to(device)  # 将模型移动到指定设备(GPU或CPU)
criterion = nn.CrossEntropyLoss()  # 使用交叉熵损失函数
optimizer = optim.SGD(model.parameters(), lr=learning_rate)  # 使用随机梯度下降优化器(SGD)

# 训练模型
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):
        images = images.to(device)  # 将图像数据移动到指定设备
        labels = labels.to(device)  # 将标签数据移动到指定设备

        # 前向传播
        outputs = model(images)
        loss = criterion(outputs, labels)

        # 反向传播和优化
        optimizer.zero_grad()  # 清空梯度缓存
        loss.backward()  # 计算梯度
        optimizer.step()  # 更新权重参数

        if (i + 1) % 100 == 0:  # 每100个batch打印一次训练信息
            print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(epoch + 1, num_epochs, i + 1, len(train_loader),
                                                                     loss.item()))

# 训练结束,保存模型参数
torch.save(model.state_dict(), './model.pth')

# 加载训练好的模型参数
model.load_state_dict(torch.load('./model.pth'))
model.eval()  # 将模型设置为评估模式,关闭dropout等操作

# 定义评估指标变量
correct = 0  # 记录预测正确的样本数量
total = 0  # 记录总样本数量

# 测试模型性能
with torch.no_grad():  # 关闭梯度计算,节省内存空间
    for images, labels in test_loader:
        images = images.to(device)  # 将图像数据移动到指定设备
        labels = labels.to(device)  # 将标签数据移动到指定设备
        outputs = model(images)  # 模型前向传播,得到预测结果
        _, predicted = torch.max(outputs.data, 1)  # 取预测结果的最大值对应的类别作为预测类别
        total += labels.size(0)  # 更新总样本数量
        correct += (predicted == labels).sum().item()  # 统计预测正确的样本数量

# 计算模型准确率并打印出来
accuracy = 100 * correct / total  # 计算准确率,将正确预测的样本数量除以总样本数量并乘以100得到百分比形式的准确率。
print('Accuracy of the model on the test images: {} %'.format(accuracy))  # 打印出模型的准确率。

结束语

如果本博文对你有所帮助/启发,可以点个赞/收藏支持一下,如果能够持续关注,小编感激不尽~
如果有相关需求/问题需要小编帮助,欢迎私信~
小编会坚持创作,持续优化博文质量,给读者带来更好de阅读体验~

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/231502.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【剑指offer|图解|数组】寻找文件副本 + 螺旋遍历二维数组

🌈个人主页:聆风吟 🔥系列专栏:数据结构、剑指offer每日一练 🔖少年有梦不应止于心动,更要付诸行动。 文章目录 一. ⛳️寻找文件副本(题目难度:简单)1.1 题目1.2 示例1.3 限制1.4 解题思路一c代…

【链表OJ—分割链表】

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 目录 前言 1、分割链表题目: 方法讲解: 图文解析: 代码实现: 总结 前言 世上有两种耀眼的光芒,一种是正在升起的太…

坚鹏:中国工商银行浙江大学金融业务转型与场景营销策略培训

中国工商银行打造D-ICBC数字化转型战略,围绕“数字生态、数字资产、数字技术、数字基建、数字基因”五维布局,深入推进数字化转型,加快形成体系化、生态化实施路径,促进科技与业务加速融合,以“数字工行”建设推动“GB…

【父子进程/AES/XTEA/SMC】赛后复盘

官方wp: 进程重影技术: 进程重映像利用了Windows内核中的缓存同步问题,它会导致可执行文件的路径与从该可执行文件创建的映像节区所报告的路径不匹配。通过在一个诱饵路径上加载DLL,然 后卸载它,然后从一个新路径加载它&#x…

SQL语言重温

数据库语言重温 笔记背景SQL教程一些最重要的 SQL 命令SQL WHERE 子句SQL AND & OR 运算符SQL ORDER BY 关键字 笔记背景 由于工作需要,现重温简单SQL语言,笔记记录如下。 SQL教程 SQL(Structured Query Language:结构化查询语言&…

网络安全渗透测试

针对网络的渗透测试项目一般包括:信息收集、端口扫描、指纹识别、漏洞扫描、绘制网络拓扑、识别代理、记录结果等。下面就一一介绍。 信息收集 DNS dns信息包含(A, MX, NS, SRV, PTR, SOA, CNAME) 记录,了解不同记录的含义至关重要。 A 记录列出特定…

【Java】构建表达式二叉树和表达式二叉树求值

问题背景 1. 实现一个简单的计算器。通过键盘输入一个包含圆括号、加减乘除等符号组成的算术表达式字符串,输出该算术表达式的值。要求: (1)系统至少能实现加、减、乘、除等运算; (2)利用二叉…

C++ 运算符重载与操作符重载

目录 运算符重载 运算符重载的特性 其他运算符重载的实现 默认成员函数——赋值运算符重载 默认成员函数——取地址操作符重载 const成员 附录 运算符重载 C为了增强代码的可读性引入了运算符重载,运算符重载是具有特殊函数名的函数,也具有其返回…

搞懂内存函数

引言 本文介绍memcpy的使用和模拟实现、memmove的使用和模拟实现、memcmp使用、memset使用 ✨ 猪巴戒:个人主页✨ 所属专栏:《C语言进阶》 🎈跟着猪巴戒,一起学习C语言🎈 目录 引言 memcpy memcpy的使用 memcpy的…

智能统计账户支出,掌控财务状况,轻松修改明细。

在这个快节奏的时代,我们的生活每天都在发生着变化。无论是工资收入、购物消费,还是房租支出、投资理财,我们的财务状况也因此变得日益复杂。那么,有没有一种方法可以让我们轻松掌握自己的财务状况,实现智慧理财呢&…

面向AOP(2)spring

我是南城余!阿里云开发者平台专家博士证书获得者! 欢迎关注我的博客!一同成长! 一名从事运维开发的worker,记录分享学习。 专注于AI,运维开发,windows Linux 系统领域的分享! 本…

VUE语法--toRefs与toRef用法

1、功能概述 ref和reactive能够定义响应式的数据,当我们通过reactive定义了一个对象或者数组数据的时候,如果我们只希望这个对象或者数组中指定的数据响应,其他的不响应。这个时候我们就可以使用toRefs和toRef实现局部数据的响应。 toRefs是…

c语言选择排序总结(详解)

选择排序cpp文件项目结构截图 项目cpp文件截图 项目具体代码截图 #define _CRT_SECURE_NO_WARNINGS #include <stdio.h> #include <stdlib.h> #include <math.h> #include <iostream> #include <string.h> #include <time.h> #include &…

构建外卖系统:使用Django框架

在当今数字化的时代&#xff0c;外卖系统的搭建不再是什么复杂的任务。通过使用Django框架&#xff0c;我们可以迅速建立一个强大、灵活且易于扩展的外卖系统。本文将演示如何使用Django构建一个简单的外卖系统&#xff0c;并包含一些基本的技术代码。 步骤一&#xff1a;安装…

unity 2d 入门 飞翔小鸟 Cinemachine 记录分数(十二)

1、创建文本 右键->create->ui->leagcy->text 2、设置字体 3、设置默认值和数字 4、当切换分辨率&#xff0c;分数不见问题 拖拽这里调整 调整到如下图 5、编写得分脚本 using System.Collections; using System.Collections.Generic; using UnityEngine; …

机器人学习目标

学习目标&#xff1a; 若干年后&#xff0c;我们都将化为尘土&#xff0c;无人铭记我们的存在。那么&#xff0c;何不趁现在&#xff0c;尽己所能&#xff0c;在这个世界上留下一些痕迹&#xff0c;让未来的时光里&#xff0c;仍有人能感知到我们的存在。 机器人协会每届每个阶…

文件格式对齐、自定义快捷键、idea

文件格式对齐 Shift Alt F 自动格式化代码的快捷键&#xff08;如何配置自动格式化&#xff09; 日常编码必备idea快捷键 [VS Code] 入门-自定键盘快捷键 文件格式对齐 文件格式对齐通常是通过编辑器或IDE提供的快捷键或命令完成的。以下是一些常见编辑器和IDE中进行文件…

线边仓到底谁来管比较好

最近这段时间在客户现场出差&#xff0c;和客户聊到系统的边界时&#xff0c;客户IT希望将线边仓也纳入WMS进行管理。 给出的理由是WMS是管理实物的&#xff0c;线边仓也有实物存放&#xff0c;理所当然应该让WMS进行管理。 那线边仓能在WMS管理吗&#x…

12 RT1052的GPIO输入

文章目录 12.1 GPIO输入硬件12.1.1 GPIO初始化 12.1 GPIO输入硬件 RST 复位按键 连接至 RT1052 的 POR_B 引脚&#xff0c;当该引脚为低电平时会引起 RT1052芯片的复位 WAUP 按键 该按键在没有被按下的时候&#xff0c;引脚状态为高电平&#xff0c;当按键按下时&#xff0…

msvcr90.dll丢失的解决方法分享,5个快速修复dll文件丢失教程

在今天的电脑使用过程中&#xff0c;我们可能会遇到各种各样的问题。其中之一就是msvcr90.dll丢失的问题。那么&#xff0c;msvcr90.dll是什么&#xff1f;msvcr90.dll丢失对电脑有什么影响&#xff1f;又该如何解决这个问题呢&#xff1f;接下来&#xff0c;我将为大家详细介绍…
最新文章