基于vgg16进行迁移学习服装分类

pytorch深度学习项目实战100例 的学习记录

我的环境:

白票大王: google colab

用其他的话,其实实现也行,但是让小白来重环境来开始安装的话,浪费时间

数据集

Clothing dataset
20 个不同类别的 5000 多张图片。

在这里插入图片描述

该数据集可免费用于任何目的,包括商业用途:

例如

创建教程或课程(免费或付费)
撰写书籍
Kaggle 竞赛(作为外部数据集)
训练任何公司的内部模型
数据
images.csv 文件包含

image - 图像的 ID(用它从 images/.jpg 中加载图像)
sender_id - 提供图片者的 ID
label - 图片的类别
kids - 标记,如果是儿童服装,则为 True

数据集下载

VGG网络,全称为Visual Geometry Group Network

是由牛津大学的Visual Geometry Group提出的一种深度学习网络架构。该网络在2014年的ImageNet挑战赛中取得了优异的成绩,主要以其简洁的架构和出色的性能著称。VGG网络的引入对于深度学习和计算机视觉领域的发展产生了重要影响,以下是VGG网络的一些主要特点和网络结构的详细介绍:
在这里插入图片描述

特点

  • 统一的卷积核尺寸:VGG网络使用了统一的
    3×3的卷积核,步长为1。这种小尺寸卷积核的使用是VGG网络的一个标志性特点,它允许网络通过堆叠多层小卷积核来增加网络深度,同时保持计算效率。

  • 增加网络深度:VGG网络通过重复堆叠卷积层和池化层来增加网络深度,最著名的版本是VGG-16和VGG-19,分别含有16层和19层权重层。这种深度的增加显著提高了网络的性能。

  • 使用ReLU激活函数:VGG网络在每个卷积层后使用了ReLU(Rectified Linear Unit)作为激活函数,以增加非线性特性并加速训练过程。

  • 池化层的使用:VGG网络在连续的几个卷积层之后使用最大池化层来进行特征下采样,这有助于减少计算量和防止过拟合。

网络结构

VGG网络的基本单元是连续的卷积层,后面跟着ReLU激活函数,然后是可选的池化层。这一基本单元重复多次,形成了VGG网络的主体。以下是VGG-16网络的一个典型结构:

  • 输入层:接受固定尺寸的图像,如224×224的RGB图像。

  • 卷积层:多个3×3卷积层,每个卷积层后面跟着ReLU激活函数。在初期层次中,卷积层的数量较少(通常是2-3层),在更深的层次中数量增多。

  • 池化层:每几个卷积层后会跟一个2×2的最大池化层,用于下采样。

  • 全连接层:网络末端有三个全连接层,前两个全连接层各有4096个节点,最后一个全连接层用于分类,节点数与分类任务的类别数相等。

  • 输出层:最后是一个softmax层,用于输出分类的概率分布。

VGG网络的这种结构设计简洁而有效,尤其是在处理图像分类和识别任务时表现出色。

vgg 迁移学习

首先,利用torchvision中的models,通过models调用vgg16模型,如果参数pretrainedtrue,就会自动下载

model = models.vgg16(pretrained=True)

查看网络结构

model._modules

在这里插入图片描述

VGG模型由于其出色的特征提取能力,常被用于迁移学习中,特别是在图像识别和分类任务上。迁移学习允许我们利用在大型数据集(如ImageNet)上预训练的模型,将其应用于数据量较小的特定任务上,通常可以显著提高性能。但是ImageNet里面有1000类别,我们数据集只有19个类

model.classifier[6]

在这里插入图片描述

model.classifier[6] = nn.Linear(in_features=4096, out_features=19)

阿光的代码,完整代码

import json
import os

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets
import torchvision.models as models
from tqdm import tqdm


device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

