卷积神经网络实现彩色图像分类 - P2

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍦 参考文章:365天深度学习训练营-第P2周:彩色识别
  • 🍖 原作者:K同学啊 | 接辅导、项目定制
  • 🚀 文章来源:K同学的学习圈子

目录

  • 环境
  • 步骤
    • 环境设置
      • 包引用
      • 硬件设备
    • 数据准备
      • 数据集下载与加载
      • 数据集预览
      • 数据集准备
    • 模型设计
    • 模型训练
      • 超参数设置
      • helper函数
      • 正式训练
    • 结果呈现
  • 总结与心得体会

上周使用Pytorch构建卷积神经网络,实现了MNIST手写数字的识别,这周的目标是CIFAR10中复杂的彩色图像分类。


环境

  • 系统:Linux
  • 语言: Python 3.8.10
  • 深度学习框架:PyTorch 2.0.0+cu118

步骤

环境设置

包引用

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

import numpy as np
import matplotlib.pyplot as plt
from torchinfo import summary # 方便像tensorflow一样打印模型

硬件设备

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

数据准备

数据集下载与加载

train_dataset = datasets.CIFAR10(root='data', train=True, 
					download=True, transform=transforms.ToTensor()) # 不要忘记这个transform
test_dataset = datasets.CIFAR10(root='data', train=False, 
					download=True, transform=transforms.ToTensor())

数据集预览

image, label = train_dataset[0]
print(image.shape)
plt.figure(figsize=(20,4))
for i in range(20):
	image, label = train_dataset[i]
	plt.subplot(2, 10, i+1)
	plt.imshow(image.numpy().transpose(1,2,0)
	plt.axis('off')
	plt.title(label) # 加载的数据集没有对应的名称,暂时展示它们的id

数据集预览

数据集准备

batch_size = 32
train_loader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=batch_size)

模型设计

class Model(nn.Module):
	def __init__(self, num_classes):
		super().__init__()
		# 3x3的卷积无padding每次宽高-2
		# 2x2的最大池化,每次宽高缩短为原来的一半
		# 32x32 -> conv1 -> 30x30 -> maxpool -> 15x15
		self.conv1 = nn.Conv2d(3, 64, kernel_size=3)
		# 15x15 -> conv2 -> 13x13 -> maxpool -> 6x6
		self.conv2 = nn.Conv2d(64, 64, kernel_size=3)
		# 6x6 -> conv3 -> 4x4 -> maxpool -> 2x2
		self.conv3 = nn.Conv2d(64, 128, kernel_size=3)
		self.maxpool = nn.MaxPool2d(2),
		self.flatten = nn.Flatten(),
		self.fc1 = nn.Linear(2*2*128, 256)
		self.fc2 = nn.Linear(256, num_classes)

	def forward(self, x):
		x = F.relu(self.conv1(x))
		x = self.maxpool(x)

		x = F.relu(self.conv2(x))
		x = self.maxpool(x)

		x = F.relu(self.conv3(x))
		x = self.maxpool(x)

		x = self.flatten(x)

		x = F.relu(self.fc1(x))
		x = self.fc2(x)
		return x

model = Model(10).to(device)
summary(model, input_size=(1, 3, 32, 32))

模型结构图

模型训练

超参数设置

learning_rate = 1e-2
epochs = 10
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=learning_rate)

helper函数

def train(train_loader, model, loss_fn, optimizer):
	size = len(train_loader.dataset)
	num_batches = len(train_loader)

	train_loss, train_acc = 0, 0
	for x, y in train_loader:
		x, y = x.to(device), y.to(device)

		preds = model(x)
		loss = loss_fn(preds, y)

		optimizer.zero_grad()
		loss.backward()
		optimizer.step()

		train_loss += loss.item()
		train_acc += (preds.argmax(1) == y).type(torch.float).sum().item()

	train_loss /= num_batches
	train_acc /= size

	return train_loss, train_acc

def test(test_loader, model, loss_fn):
	size = len(test_loader.dataset)
	num_batches = len(test_loader)

	test_loss, test_acc = 0, 0
	with torch.no_grad():
		for x, y in test_loader:
			x, y = x.to(device), y.to(device)
			
			preds = model(x)
			loss = loss_fn(preds, y)
			
			test_loss += loss.item()
			test_acc += (preds.argmax(1) == y).type(torch.float).sum().item()

	test_loss /= num_batches
	test_acc /= size

	return test_loss, test_acc

def fit(train_loader, test_loader, model, loss_fn, optimizer, epochs):
	train_loss, train_acc = [], []
	test_loss, test_acc = [], []
	for epoch in range(epochs):
		model.train()
		epoch_train_loss, epoch_train_acc = train(train_loader, model, loss_fn, optimizer)
		model.eval()
		epoch_test_loss, epoch_test_acc = test(test_loader, model, loss_fn)

		train_loss.append(epoch_train_loss)
		train_acc.append(epoch_train_acc)
		test_loss.append(epoch_test_loss)
		test_acc.append(epoch_test_acc)
	return train_loss, train_acc, test_loss, test_acc

