卷积神经网络|迁移学习-猫狗分类完整代码实现

还记得这篇文章吗?迁移学习|代码实现

在这篇文章中,我们知道了在构建模型时,可以借助一些非常有名的模型,这些模型在ImageNet数据集上早已经得到了检验。

同时torchvision模块也提供了预训练好的模型。我们只需稍作修改,便可运用到自己的实际任务中!

我们仍然按照这个步骤开始我们的模型的训练

  • 准备一个可迭代的数据集

  • 定义一个神经网络

  • 将数据集输入到神经网络进行处理

  • 计算损失

  • 通过梯度下降算法更新参数

import torch import torchvisionimport torchvision.transforms as transformsimport torch.nn as nnimport torch.optim as optimimport matplotlib.pyplot as pltfrom torchvision import models

数据集准备

cifar10_train = torchvision.datasets.CIFAR10(    root = 'cifar10/',    train = True,    download = True)cifar10_test=torchvision.datasets.CIFAR10(    root = 'cifar10/',    train = False,    download = True)
transform = transforms.Compose([        transforms.ToTensor(),        transforms.Resize((224,224))    ])

cifar2_train=[(transform(img),[3,5].index(label)) for img,label in cifar10_train if label in [3,5]]
cifar2_test=[(transform(img),[3,5].index(label)) for img,label in cifar10_test if label in [3,5]]
train_loader = torch.utils.data.DataLoader(cifar2_train, batch_size=64,shuffle=True)test_loader = torch.utils.data.DataLoader(cifar2_test, batch_size=64,shuffle=True)

数据集使用CIFAR-10数据集中的猫和狗

CIFAR-10数据集类别

种类       标签

  • plane       0

  • car           1

  • bird         2

  • cat           3

  • deer         4

  • dog          5

  • frog         6

  • horse       7

  • ship         8

  • truck        9

可以看到其中cat和dog的标签分别为3和5

借助:

[3,5].index(label)

我们可以将cat标签变为0dog标签变为1,从而回到二分类问题。

举个例子:

>>> [3,5].index(3)0>>> [3,5].index(5)1

定义模型

参考这篇文章:迁移学习|代码实现

#网络搭建network=models.resnet18(pretrained=True)
for param in network.parameters():    param.requires_grad=False
network.fc=nn.Linear(512,2)#损失函数criterion=nn.CrossEntropyLoss()#优化器optimizer=optim.SGD(network.fc.parameters(),lr=0.01,momentum=0.9)
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")network=network.to(device)

训练模型:

for epoch in range(10):    total_loss = 0    total_correct = 0    for batch in train_loader:   # Get batch        images, labels =batch        images=images.to(device)        labels=labels.to(device)                    optimizer.zero_grad()  #告诉优化器把梯度属性中权重的梯度归零,否则pytorch会累积梯度        preds = network(images)        loss = criterion(preds, labels)        loss.backward()        optimizer.step()                total_loss += loss.item()        _,prelabels=torch.max(preds,dim=1)        total_correct += int((prelabels==labels).sum())    accuracy = total_correct/len(cifar2_train)    print("Epoch:%d  ,  Loss:%f  , Accuracy:%f "%(epoch,total_loss,accuracy))
  • Epoch:0  ,  Loss:78.549439  , Accuracy:0.788900

  • Epoch:1  ,  Loss:77.828066  , Accuracy:0.801500

  • Epoch:2  ,  Loss:66.151785  , Accuracy:0.828100

  • Epoch:3  ,  Loss:76.204446  , Accuracy:0.816800

  • Epoch:4  ,  Loss:68.886606  , Accuracy:0.828100

  • Epoch:5  ,  Loss:71.129405  , Accuracy:0.821200

  • Epoch:6  ,  Loss:66.096364  , Accuracy:0.829900

  • Epoch:7  ,  Loss:65.504227  , Accuracy:0.827700

  • Epoch:8  ,  Loss:76.303878  , Accuracy:0.817100

  • Epoch:9  ,  Loss:70.546953  , Accuracy:0.820700

