Pytorch:利用torchvision调用各种网络的预训练模型,完成CIFAR10数据集的各种分类任务

2023.7.19

cifar10百科:

[ 数据集 ] CIFAR-10 数据集介绍_cifar10_Horizon Max的博客-CSDN博客

torchvision各种预训练模型的调用方法:

pytorch最全预训练模型下载与调用_pytorch预训练模型下载_Jorbol的博客-CSDN博客

CIFAR10数据集下载并转换为图片:

文件结构:

import torchvision
from torch.utils.data import DataLoader
import os
import numpy as np
import imageio  # 引入imageio包

train_data = torchvision.datasets.CIFAR10(root='./dataset', train=True, transform=torchvision.transforms.ToTensor(),
                                          download=True)
test_data = torchvision.datasets.CIFAR10(root='./dataset', train=False, transform=torchvision.transforms.ToTensor(),
                                         download=True)

# 路径将测试集作为验证集
train_path = './dataset/cifar-10-batches-py/train'
test_path = './dataset/cifar-10-batches-py/val'

for i in range(10):
    file_name = train_path + '/'+ str(i)
    if not os.path.exists(file_name):
        os.mkdir(file_name)

for i in range(10):
    file_name = test_path + '/' + str(i)
    if not os.path.exists(file_name):
        os.mkdir(file_name)

# 解压 返回解压后的字典
def unpickle(file):
    import pickle as pk
    fo = open(file, 'rb')
    dict = pk.load(fo, encoding='iso-8859-1')
    fo.close()
    return dict


# begin unpickle
root_dir = "./dataset/cifar-10-batches-py"
# 生成训练集图片
print('loading_train_data_')
for j in range(1, 6):
    dataName = root_dir + "/data_batch_" + str(j)  # 读取当前目录下的data_batch1~5文件。
    Xtr = unpickle(dataName)
    print(dataName + " is loading...")

    for i in range(0, 10000):
        img = np.reshape(Xtr['data'][i], (3, 32, 32))  # Xtr['data']为图片二进制数据
        img = img.transpose(1, 2, 0)  # 读取image
        picName = root_dir + '/train/' + str(Xtr['labels'][i]) + '/' + str(i + (j - 1) * 10000) + '.jpg'
        imageio.imsave(picName, img)  # 使用的imageio的imsave类
    print(dataName + " loaded.")

# 生成测试集图片(将测试集作为验证集)
print('loading_val_data_')
testXtr = unpickle(root_dir + "/test_batch")
for i in range(0, 10000):
    img = np.reshape(testXtr['data'][i], (3, 32, 32))
    img = img.transpose(1, 2, 0)
    picName = root_dir + '/val/' + str(testXtr['labels'][i]) + '/' + str(i) + '.jpg'
    imageio.imsave(picName, img)

 训练代码:

                                                           AlexNet 结构

1,查询需要的模型网站并填入

# 预训练模型官网
model_urls = {
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
}

2, 加载预训练模型(在10分类的情况下,AlexNet的效果比ResNet18的效果稍好)

先print(Model),,看fc层的参数,设置好fc层

3,以AlexNet10分类任务来看 

  

# Model = models.densenet161(pretrained=True)

Model = models.resnet18(pretrained=True)
for param in Model.parameters():
param.requires_grad = True
# print(Model)
# Model.fc = nn.Linear(2208, class_num)
Model.fc = nn.Linear(512, class_num)

 此时,我们需要将在导入AlexNet模型下载的官网:

# 预训练模型官网:
model_urls = {'alexnet': 'https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth',}

# 加载模型:

Model = models.alexnet(pretrained=True)
for param in Model.parameters():
param.requires_grad = True
print(Model)

   

# 根据上面Classifier层的最后(6)Linear的输入通道 in_features = 4096 更改 model的fc层

Model.fc = nn.Linear(4096, class_num)

from torch.utils.data import DataLoader, Dataset
from PIL import Image
import numpy as np
import os
import torch
import torch.optim as optim
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

transform = transforms.Compose([transforms.Resize((224, 224)),
                                transforms.ToTensor(),
                                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ]
                               )


class MyDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.names_list = []

        for dirs in os.listdir(self.root_dir):
            dir_path = self.root_dir + '/' + dirs
            for imgs in os.listdir(dir_path):
                img_path = dir_path + '/' + imgs
                self.names_list.append((img_path, dirs))

    def __len__(self):
        return len(self.names_list)

    def __getitem__(self, index):
        image_path, label = self.names_list[index]
        if not os.path.isfile(image_path):
            print(image_path + '不存在该路径')
            return None
        image = Image.open(image_path).convert('RGB')

        label = np.array(label).astype(int)
        label = torch.from_numpy(label)


        if self.transform:
            image = self.transform(image)

        return image, label