正式训练

train_loss, train_acc, test_loss, test_acc = 
				fit(train_loader, test_loader, model, loss_fn, optimizer, 20)

训练结果

结果呈现

series = range(len(train_loss))
plt.figure(figsize=(12,4))
plt.subplot(1,2,1)
plt.plot(series, train_loss, label='train loss')
plt.plot(series, test_loss, label='validation loss')
plt.legend(loc='upper right')
plt.title('Loss')
plt.subplot(1,2,2)
plt.plot(series, train_acc, label='train accuracy')
plt.plot(series, test_acc, label='validation accuracy')
plt.legend(loc='lower right')
plt.title('Accuracy')

实验结果
从结果图可以发现,模型应该还没收敛,将epoch设置为30,重新跑一遍模型。
实验结果2
可以看出20个epoch后,训练集上的正确率持续增长,在验证集上的正确率几乎就不再增长了,符合过拟合的特征。需要对模型进行改进才能提升正确率了。


总结与心得体会

通过本周的学习,掌握了使用pytorch编写一个完整深度学习的过程,包括环境的配置、数据的准备、模型定义与训练、结果分析呈现等步骤,并且掌握了通过pytorch的API组建一个简单的卷积神经网络的过程。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/71438.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【考研复习】24王道数据结构课后习题代码|第3章栈与队列