测试模型:

correct=0total=0network.eval()with torch.no_grad():    for batch in test_loader:        imgs,labels=batch        imgs=imgs.cuda()        labels=labels.cuda()                preds=network(imgs)        _,prelabels=torch.max(preds,dim=1)        #print(prelabels.size())        total=total+labels.size(0)        correct=correct+int((prelabels==labels).sum())    #print(total)    accuracy=correct/total    print("Accuracy: ",accuracy)

Accuracy:  0.8025

这里使用的预训练模型是resnet18,我们也可以使用VGG16模型,同时记得改变最后一个全连接层的输出参数,使得其满足我们自己的任务。

除了预训练模型之外,我们还可以对一些超参数进行调整,使最后的效果变得更好!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/302429.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

我的阿里云服务器被攻击了

服务器被DDoS攻击最恶心,尤其是阿里云的服务器受攻击最频繁,因为黑客都知道阿里云服务器防御低,一但被攻击就会进入黑洞清洗,轻的IP停止半小时,重的停两个至24小时,给网站带来很严重的损失。而处理 ddos 攻…

Spring Cloud Gateway整合Sentinel

日升时奋斗,日落时自省 目录 1、实现整合 1.1、添加框架依赖 1.2、设置配置文件 1.3、设置限流和熔断规则 1.3.1、限流配置 Route ID限流配置 API限流配置 1.3.2、熔断配置 2、实现原理 先前Sentinel针对是业务微服务,没有整合Sentinel到Spring…

若依CRUD搬砖开始,Java小白入门(十)

背景 经过囫囵吞枣的学习若依框架,对于ruoyi-framework,common,安全,代码生成等模块都看了一圈,剩余的调度模块,这个暂时不深入,剩余的是ruoyi-system,就是用mybatis完成的&#xf…

c/c++基础 自增自减运算符 大白讲解i++/i--/++i/--i

后置运算符:i表示在使用x之后,再使x的值加1,即ii1; 前置运算符:i表示在使用x之前,先使x的值加1,即ii1. 前缀运算和后缀运算的区别:前缀运算是“先变后用”,而后缀运算是“先用后变”…

JavaScript:Date 对象-时间日期

Date 对象-时间日期: - JS中所有的关于时间信息都需要通过Date对象来表示 // 创建一个Date对象 // 如果直接使用new Date()创建时间对象,它会默认创建一个表示代码执行时刻的对象var d new Date();// 如果希望创建一个指定的时间的Date的对象,需要传递…

《代码整洁之道之程序员的职业素养》-验收测试测试策略

Tips:此文为阅读Bob大叔的《代码整洁之道》一书的摘抄小记,谨慎“食用” 一、验收测试 重视沟通,专业开发人员既要做好开发也要做好沟通。“输入糟糕,输出也会糟糕”,职业程序员需要重视与团队及业务部门的沟通&…

IP3005A 超高精度内置MOSFET 单节锂电池保护IC 英集芯

描述 IP3005系列IC是一款超高精度的单节锂离子/ 锂聚合物电池保护芯片,它内置功率MOSFET,全 集成了超高精度的过充电压、过放电压、过放电流、 过充电流检测保护电路。 IP3005采用了精确的电压判断电路,让过充电压,过充恢复电压&…

一天一个设计模式---单例模式

概念 单例模式是一种创建型设计模式,其主要目的是确保一个类只有一个实例,并提供一个全局访问点。这意味着在应用程序中的任何地方,只能有一个实例存在,而不会创建多个相同类型的实例。 具体内容 单例模式通常包括以下几个要素…

如何解决vscode中文路径的问题

首先我们进入设备 搜索“区域”,选择“区域设置” 点击管理语言设置 点击更改系统区域设置,勾选“Beta 版: 使用 Unicode UTF-8 提供全球语言支持(U)”,电脑会叫你重启,你重启就行了

树莓派点亮led(1)

