使用 pytorch训练自己的图片分类模型

如何自己训练一个图片分类模型,如果一切从头开始,对于一般公司或个人基本是难以实现的。其实,我们可以利用一个现有的图片分类模型,加上新的分类,这种方式叫做迁移学习,就是把现有的模式知识,转移到新的模型。Pytorch 官网提供已经训练好的模型,可以在此基础上训练自己的模型。我们用的模型是 VGG 分类模型,首先,先运行一个已经训练好的模型可做 1000 个分类。

安装依赖

# 去官网根据系统进行下载
pip3 install torch torchvision torchaudio
pip3 install tqdm

现有模型进行图片识别

可以去百度上下载一个狗或者鸟的图片,运行下面的程序进行识别。

# 导入软件包
import numpy as np
import json
from PIL import Image

import torch
import torchvision
from torchvision import models, transforms

#生成VGG-16模型的实例
use_pretrained = True  # 使用已经训练好的参数
net = models.vgg16(pretrained=use_pretrained)
net.eval()  # 设置为推测模式

# 对输入图片进行预处理的类
class BaseTransform():
    """
    调整图片的尺寸,并对颜色进行规范化。

    Attributes
    ----------
    resize : int
       指定调整尺寸后图片的大小
    mean : (R, G, B)
       各个颜色通道的平均值
    std : (R, G, B)
       各个颜色通道的标准偏差
    """

    def __init__(self, resize, mean, std):
        self.base_transform = transforms.Compose([
            transforms.Resize(resize),  #将较短边的长度作为resize的大小
            transforms.CenterCrop(resize),  #从图片中央截取resize × resize大小的区域
            transforms.ToTensor(),  #转换为Torch张量
            transforms.Normalize(mean, std)  #颜色信息的正规化
        ])

    def __call__(self, img):
        return self.base_transform(img)

# 根据输出结果对标签进行预测的后处理类
class ILSVRCPredictor():
    """
    根据ILSVRC数据,从模型的输出结果计算出分类标签

    Attributes
    ----------
    class_index : dictionary
           将类的index与标签名关联起来的字典型变量
    """

    def __init__(self, class_index):
        self.class_index = class_index

    def predict_max(self, out):
        """
        获得概率最大的ILSVRC分类标签名

        Parameters
        ----------
        out : torch.Size([1, 1000])
            从Net中输出结果

        Returns
        -------
        predicted_label_name : str
            预测概率最高的分类标签的名称
        """
        maxid = np.argmax(out.detach().numpy())
        predicted_label_name = self.class_index[str(maxid)][1]

        return predicted_label_name
# 载入ILSVRC的标签信息,并生成字典型变量
ILSVRC_class_index = json.load(open('./data/imagenet_class_index.json', 'r'))

# 生成ILSVRCPredictor的实例
predictor = ILSVRCPredictor(ILSVRC_class_index)

# 读取输入的图像
image_file_path = './data/jww2.webp'
img = Image.open(image_file_path)  # [ 高度 ][ 宽度 ][ 颜色RGB]

# 完成预处理后,添加批次尺寸的维度
resize = 224
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
transform = BaseTransform(resize, mean, std)  #创建预处理类
img_transformed = transform(img)  # torch.Size([3, 224, 224])
inputs = img_transformed.unsqueeze_(0)  # torch.Size([1, 3, 224, 224])

# 输入数据到模型中,并将模型的输出转换为标签
out = net(inputs)  # torch.Size([1, 1000])
result = predictor.predict_max(out)

# 输出预测结果
print("输入图像的预测结果:", result)

我识别的是一只吉娃娃的图片,结果正确,Chihuahua。

现有的模型已经可以正常工作了,下面就是添加新的分类了,这里使用了蚂蚁和蜜蜂。把 1000 个分类改为了 2个分类。
net.classifier[6] = nn.Linear(in_features=4096, out_features=2)

import glob
import os.path as osp
import random
import numpy as np
import json
from PIL import Image
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import torchvision
from torchvision import models, transforms

torch.manual_seed(1234)
np.random.seed(1234)
random.seed(1234)