if __name__ == '__main__':
    # 准备数据集
    train_data_path = './dataset/cifar-10-batches-py/train'
    val_data_path = './dataset/cifar-10-batches-py/val'

    # 数据长度
    train_data_length = len(train_data_path)
    val_data_length = len(val_data_path)

    # 分类的类别
    class_num = 10

    # 迭代次数
    epoch = 30

    # 学习率
    learning_rate = 0.00001

    # 批处理大小
    batch_size = 128

    # 数据加载器
    train_dataset = MyDataset(train_data_path, transform)
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

    val_dataset = MyDataset(val_data_path, transform)
    val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    train_data_size = len(train_dataset)
    val_data_size = len(val_dataset)

    # 预训练模型官网
    model_urls = {'alexnet': 'https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth',}

    # 调用预训练调整全连接层:搭建网络
    # Model = models.densenet161(pretrained=True)
    Model = models.alexnet(pretrained=True)
    for param in Model.parameters():
        param.requires_grad = True
    print(Model)
    # Model.fc = nn.Linear(2208, class_num)
    Model.fc = nn.Linear(4096, class_num)

    # 创建网络模型
    ModelOutput = Model.cuda()  # DenseNet161 ResNet18

    # 采用多GPU训练
    if torch.cuda.device_count() > 1:
        print("使用", torch.cuda.device_count(), "个GPUs进行训练")
        ModelOutput = nn.DataParallel(ModelOutput)
    else:
        ModelOutput = Model.to(device)  # .Cuda()数据是指放到GPU上
        print("使用", torch.cuda.device_count(), "个GPUs进行训练")

    # 定义损失函数
    loss_fn = nn.CrossEntropyLoss().cuda()  # 交叉熵函数

    # 定义优化器
    optimizer = optim.Adam(ModelOutput.parameters(), lr=learning_rate)

    # 记录验证的次数
    total_train_step = 0
    total_val_step = 0

    # 训练
    acc_list = np.zeros(epoch)
    print("{0:-^27}".format('Train_Model'))
    for i in range(epoch):
        print("----------epoch={}----------".format(i + 1))
        ModelOutput.train()
        for data in train_dataloader:  # data 是batch大小
            image_train_data, t_labels = data
            image_train_data = image_train_data.cuda()
            t_labels = t_labels.cuda()
            output = ModelOutput(image_train_data)
            loss = loss_fn(output, t_labels.long())

            # 优化器优化模型
            optimizer.zero_grad()  # 梯度清零
            loss.backward()  # 反向传播
            optimizer.step()  # 优化更新参数

            total_train_step = total_train_step + 1
            print("train_times:{},Loss:{}".format(total_train_step, loss.item()))

        # 验证步骤开始
        ModelOutput.eval()
        total_val_loss = 0
        total_accuracy = 0
        with torch.no_grad():  # 测试的时候不需要对梯度进行调整,所以梯度设置不调整
            for data in val_dataloader:
                image_val_data, v_labels = data
                image_val_data = image_val_data.cuda()
                v_labels = v_labels.cuda()
                outputs = ModelOutput(image_val_data)
                loss = loss_fn(outputs, v_labels.long())
                total_val_loss = total_val_loss + loss.item()  # 计算损失值的和
                accuracy = 0

                for j in v_labels:  # 计算精确度的和

                    if outputs.argmax(1)[j] == v_labels[j]:
                        accuracy = accuracy + 1

                # accuracy = (outputs.argmax(1) == v_labels).sum()  # 计算一个数据的精确度
                total_accuracy = total_accuracy + accuracy

        val_acc = float(total_accuracy / val_data_size) * 100
        acc_list[i] = val_acc  # 记录验证集的正确率
        print('the_classification_is_correct :', total_accuracy, val_data_length)
        print("val_Loss:{}".format(total_val_loss))
        print("val_acc:{}".format(val_acc), '%')
        total_val_step += 1
        torch.save(ModelOutput, "Model_{}.pth".format(i + 1))
        # torch.save(ModelOutput.module.state_dict(), "Model_{}.pth".format(i + 1))
        print("{0:-^24}".format('Model_Saved'), '\n')
        print('val_max=', max(acc_list), '%', '\n')  # 验证集的最高正确率

测试代码:

import torch
from torchvision import transforms

import os
from PIL import Image

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')  # 判断是否有GPU

model = torch.load('Model_8.pth')  # 加载模型