data_transform = {
    'train': transforms.Compose([transforms.RandomResizedCrop(224),
                                 transforms.RandomHorizontalFlip(),
                                 transforms.ToTensor(),
                                 transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
}

image_path = './images'

train_dataset = datasets.ImageFolder(root=os.path.join(image_path, 'train'),
                                     transform=data_transform['train'])

batch_size = 32

train_loader = torch.utils.data.DataLoader(train_dataset,
                                           batch_size,
                                           True)

print('using {} images for training.'.format(len(train_dataset)))

cloth_list = train_dataset.class_to_idx
class_dict = {}
for key, val in cloth_list.items():
    class_dict[val] = key
with open('class_dict.pk', 'wb') as f:
    pickle.dump(class_dict, f)

model_path = './checkpoints/vgg16-397923af.pth'

model = models.vgg16(pretrained=False)
model.load_state_dict(torch.load(model_path, 'cpu'))

for parma in model.parameters():  # 设置自动梯度为false
    parma.requires_grad = False

model.classifier[6] = nn.Linear(in_features=4096, out_features=19)

model.to(device)

loss_function = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.03)

epochs = 10
best_acc = 0
best_model = None
save_path = './checkpoints/best_model.pkl'

for epoch in range(epochs):
    model.train()
    running_loss = 0
    epoch_acc = 0 # 每个epoch的准确率
    epoch_acc_count = 0 # 每个epoch训练的样本数
    train_count = 0 # 用于计算总的样本数,方便求准确率
    train_bar = tqdm(train_loader)
    for data in train_bar:
        images, labels = data
        optimizer.zero_grad()
        output = model(images.to(device))
        loss = loss_function(output, labels.to(device))
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,
                                                                 epochs,
                                                                 loss)
        # 计算每个epoch正确的个数
        epoch_acc_count += (output.argmax(axis=1) == labels.view(-1)).sum()
        train_count += len(images)
        
    # 每个epoch对应的准确率
    epoch_acc = epoch_acc_count / train_count
    
    # 打印信息
    print("【EPOCH: 】%s" % str(epoch + 1))
    print("训练损失为%s" % str(running_loss))
    print("训练精度为%s" % (str(epoch_acc.item() * 100)[:5]) + '%')
    
    if epoch_acc > best_acc:
        best_acc = epoch_acc
        best_model = model.state_dict()
        
     # 在训练结束保存最优的模型参数
    if epoch == epochs - 1:
        # 保存模型
        torch.save(best_model, save_path)
        
print('Finished Training')

with open('class_dict.pk', 'rb') as f:
    class_dict = pickle.load(f)
    
data_transform = transforms.Compose(
    [transforms.Resize(256),
     transforms.CenterCrop(224),
     transforms.ToTensor(),
     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

img_path = r'./images/test/0f2ac7b8-c769-4f0b-a915-5680d0dad34c.jpg'

img = Image.open(img_path)

img = data_transform(img)

img = torch.unsqueeze(img, dim=0)

pred = class_dict[model(img).argmax(axis=1).item()]
print('【预测结果分类】:%s' % pred)

没有下载他弄好的数据集,当然不行拉阿 TT

然后我去直接下载github的

发现了skip里面的图片有几张是错误的

只能跳过

import pandas as pd
# CSV文件路径
csv_file_path = '/content/clothing-dataset/images.csv'
# 图片所在目录
images_directory = '/content/clothing-dataset/images/'

# 读取CSV文件
df = pd.read_csv(csv_file_path)

# 过滤掉标签为'skip'的行
df_filtered = df[df['label'] != 'Skip']

# print(df_filtered)

# 找出所有应该被删除的图片文件
images_to_delete = df[df['label'] == 'Skip']['image']

# images_to_delete

for image_id in images_to_delete:
  image_path  = os.path.join(images_directory,f"{image_id}.jpg")
  if os.path.exists(image_path):
    os.remove(image_path)
    print(f"Deleted{image_path}")

# 保存更新后的DataFrame到新的CSV文件
df_filtered.to_csv('/content/clothing-dataset/updated_file.csv', index=False)  # 指定新的CSV文件路径
# 再次加载更新后的CSV文件进行验证
df_updated = pd.read_csv('/content/clothing-dataset/updated_file.csv')
print("含'skip'标签的行数:", df_updated[df_updated['label'] == 'Skip'].shape[0])

优化器


loss_function = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.03)

训练和保存模型

model.to(device)

epochs = 10
best_acc = 0
best_model = None
save_path = './checkpoints/best_model.pkl'

for epoch in range(epochs):
    model.train()
    running_loss = 0
    epoch_acc = 0 # 每个epoch的准确率
    epoch_acc_count = 0 # 每个epoch训练的样本数
    train_count = 0 # 用于计算总的样本数,方便求准确率
    train_bar = tqdm(train_loader)
    for data in train_bar:
        images, labels = data
        optimizer.zero_grad()
        output = model(images.to(device))
        loss = loss_function(output, labels.to(device))
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,
                                                                 epochs,
                                                                 loss)
        # 计算每个epoch正确的个数
        epoch_acc_count += (output.argmax(axis=1) == labels.view(-1)).sum()
        train_count += len(images)
        
    # 每个epoch对应的准确率
    epoch_acc = epoch_acc_count / train_count
    
    # 打印信息
    print("【EPOCH: 】%s" % str(epoch + 1))
    print("训练损失为%s" % str(running_loss))
    print("训练精度为%s" % (str(epoch_acc.item() * 100)[:5]) + '%')
    
    if epoch_acc > best_acc:
        best_acc = epoch_acc
        best_model = model.state_dict()
        
     # 在训练结束保存最优的模型参数
    if epoch == epochs - 1:
        # 保存模型
        torch.save(best_model, save_path)
        