class ImageTransform():
    """
    图像的预处理类。训练时和推测时采用不同的处理方式
    对图像的大小进行调整,并将颜色信息标准化
    训练时采用 RandomResizedCrop 和 RandomHorizontalFlip 进行数据增强处理


    Attributes
    ----------
    resize : int
       指定调整后图像的尺寸
    mean : (R, G, B)
        各个颜色通道的平均值
    std : (R, G, B)
        各个颜色通道的标准偏差
    """

    def __init__(self, resize, mean, std):
        self.data_transform = {
            'train': transforms.Compose([
                transforms.RandomResizedCrop(
                    resize, scale=(0.5, 1.0)), #数据增强处理
                transforms.RandomHorizontalFlip(),  #数据增强处理
                transforms.ToTensor(),  # 转换为张量
                transforms.Normalize(mean, std)  # 归一化
            ]),
            'val': transforms.Compose([
                transforms.Resize(resize),  #调整大小
                transforms.CenterCrop(resize),  #从图像中央截取resize×resize大小的区域
                transforms.ToTensor(), #转换为张量
                transforms.Normalize(mean, std)  #归一化
            ])
        }

    def __call__(self, img, phase='train'):
        """
        Parameters
        ----------
        phase : 'train' or 'val'
            指定预处理所使用的模式
        """
        return self.data_transform[phase](img)

#  创建用于保存蚂蚁和蜜蜂的图片的文件路径的列表变量


def make_datapath_list(phase="train"):
    """
    创建用于保存数据路径的列表

    Parameters
    ----------
    phase : 'train' or 'val'
        指定是训练数据还是验证数据

    Returns
    -------
    path_list : list
       保存了数据路径的列表
    """

    rootpath = "./data/hymenoptera_data/"
    target_path = osp.join(rootpath+phase+'/**/*.jpg')
    print(target_path)

    path_list = []  #  保存到这里

    #  使用 glob 取得包括示例目录的文件路径
    for path in glob.glob(target_path):
        path_list.append(path)

    return path_list


class HymenopteraDataset(data.Dataset):
    """
    蚂蚁和蜜蜂图片的Dataset类,继承自PyTorch的Dataset类

    Attributes
    ----------
    file_list : 列表
        列表中保存了图片路径
    transform : object
        预处理类的实例
    phase : 'train' or 'test'
        指定是学习还是验证
    """

    def __init__(self, file_list, transform=None, phase='train'):
        self.file_list = file_list  # 文件路径列表
        self.transform = transform  # 预处理类的实例
        self.phase = phase  # 指定是train 还是val

    def __len__(self):
        '''返回图片张数'''
        return len(self.file_list)

    def __getitem__(self, index):
        '''
        获取预处理完毕的图片的张量数据和标签
        '''

        #载入第index张图片
        img_path = self.file_list[index]
        img = Image.open(img_path) #[高度][宽度][颜色RGB]

        #对图片进行预处理
        img_transformed = self.transform(
            img, self.phase)  # torch.Size([3, 224, 224])

        #从文件名中抽取图片的标签
        if self.phase == "train":
            label = img_path[30:34]
        elif self.phase == "val":
            label = img_path[28:32]

      #将标签转换为数字
        if label == "ants":
            label = 0
        elif label == "bees":
            label = 1

        return img_transformed, label

#  执行
train_list = make_datapath_list(phase="train")
val_list = make_datapath_list(phase="val")

#执行
size = 224
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
train_dataset = HymenopteraDataset(
    file_list=train_list, transform=ImageTransform(size, mean, std), phase='train')

val_dataset = HymenopteraDataset(
    file_list=val_list, transform=ImageTransform(size, mean, std), phase='val')

#指定小批次尺寸
batch_size = 32

#创建DataLoader
train_dataloader = torch.utils.data.DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True)

val_dataloader = torch.utils.data.DataLoader(
    val_dataset, batch_size=batch_size, shuffle=False)

#集中到字典变量中
dataloaders_dict = {"train": train_dataloader, "val": val_dataloader}

#确认执行结果
batch_iterator = iter(dataloaders_dict["train"])  #转换成迭代器
inputs, labels = next(
    batch_iterator) #取出第一个元素

# 载入已经学习完毕的VGG−16模型
#创建VGG−16模型的实例
use_pretrained = True #指定使用已经训练好的参数
net = models.vgg16(pretrained=use_pretrained)

#指定使用已经训练好的参数
net.classifier[6] = nn.Linear(in_features=4096, out_features=2)

#设定为训练模式
net.train()

print('网络设置完毕 :载入已经学习完毕的权重,并设置为训练模式')

# #设置损失函数
criterion = nn.CrossEntropyLoss()

params_to_update = []

#需要学习的参数名称
update_param_names = ["classifier.6.weight", "classifier.6.bias"]

#除了需要学习的那些参数外,其他参数设置为不进行梯度计算,禁止更新
for name, param in net.named_parameters():
    if name in update_param_names:
        param.requires_grad = True
        params_to_update.append(param)
        print(name)
    else:
        param.requires_grad = False