path = "./dataset/cifar-10-batches-py/test/"  # 测试集

imgs = os.listdir(path)

test_num = len(imgs)
print(f"test_dataset_quantity={test_num}")

for img_name in imgs:
    img = Image.open(path + img_name)

    test_transform = transforms.Compose([transforms.Resize((224, 224)),
                                         transforms.ToTensor(),
                                         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ]
                                        )

    img = test_transform(img)
    img = img.to(device)
    img = img.unsqueeze(0)
    outputs = model(img)  # 将图片输入到模型中
    _, predicted = outputs.max(1)

    pred_type = predicted.item()
    print(img_name, 'pred_type:', pred_type)

在使用标签为9的卡车图像进行预测:

AlexNet:

 ResNet18:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/43634.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

vue中export和export default的使用

<script> export default {name: HelloWorld } $(function () {alert(引入成功) }) </script> 1、export的使用 比喻index.js要使用test.js中的数据&#xff0c;首先在test.js文件中进行导出操作 代码如下&#xff1a; export function list() {alert("list…

【表达式引擎】简单高效的轻量级Java表达式引擎:Aviator

简单高效的轻量级表达式引擎&#xff1a;Aviator 前言 Aviator 是一个高性能、、轻量级的表达式引擎&#xff0c;支持表达式动态求值。其设计目标为轻量级和高性能&#xff0c;相比于 Groovy 和 JRuby 的笨重&#xff0c;Aviator 就显得更加的小巧。与其他的轻量级表达式引擎不…

Python实战项目——物流行业数据分析(二)

今天我们对物流行业数据进行简单分析&#xff0c;数据来源&#xff1a;某企业销售的6种商品所对应的送货及用户反馈数据 解决问题&#xff1a; 1、配送服务是否存在问题 2、是否存在尚有潜力的销售区域 3、商品是否存在质量问题 分析过程&#xff1a; 依旧先进行数据处理 一…

【MATLAB第58期】基于MATLAB的PCA-Kmeans、PCA-LVQ与BP神经网络分类预测模型对比

【MATLAB第58期】基于MATLAB的PCA-Kmeans、PCA-LVQ与BP神经网络分类预测模型对比 一、数据介绍 基于UCI葡萄酒数据集进行葡萄酒分类及产地预测 共包含178组样本数据&#xff0c;来源于三个葡萄酒产地&#xff0c;每组数据包含产地标签及13种化学元素含量&#xff0c;即已知类…

C# List 详解四

目录 18.FindLast(Predicate) 19.FindLastIndex(Int32, Int32, Predicate) 20.FindLastIndex(Int32, Predicate) 21.FindLastIndex(Predicate) 22.ForEach(Action) 23.GetEnumerator() 24.GetHashCode() 25.GetRange(Int32, Int32) C#…

【搜索引擎Solr】配置 Solr 以获得最佳性能

Apache Solr 是广泛使用的搜索引擎。有几个著名的平台使用 Solr&#xff1b;Netflix 和 Instagram 是其中的一些名称。我们在 tajawal 的应用程序中一直使用 Solr 和 ElasticSearch。在这篇文章中&#xff0c;我将为您提供一些关于如何编写优化的 Schema 文件的技巧。我们不会讨…

Prompt 技巧指南-让 ChatGPT 回答更准确

随着 ChatGPT 等大型语言模型 (LLM)的兴起&#xff0c;人们慢慢发现&#xff0c;怎么样向 LLM 提问、以什么技巧提问&#xff0c;是获得更加准确的回答的关键&#xff0c;也由此产生了提示工程这个全新的领域。 提示工程(prompt engineering)是一门相对较新的领域&#xff0c;用…

多目标灰狼算法(MOGWO)的Matlab代码详细注释及难点解释(佳点集改进初始种群的MOGWO)

目录 一、外部种群Archive机制 二、领导者选择机制 三、多目标灰狼算法运行步骤 四、MOGWO的Matlab部分代码详细注释 五、MOGWO算法难点解释 5.1 网格与膨胀因子 5.2 轮盘赌方法选择每个超立方体概率 为了将灰狼算法应用于多目标优化问题,在灰狼算法中引入外部种群Archi…

【人工智能】大模型平台新贵——文心千帆

个人主页&#xff1a;【&#x1f60a;个人主页】 &#x1f31e;热爱编程&#xff0c;热爱生活&#x1f31e; 文章目录 前言大模型平台文心千帆发布会推理能力模型微调 作用 前言 在不久的之前我们曾讨论过在ChatGPT爆火的大环境下&#xff0c;百度推出的“中国版ChatGPT”—文…