print('Finished Training')
with open('class_dict.pk', 'rb') as f:
    class_dict = pickle.load(f)
    
data_transform = transforms.Compose(
    [transforms.Resize(256),
     transforms.CenterCrop(224),
     transforms.ToTensor(),
     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

img_path = r'./images/test/0f2ac7b8-c769-4f0b-a915-5680d0dad34c.jpg'

img = Image.open(img_path)

img = data_transform(img)

img = torch.unsqueeze(img, dim=0)

pred = class_dict[model(img).argmax(axis=1).item()]
print('【预测结果分类】:%s' % pred)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/433046.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

基于springboot+vue实现电子商务平台管理系统项目【项目源码+论文说明】

基于springboot实现电子商务平台管理系统演示 研究的目的和意义 据我国IT行业发布的报告表明,近年来,我国互联网发展呈快速增长趋势,网民的数量已达8700万,逼近世界第一,并且随着宽带的实施及降价,每天约有…

【机器学习】包裹式特征选择之递归特征消除法

🎈个人主页:豌豆射手^ 🎉欢迎 👍点赞✍评论⭐收藏 🤗收录专栏:机器学习 🤝希望本文对您有所裨益,如有不足之处,欢迎在评论区提出指正,让我们共同学习、交流进…

基于Arduino的智能寻迹小车设计

目 录 摘 要 Ⅰ Abstract Ⅱ 引 言 1 1系统方案设计 3 1.1 方案论证 3 1.2 项目的总体设计 4 2 项目硬件设计 6 2.1 Arduino平台简介 6 2.2 ATmega328P单片机的最小系统 8 2.3 寻迹模块的设计 9 2.4 驱动模块的设计 11 2.5 电源模块的设计 14 2.6 按键电路的设计 15 2.7 蜂鸣器…

c++|内存管理

c|内存管理 C/C内存分布strlen 和 sizeof的区别 c语言动态内存管理方式malloccallocrealloc例题 c管理方式new/delete操作内置类型new/delete操作自定义类型证明 new 和 delete 的底层原理operator new与operator delete函数operator new 和 operator delete的 用法构造函数里面…

独家揭秘:AI大模型的神秘面纱

AI大模型,是当下人工智能领域里备受瞩目的技术,在推动科技进步和社会发展方面发挥着重要作用。然而,AI大模型的神秘面纱始终让人们充满好奇和探究。 首先,让我们来揭开AI大模型的面纱。在人工智能领域中,大模型是指参…

Idea 开启热部署 Devtools

一、背景 当我们在 idea 中修改代码的时候,idea 并不会自动的重启去响应我们修改的内容,而是需要我们手动的重新启动项目才可以生效,这个是非常不方便,但是可以在 idea 中开启这个自动热部署的功能。 我的 idea 版本为 2022.3.3 。…

Mosquitto介绍

一、Mosquitto介绍 Eclipse Mosquitto是一个开源的MQTT消息代理(服务器)软件。提供轻量级的,支持可发布/可订阅的的消息推送模式,使设备对设备之间的短消息通信变得简单,比如现在应用广泛的低功耗传感器,手…

怎么将电脑excel文档内的数据转换为图片形式

你平时在办公室会遇到格式转换的问题吗?比如PDF转Word,WPS转PDF,PDF转TXT,图片转PDF等。边肖最近在工作过程中遇到了类似的问题。为了更方便的查看表格,需要将Excel表格转换成图片格式。遇到这样的问题,很多…

Excel小技巧 (2) - 如何去除和增加前导0

1. 如何去除前导0 公式:SUBSTITUTE(A2,0,""),然后拖动十字架,同步所有列数据,轻松搞定。 2. 如何补充前导0 公式:TEXT(D2,"0000000") ,0的个数是数字的完整位数。然后拖动十字架&a…

LiveNVR监控流媒体Onvif/RTSP功能-视频广场点击在线或离线时展示状态记录快速查看通道离线原因

LiveNVR视频广场点击在线或离线时展示状态记录快速查看通道离线原因 1、状态记录1.1、点击在线查看1.2、点击离线查看 2、RTSP/HLS/FLV/RTMP拉流Onvif流媒体服务 1、状态记录 1.1、点击在线查看 可以点击视频广场页面中, 在线 两个字查看状态记录 1.2、点击离线查…

解决Windows自定义快捷键打开快捷方式慢的问题

主要是微软拼音的自学习在捣鬼。 关闭自学习即可。

免费IP地址证书

IP地址证书,又称为IP证书或IP地址所有权证书,是一种证明特定IP地址归属和合法使用的电子凭证。它通常由权威机构颁发,如互联网地址分配机构(IANA)或其下属的区域互联网注册管理机构(RIRs)。IP地…

MySQL 元数据锁及问题排查(Metadata Locks MDL)

"元数据"是用来描述数据对象定义的,而元数据锁(Metadata Lock MDL)即是加在这些定义上。通常我们认为非锁定一致性读(简单select)是不加锁的,这个是基于表内数据层面,其依然会对表的元…

第106讲:Mycat实践指南:范围分片下的水平分表详解

文章目录 1.Mycat水平拆分的分片规则2. Mycat水平拆分之范围分片2.1.使用范围分片水平分表的背景2.2.水平分表范围分片案例2.3.准备测试的表结构2.4.配置Mycat实现范围分片的水平分表2.4.1.配置Schema配置文件2.4.2.配置Rule分片规则配置文件2.4.3.配置Server配置文件2.4.4.重启…

高级语言讲义2018计专(仅高级语言部分)

1.编写完整程序解决中国古代数学家张丘健在他的《算经》中提出的”百钱百鸡问题“:鸡翁一,值钱五;鸡母一,值钱三;鸡雏三,值钱一;百钱买百鸡,翁,母,雏各几何 …

x6.js 流程图绘制笔记,常用函数

官方参考网站如下:https://antv-x6.gitee.io/zh/docs/tutorial/about 安装x6 输入以下命令 npm install antv/x6 --save 引用插件代码如下: import { Graph } from antv/x6; 创建绘制区域 this.guiX6 new Graph({container: document.querySelect…

个人社区 项目测试

目 录 一.背景及介绍二.功能详情三.手动测试1.编写测试用例2.测试 一.背景及介绍 该项目采用了前后端分离技术,把我们的数据保存到数据库中,操作对象是用户和个人文章编辑保存,前端的页面实现了登录,列表,编辑&#x…

基于单片机的蓝牙无线密码锁设计

目 录 摘 要 Ⅰ Abstract Ⅱ 引 言 1 1 系统总体设计 3 1.1 系统设计要求 3 1.2 系统设计思路 3 2 系统硬件设计 5 2.1 设计原理 5 2.2 主控模块 5 2.3 芯片模块 8 2.4 矩阵键盘模块 9 2.5 液晶显示模块 10 2.6 继电器驱动模块 12 2.7 蜂鸣器模块 13 2.8 蓝牙模块 14 3 系统软…

鸿蒙4.0-DevEco Studio界面工程

DevEco Studio界面工程 DevEco Studio 下载与第一个工程新建的第一个工程界面回到Project工程结构来看 DevEco Studio 下载与第一个工程 DevEco Studio 下载地址: https://developer.harmonyos.com/cn/develop/deveco-studio#download 学习课堂以及文档地址&#x…

Docker 快速入门实操教程(完结)

Docker 快速入门实操教程(完结) Docker,启动! 如果安装好Docker不知道怎么使用,不理解各个名词的概念,不太了解各个功能的用途,这篇文章应该会对你有帮助。 前置条件:已经安装Doc…
最新文章