optimizer = optim.SGD(params=params_to_update, lr=0.001, momentum=0.9)


def train_model(net, dataloaders_dict, criterion, optimizer, num_epochs):

    #epoch循环
    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch+1, num_epochs))
        print('-------------')

        # 每个epoch中的学习和验证循环
        for phase in ['train', 'val']:
            if phase == 'train':
                net.train()  #将模式设置为训练模式
            else:
                net.eval()   #将模式设置为验证模式

            epoch_loss = 0.0  #epoch的合计损失
            epoch_corrects = 0 #epoch的正确答案数量

            #为了确认训练前的验证能力,省略epoch=0时的训练
            if (epoch == 0) and (phase == 'train'):
                continue

            #载入数据并切取出小批次的循环
            for inputs, labels in tqdm(dataloaders_dict[phase]):

                #初始化optimizer
                optimizer.zero_grad()

                #计算正向传播(forward)
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = net(inputs)
                    loss = criterion(outputs, labels) #计算损失
                    _, preds = torch.max(outputs, 1)  #预测标签
                    
  
                    ##训练时的反向传播
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                    #计算迭代的结果
                    # 计算迭代的结果
                    epoch_loss += loss.item() * inputs.size(0)  
                    # 更新正确答案数量的总和
                    epoch_corrects += torch.sum(preds == labels.data)

            #显示每个epoch的loss和正解率
            epoch_loss = epoch_loss / len(dataloaders_dict[phase].dataset)
            epoch_acc = epoch_corrects.double(
            ) / len(dataloaders_dict[phase].dataset)

            print('{} Loss: {:.4f} Acc: {:.4f}'.format(
                phase, epoch_loss, epoch_acc))

num_epochs=2
train_model(net, dataloaders_dict, criterion, optimizer, num_epochs=num_epochs)

在这里插入图片描述
通过运行结果可以看到,首次没有训练直接在原始模型进行测试,正确率 33%,第二轮,经过 8 次迭代学习,正确率提高到 72%,这里比较奇怪的是验证集的正确率更高。原因是训练集做了数据增广,有些图片是变形的,所以识别起来更加困难。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/573428.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【智能算法】金豺优化算法(GJO)原理及实现

目录 1.背景2.算法原理2.1算法思想2.2算法过程 3.结果展示4.参考文献 1.背景 2022年,N Chopra等人受到金豺狩猎行为启发,提出了金豺优化算法(Golden Jackal Optimization, GJO)。 2.算法原理 2.1算法思想 GJO 模拟金豺协同狩猎…

20240425在Ubuntu20.04下检测HDD机械硬盘

20240425在Ubuntu20.04下检测HDD机械硬盘 2024/4/25 14:28 百度:免费 HDD 机械硬盘坏道检测 ubuntu HDD机械硬盘 坏道检测 https://blog.csdn.net/anny0001/article/details/136001767 ubuntu 坏道扫描 Mystery_zero 已于 2024-02-02 22:20:46 修改badblocks -b 819…

Exploiting CXL-based Memory for Distributed Deep Learning——论文泛读

ICPP 2022 Paper CXL论文阅读笔记整理 问题 深度学习(DL)正被广泛用于解决不同领域的科学应用中的复杂问题。DL应用程序使用大规模高性能计算(HPC)系统来训练给定的模型,需要消耗大量数据。这些工作负载具有很大的内…

k8s使用calico网络插件时,集群内节点防火墙策略配置方法

前言 我们在内网使用k8s时,有时候需要针对整个集群的节点设置防火墙,阻止一些外部访问,或者是仅允许白名单内的ip访问,传统做法是使用firewall之类的防火墙软件,但是,使用firewall存在如下问题&#xff1a…

Unity inputSystem 读取输入值的方法

1:通过关在 PlayerInput 获取 设置后之后在同意物体上挂载C# 脚本 通过事件获得 2: 生成 C#脚本 通过C# 脚本获得 3:通过回调函数

redis中的缓存穿透问题

缓存穿透 缓存穿透问题: 一般请求来到后端,都是先从缓存中查找数据,如果缓存中找不到,才会去数据库中查询数据。 而缓存穿透就是基于这一点,不断发送请求查询不存在的数据,从而使数据库压力过大&#xff…

python+vue得物文具玩具礼品商城系统flask-django

网站素材:收集好看的素材,然后使用PS做出适合网页尺寸的图片。在需求分析阶段以前期调研结果为基础,理解系统功能、性能、可靠性等要求,采用数据流图、实体联系图、状态转换图、数据字典等给出系统的逻辑模型。在设计阶段&#xf…