【C++】STL使用仿函数控制优先级队列priority_queue

文章目录 前言一、priority_queue的底层实现二、使用仿函数控制priority_queue的底层总结 前言 本文章讲解CSTL的容器适配器&#xff1a;priority_queue的实现&#xff0c;并实现仿函数控制priority_queue底层。 一、priority_queue的底层实现 priority_queue叫做优先级队列&…

C进阶:文件操作

C语言文件操作 什么是文件 磁盘上的数据是文件。 但是在程序设计中&#xff0c;我们一般谈的文件有两种&#xff1a;程序文件&#xff08;例如.c,.h这一类编译&#xff0c;链接过程中的文件&#xff09;&#xff0c;数据文件。 程序文件 包括源程序文件&#xff08;后缀为.c&…

【PostgreSQL内核学习(十)—— 查询执行(可优化语句执行)】

可优化语句执行 概述物理代数与处理模型物理操作符的数据结构执行器的运行 声明&#xff1a;本文的部分内容参考了他人的文章。在编写过程中&#xff0c;我们尊重他人的知识产权和学术成果&#xff0c;力求遵循合理使用原则&#xff0c;并在适用的情况下注明引用来源。 本文主要…

macOS系统下编译linux-adk源码

1.下载 linux-adk源码 https://github.com/gibsson/linux-adk.git 2.安装libusb库 brew install libusb 3.修改Makefile CFLAGS += -Isrc -I/usr/local/Cellar/libusb/1.0.26/include/libusb-1.0 4.编译 make ./linux-adk -h 查看用法 查看系统已连接USB设备 system_p…

C#使用Linq和Loop计算集合的平均值、方差【标准差】

方差【标准差】 标准差公式是一种数学公式。标准差也被称为标准偏差&#xff0c;或者实验标准差&#xff0c;公式如下所示&#xff1a; 样本标准差方差的算术平方根ssqrt(((x1-x)^2 (x2-x)^2 ......(xn-x)^2)/n) 总体标准差σsqrt(((x1-x)^2 (x2-x)^2 ......(xn-x)^2)/n ) …

win11我们无法创建新的分区也找不到现有的分区

U盘重装系统的时候 提示&#xff1a;win11我们无法创建新的分区也找不到现有的分区 ShiftF10 &#xff0c;调出 命令提示符&#xff1b; diskpart list disk select disk 盘编号 clean convert gpt 参考&#xff1a;怎么解决我们无法创建新的分区也找不到现有的分区问题&#x…

OpenCV实现照片换底色处理

目录 1.导言 2.引言 3.代码分析 4.优化改进 5.总结 1.导言 在图像处理领域&#xff0c;OpenCV是一款强大而广泛应用的开源库&#xff0c;能够提供丰富的图像处理和计算机视觉功能。本篇博客将介绍如何利用Qt 编辑器调用OpenCV库对照片进行换底色处理&#xff0c;实现更加…

HTTP 缓存机制 强制缓存/协商缓存

为什么被缓存&#xff0c;如何命中缓存以及缓存什么时候生效的&#xff0c;我们却很少在实际开发中去了解。借助动画形式来从根上理解 HTTP 缓存机制及原理。 HTTP 缓存&#xff0c;对于前端的性能优化方面来讲&#xff0c;是非常关键的&#xff0c;从缓存中读取数据和直接向服…

REST和RPC的区别

1 REST REST 不是一种协议&#xff0c;它是一种架构。大部分REST的实现中使用了RPC的机制&#xff0c;大致由三部分组成&#xff1a; method&#xff1a;动词&#xff08;GET、POST、PUT、DELETE之类的&#xff09;Host&#xff1a;URI&#xff08;统一资源标识&#xff09;&…

前端Vue uni-app App/小程序/H5 通用tree树形结构图

随着技术的发展&#xff0c;开发的复杂度也越来越高&#xff0c;传统开发方式将一个系统做成了整块应用&#xff0c;经常出现的情况就是一个小小的改动或者一个小功能的增加可能会引起整体逻辑的修改&#xff0c;造成牵一发而动全身。 通过组件化开发&#xff0c;可以有效实现…

Apple M1 Pro macOS 切换中文输入法卡住

(macOS 在切换中文输入法时出现卡住的情况 1&#xff0c;切换为中文输入法后再次卡住2&#xff0c;杀死 简体中文输入方式的进程参考 将光标移到菜单栏的输入法切换为英文输入法 多次切换为英文输入法&#xff0c;可以切换为英文输入法 切换为英文输入法后电脑不卡顿了&#xf…