文章目录 3.1 栈3.2 队列3.3 栈和队列的应用 3.1 栈 int symmetry(linklist L,int n){char s[n/2];lnode *pL->next;int i;for(i0;i<n/2;i){s[i]p->data;pp->next;}i--;if(n%21) pp->next;while(p&&s[i]p->data){i--;pp->next;}if(i-1) return 1;…

Vue实现图片懒加载

图片懒加载&#xff08;Lazy Loading&#xff09;是一种前端优化技术&#xff0c;旨在改善网页加载性能和用户体验。它的基本原理是&#xff0c;将网页中的图片延迟加载&#xff0c;只有当用户滚动到图片所在的位置时&#xff0c;才会加载图片内容&#xff0c;从而减少初始页面…

QT生成Word PDF文档

需求&#xff1a;将软件处理的结果保存为一个报告文档&#xff0c;文档中包含表格、图片、文字&#xff0c;格式为word的.doc和.pdf。生成word是为了便于用户编辑。 开发环境&#xff1a;qt4.8.4vs2010 在qt的官网上对于pdf的操作介绍如下&#xff1a;http://qt-project.org/…

易服客工作室:Elementor AI简介 – 彻底改变您创建网站的方式

Elementor 作为领先的 WordPress 网站构建器&#xff0c;是第一个添加本机 AI 集成的。Elementor AI 的第一阶段将使您能够生成和改进文本和自定义代码&#xff08;HTML、自定义代码和自定义 CSS&#xff09;。我们还已经在进行以下阶段的工作&#xff0c;其中将包括基于人工智…

2023年游戏买量能怎么玩?

疫情过后&#xff0c;一地鸡毛。游戏行业的日子也不好过。来看看移动游戏收入&#xff1a;2022年&#xff0c;移动游戏收入达到920亿美元&#xff0c;同比下降6.4%。这告诉我们&#xff0c;2022年对移动游戏市场来说是一个小挫折。 但不管是下挫还是上升&#xff0c;移动游戏市…

pytest fixture 用于teardown工作

fixture通过scope参数控制setup级别&#xff0c;setup作为用例之前前的操作&#xff0c;用例执行完之后那肯定也有teardown操作。这里用到fixture的teardown操作并不是独立的函数&#xff0c;用yield关键字呼唤teardown操作。 举个例子&#xff1a; 输出&#xff1a; 说明&…

剑指 Offer 61. 扑克牌中的顺子

题目描述 从若干副扑克牌中随机抽 5 张牌&#xff0c;判断是不是一个顺子&#xff0c;即这5张牌是不是连续的。2&#xff5e;10为数字本身&#xff0c;A为1&#xff0c;J为11&#xff0c;Q为12&#xff0c;K为13&#xff0c;而大、小王为 0 &#xff0c;可以看成任意数字。A 不…

通讯协议038——全网独有的OPC HDA知识一之聚合(七)最小值

本文简单介绍OPC HDA规范的基本概念&#xff0c;更多通信资源请登录网信智汇(wangxinzhihui.com)。 本节旨在详细说明HDA聚合的要求和性能。其目的是使HDA聚合标准化&#xff0c;以便HDA客户端能够可靠地预测聚合计算的结果并理解其含义。如果用户需要聚合中的自定义功能&…

程序使用Microsoft.XMLHTTP对象请求https时出错解决

程序中使用Microsoft.XMLHTTP组件请求https时出现如下错误&#xff1a; 出错程序代码示例&#xff1a; strUrl "https://www.xxx.com/xxx.asp?id11" dim objXmlHttp set objXmlHttp Server.CreateObject("Microsoft.XMLHTTP") objXmlHttp.open "…

SpringCloudGateway配置跨域设置以及如何本地测试跨域

问题背景 有个服务A &#xff0c;自身对外提供服务&#xff0c;几个系统的前端页面也在调用&#xff0c;使用springboot 2.6.8开发的&#xff0c;自身因为有前端直接调用已经配置了跨域。 现在有网关服务&#xff0c;一部分前端通过网关访问服务A&#xff08;因为之前没有网关…

springboot 基础

巩固基础&#xff0c;砥砺前行 。 只有不断重复&#xff0c;才能做到超越自己。 能坚持把简单的事情做到极致&#xff0c;也是不容易的。 SpringBoot JavaEE 简介 JavaEE的局限性&#xff1a; 1、过于复杂&#xff0c;JavaEE正对的是复杂的分布式企业应用&#xff0c;然而现实…

桥接模式-java实现

桥接模式 桥接模式的本质&#xff0c;是解决一个基类&#xff0c;存在多个扩展维度的的问题。 比如一个图形基类&#xff0c;从颜色方面扩展和从形状上扩展&#xff0c;我们都需要这两个维度进行扩展&#xff0c;这就意味着&#xff0c;我们需要创建一个图形子类的同时&#x…

软件测试基础篇——MySQL

MySQL 1、数据库技术概述 数据库database&#xff1a;存放和管理各种数据的仓库&#xff0c;操作的对象主要是【数据data】&#xff0c;科学的组织和存储数据&#xff0c;高效的获取和处理数据SQL&#xff1a;结构化查询语言&#xff0c;专为**关系型数据库而建立的操作语言&…

科技云报道:一波未平一波又起?AI大模型再出邪恶攻击工具

AI大模型的快速向前奔跑&#xff0c;让我们见识到了AI的无限可能&#xff0c;但也展示了AI在虚假信息、深度伪造和网络攻击方面的潜在威胁。 据安全分析平台Netenrich报道&#xff0c;近日&#xff0c;一款名为FraudGPT的AI工具近期在暗网上流通&#xff0c;并被犯罪分子用于编…

基于微信小程序的应届大学生招聘平台的设计与实现

伴随着社会以及科学技术的发展&#xff0c;互联网已经渗透在人们的身边&#xff0c;网络慢慢的变成了人们的生活必不可少的一部分&#xff0c;紧接着众多智能手机飞速的发展&#xff0c;小程序这一名词已不陌生&#xff0c;越来越多的企业、公司、高校、医院等机构都会使用小程…

C++学习笔记——从面试题出发学习C++

C学习笔记——从面试题出发学习C C学习笔记——从面试题出发学习C1. 成员函数的重写、重载和隐藏的区别&#xff1f;2. 构造函数可以是虚函数吗&#xff1f;内联函数可以是虚函数吗&#xff1f;析构函数为什么一定要是虚函数&#xff1f;3. 解释左值/右值、左值/右值引用、std:…

Clickhouse学习系列——一条SQL完成gourp by分组与不分组数值计算

笔者在近一两年接触了Clickhouse数据库&#xff0c;在项目中也进行了一些实践&#xff0c;但一直都没有一些技术文章的沉淀&#xff0c;近期打算做个系列&#xff0c;通过一些具体的场景将Clickhouse的用法进行沉淀和分享&#xff0c;供大家参考。 首先我们假设一个Clickhouse数…

EXPLAIN使用分析

系列文章目录 文章目录 系列文章目录一、type说明二、MySQL中使用Show Profile1.查看当前profiling配置2.在会话级别修改profiling配置3.查看profile记录4.要深入查看某条查询执行时间的分布 一、type说明 我们只需要注意一个最重要的type 的信息很明显的提现是否用到索引&…

单参数构造函数的隐式类型转化

单参数构造函数的隐式类型转化 如果你不想发生隐式类型的转化&#xff0c;可以在默认构造函数前加上关键字&#xff1a;explicit 多参数的玩法和单参数的是不一样的 c98 不支持多参数隐式类型的转化 c11 支持多参数隐式类型的转化 举个例子&#xff1a; 多参数可以这样写&…

考虑分布式电源的配电网无功优化问题研究(Matlab代码实现)

&#x1f4a5;&#x1f4a5;&#x1f49e;&#x1f49e;欢迎来到本博客❤️❤️&#x1f4a5;&#x1f4a5; &#x1f3c6;博主优势&#xff1a;&#x1f31e;&#x1f31e;&#x1f31e;博客内容尽量做到思维缜密&#xff0c;逻辑清晰&#xff0c;为了方便读者。 ⛳️座右铭&a…
最新文章