【静态分析】静态分析笔记07 - 指针分析基础

参考: 【课程笔记】南大软件分析课程7——指针分析基础(课时9/10) - 简书 -------------------------------------------------------------- 1. 指针分析规则 规则:采用推导形式,横线上面是条件,横线下…

【VTKExamples::Meshes】第十八期 OBBDicer

很高兴在雪易的CSDN遇见你 VTK技术爱好者 QQ:870202403 公众号:VTK忠粉 前言 本文分享VTK样例OBBDicer,并解析接口vtkOBBDicer,希望对各位小伙伴有所帮助! 感谢各位小伙伴的点赞+关注,小易会继续努力分享,一起进步! 你的点赞就是我的动力(^U^)ノ~YO 1. …

GaussDB轻量化运维管理工具介绍

前言 本期课程将从管理平台的架构出发,结合平台的实例管理、实例升级、容灾管理和监控告警的功能和操作介绍,全面覆盖日常运维操作,带您理解并熟练运用GaussDB运维平台完成运维工作。 一、GaussDB 运维管理平台简介 开放生态层 友好Web界面…

解决office2016专业增强版 “你的许可证并非正版,你可能是盗版软件的受害者“

问题描述:安装完office后,用kms已经激活成功,但是一直在上面显示“你的许可证不是正版,并且你可能是盗版软件的受害者,使用正版Office,避免干扰并保护你的文件安全。” 尝试过网上的各种方法都没用,后面发现是用的HEU …

分享:9.3版本无缝导入AVEVA PDMS高版本工程12.0,12.1,E3D

9.3版本可以无缝导入AVEVA PDMS的工程。 UKP3d导入AVEVA PDMS工程的方法 http://47.94.91.234/forum.php?modviewthread&tid163583&fromuid6 (出处: 优易软件-工厂设计软件专家) (从AVEVA PDMS导出时元件和等级的功能我们正做收尾工作,到时可以…

Kafka---总结篇

kafka架构 主要概念 broker: 存储消息的机器 控制器controller (1)使用zookeeper, 除了提供一般的broker功能之外,还负责选举分区首领。通过在zookeepr中创建一个名为 /controller的临时节点称为 controller。每个选出的contro…

百科词条创建要多久成功?

在互联网信息爆炸的时代,百科词条作为权威的知识分享平台,其重要性不言而喻。那么,创建一个百科词条需要多久才能成功呢?创建百科词条是一个相当需要有耐心的工作,接下来伯乐网络传媒就来给大家讲一讲。 一、影响百科词…

node-sass报错如何解决

npm install 安装的时候 报node-sass错误 这个一看就是node版本兼容性导致的问题 node-sass与node版本不匹配 下面是常见的node版本和对应的node-sass版本 解决办法 1.单独安装node-sass npm install node-sass9.0.0 还是报上面的错误!!!&a…

论文笔记:Leveraging Language Foundation Models for Human Mobility Forecasting

SIGSPATIAL 2022 1intro 语言模型POI客流量预测 2 方法 3 实验

Midjourney如何利用quality控制图片质量,让细节更丰富

hello 小伙伴们,我是你们的老朋友——树下,今天分享Midjourney提示词常用参数——quality,通过更给quality的值可以生成质量更好的图片,让细节更丰富,那么这个参数是怎么用的呢?话不多说,直接开…

2014NOIP普及组真题 3. 螺旋矩阵

线上OJ: 一本通:http://ybt.ssoier.cn:8088/problem_show.php?pid1967 背景知识: 螺旋矩阵可以采用模拟的方式生成。就是顺时针四个方向 第1步、是第 1 行,方向为从左到右,数值1。当向右遇到 边界n 或者 格子已填过数…

基于卷积神经网络的手写数字识别

⚠申明: 未经许可,禁止以任何形式转载,若要引用,请标注链接地址。 全文共计3077字,阅读大概需要3分钟 🌈更多学习内容, 欢迎👏关注👀【文末】我的个人微信公众号&#xf…

海外短剧:跨文化的新浪潮与看剧系统的搭建,海外短剧系统搭建开发定制

在全球化的大潮下,海外短剧作为一种新兴的文化交流方式,正逐渐受到越来越多人的喜爱。这种融合了各地文化元素、叙事手法新颖独特的短剧形式,不仅丰富了观众的视觉体验,也为影视媒体和想拓展海外市场的企业带来了无限商机。 一、…