更换清华源 树莓派更换国内源(清华源)_树莓派更换清华源-CSDN博客 查看python版本 安装pipx 安装引脚 查看引脚 #安装gpio 创建文件夹 创建py文件 运行python文件 ubuntu传递文件到树莓派 1、启用ubuntu端的新终端 2拷贝文件到home目录下的用户文件夹…

AI人工智能学习路线图

学习人工智能 AI 的路线通常包括以下几个步骤:了解人工智能的基本概念和历史,包括机器学习、神经网络、深度学习等技术。学习数学基础知识,包括线性代数、微积分、概率论和统计学等。学习编程基础知识,包括 Python、C 等编程语言。…

手把手教学git-idea在实际开发中如何使用(适用于包装/实习同学)

TOC 前言 当前git主流的使用方式有可视化工具和git命令行, 这里主要介绍可视化工具(idea中的git)的使用方法, 其他比较好用的可视化工具还有SourceTree git-idea idea中git相关页面和功能的介绍 图一 图二 图三 图四 合并代码解决冲突: 合并代码我知道的方法有三种…

研究生写爬虫险些锒铛入狱,起因竟是为爱冲锋?

我国目前并未出台专门针对网络爬虫技术的法律规范,但在司法实践中,相关判决已屡见不鲜,K 哥特设了“K哥爬虫普法”专栏,本栏目通过对真实案例的分析,旨在提高广大爬虫工程师的法律意识,知晓如何合法合规利用…

2024了,我不允许你还不会:Qt查看与调试源码

一、人人都是大佬,谦(卑)虚(心)长远进步 作为一个Qt的开发者,下面这段代码你已经快到了“相看两不厌”的状态了吧。你有没有好奇过,a.exec() 到底干了什么? 我不允许你再说 这是Qt …

stable diffusion 人物高级提示词(五)场景、特效、拍摄手法、风格

一、场景 场景Promptindoor室内outdoor室外cityscape城市景色countryside乡村beach海滩forest森林mountain山脉snowfield雪原skyscraper摩天大楼ancient monument古代遗迹cathedral大教堂library图书馆museum博物馆office building办公大楼restaurant餐厅street market街头市场…

Spring——Spring整合MyBatis

Spring整合MyBatis 1.创建工程 1.1.pom.xml <?xml version"1.0" encoding"UTF-8"?> <project xmlns"http://maven.apache.org/POM/4.0.0"xmlns:xsi"http://www.w3.org/2001/XMLSchema-instance"xsi:schemaLocation"…

数据结构——队列(Queue)

目录 1.队列的介绍 2.队列工程 2.1 队列的定义 2.1.1 数组实现队列 2.1.2 单链表实现队列 2.2 队列的函数接口 2.2.1 队列的初始化 2.2.2 队列的数据插入&#xff08;入队&#xff09; 2.2.3 队列的数据删除&#xff08;出队&#xff09; 2.2.4 取队头数据 2.2.5 取队…

SpringBoot+策略模式实现多种文件存储模式

一、策略模式 背景 针对某种业务可能存在多种实现方式&#xff1b;传统方式是通过传统if…else…或者switch代码判断&#xff1b; 弊端&#xff1a; 代码可读性差扩展性差难以维护 策略模式简介 策略模式是一种行为型模式&#xff0c;它将对象和行为分开&#xff0c;将行…

PyTorch|view(),改变张量维度

在构建自己的网络时&#xff0c;了解数据经过每个层后的形状变化是必须的&#xff0c;否则&#xff0c;网络大概率会出现问题。PyToch张量有一个方法&#xff0c;叫做view(),使用这个方法&#xff0c;我们可以很容易的对张量的形状进行改变&#xff0c;从而符合网络的输入要求。…

pgAdmin和asdf postgres的安装

安装pgAdmin&#xff1a; curl https://www.pgadmin.org/static/packages_pgadmin_org.pub | sudo apt-key addsudo sh -c echo "deb https://ftp.postgresql.org/pub/pgadmin/pgadmin4/apt/$(lsb_release -cs) pgadmin4 main" > /etc/apt/sources.list.d/pgadmi…