[PyTorch][chapter 52][迁移学习]

前言:

     迁移学习(Transfer Learning)是一种机器学习方法,它通过将一个领域中的知识和经验迁移到另一个相关领域中,来加速和改进新领域的学习和解决问题的能力。

      这里面主要结合前面ResNet18 例子,详细讲解一下迁移学习的流程


一  简介

    

迁移学习可以通过以下几种方式实现:

1.1 基于预训练模型的迁移:

      将已经在大规模数据集上预训练好的模型(如BERT、GPT等)作为一个通用的特征提取器,然后在新领域的任务上进行微调。

1.2  网络结构迁移:

将在一个领域中训练好的模型的网络结构应用到另一个领域中,并在此基础上进行微调。

1.3  特征迁移:

     将在一个领域中训练好的某些特征应用到另一个领域中,并在此基础上进行微调。

     word2vec

1.4 参数迁移:

       将在一个领域中训练好的模型的参数应用到另一个领域中,并在此基础上进行微调。

本文主要例子用的是 参数迁移


二  Flatten

    作用:

     输入的向量x [batch, c, w, h]=>[batch, c*w*h]

# -*- coding: utf-8 -*-
"""
Created on Wed Aug 16 15:11:35 2023

@author: chengxf2
"""

import torch
from torch import optim,nn

class Flatten(nn.Module):
    
    def __init__(self):
        
        super(Flatten,self).__init__()
        
    
    def forward(self, x):
        
        a = torch.tensor(x.shape[1:])
        #dim 中 input 张量的每一行的乘积。
        shape = torch.prod(a).item()
        #print("\n ---new shape--- ",shape)
        return x.view(-1,shape)

三 迁移学习

   torchvision 已经提供好了一些分类器 resnet18,resnet152, 利用其训练好的参数,把最后的分类类型更改掉。

   from torchvision.models import resnet152
  from torchvision.models import resnet18

   注意:

          现有分类器分类的类型 > = 新分类器类型,再做transfer.

才能取得好的效果.

         

分类器分类类型
已有分类器[猫,狗,鸡,鸭】
新分类器[猫,狗]

     

   

 

# -*- coding: utf-8 -*-
"""
Created on Wed Aug 16 14:56:35 2023

@author: chengxf2
"""

# -*- coding: utf-8 -*-
"""
Created on Tue Aug 15 15:38:18 2023

@author: chengxf2
"""

import torch
from torch import optim,nn
import visdom
from torch.utils.data import DataLoader
from PokeDataset import Pokemon
from torchvision.models import resnet152
from torchvision.models import resnet18

from util import Flatten

batchNum = 32
lr = 1e-3
epochs = 20
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
torch.manual_seed(1234)

root ='pokemon'
resize =224

csvfile ='data.csv'
train_db = Pokemon(root, resize, 'train',csvfile)
val_db = Pokemon(root, resize, 'val',csvfile)
test_db = Pokemon(root, resize, 'test',csvfile)

train_loader = DataLoader(train_db, batch_size =batchNum,shuffle= True,num_workers=4)
val_loader = DataLoader(val_db, batch_size =batchNum,shuffle= True,num_workers=2)
test_loader = DataLoader(test_db, batch_size =batchNum,shuffle= True,num_workers=2)
viz = visdom.Visdom()

def evalute(model, loader):
    
    total =len(loader.dataset)
    correct =0
    for x,y in loader:
        
        x = x.to(device)
        y = y.to(device)
        
        with torch.no_grad():
            
            logits = model(x)
            pred = logits.argmax(dim=1)
            correct += torch.eq(pred, y).sum().float().item()
    
    acc = correct/total
    
    return acc   
        
        

def main():
    
    trained_model = resnet152(pretrained=True)
    
    model = nn.Sequential(*list(trained_model.children())[:-1],
        Flatten(),
        nn.Linear(in_features=2048, out_features=5))
    
   
    
    optimizer = optim.Adam(model.parameters(),lr =lr) 
    criteon = nn.CrossEntropyLoss()
    
    best_epoch=0,
    best_acc=0
    viz.line([0],[-1],win='train_loss',opts =dict(title='train loss'))
    viz.line([0],[-1],win='val_loss',  opts =dict(title='val_acc'))
    global_step =0
    
    
  
    for epoch in range(epochs):
        print("\n --main---: ",epoch)
        for step, (x,y) in enumerate(train_loader):
            #x:[b,3,224,224] y:[b]

             x = x.to(device)
             y = y.to(device)
             #print("\n --x---: ",x.shape)
             
             logits =model(x)
             loss = criteon(logits, y)
             #print("\n --loss---: ",loss.shape)
             optimizer.zero_grad()
             loss.backward()
             optimizer.step()
             
             viz.line(Y=[loss.item()],X=[global_step],win='train_loss',update='append')
             global_step +=1
             
        if epoch %2 ==0:
            
             val_acc = evalute(model, val_loader)
             
             if val_acc>best_acc:
                 best_acc = val_acc
                 best_epoch =epoch
                 torch.save(model.state_dict(),'best.mdl')
             print("\n val_acc ",val_acc)
             viz.line([val_acc],[global_step],win='val_loss',update='append')
             
    print('\n best acc',best_acc, "best_epoch: ",best_epoch)
    
    model.load_state_dict(torch.load('best.mdl'))
    print('loaded from ckpt')
    
    test_acc = evalute(model, test_loader)
    print('\n test acc',test_acc)
                 

if __name__ == "__main__":
    
    main()


参考:
https://blog.csdn.net/qq_44089890/article/details/130460700

课时107 迁移学习实战_哔哩哔哩_bilibili

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/81812.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【Linux】make/makefile自动化构建工具

文章目录 前言一、什么是make/makefile?二、依赖关系和依赖方法2.1 makefile中创建文件2.2 makefile中删除文件2.3 stat指令查看文件的三种时间(ACM)2.4 伪目标文件(.PHONY) 三、Makefile中的一些特殊符号3.1 $ 和 $^3…

数据结构 | 堆

本文简要总结堆的概念。 更新:2023 / 8 / 20 数据结构 | 堆 堆概念方法插入步骤 删除步骤 示例大根堆堆插入删除堆排序 代码实现Python大根堆1.2. heapq 小根堆1.2. heapq 参考链接 堆 概念 如果谈到堆排序,那么必然要说说什么是 大根堆 max heap 和 …

函数极限与连续性——张宇老师学习笔记

Latex 源代码以及成品PDF(Debug版本):https://wwsk.lanzouc.com/itaDI15vddcb Latex编译Debug版本: $ xelatex 函数极限与连续性.texLatex编译Relese版本(无例题、习题,只有概念定义)&#xf…

arm:day4

1. 实现三盏灯的点亮 .text .global _start_start: led1初始化函数LED_INIT: 1 通过RCC_AHB4_ENSETR寄存器&#xff0c;设置GPIOE F组控制器使能 0x50000A28[5:4]1ldr r0,0X50000A28ldr r1,[r0]orr r1,r1,#(0X3<<4)str r1,[r0] 2.1 通过GPIOE_MODER寄存器&#xff0c;…

FFmpeg5.0源码阅读——VideoToobox硬件解码

摘要&#xff1a;本文描述了FFmpeg中videotoobox解码器如何进行解码工作&#xff0c;如何将一个编码的码流解码为最终的裸流。   关键字&#xff1a;videotoobox,decoder,ffmpeg   VideoToolbox 是一个低级框架&#xff0c;提供对硬件编码器和解码器的直接访问。 它提供视频…

RabbitMq-2安装与配置

Rabbitmq的安装 1.上传资源 注意&#xff1a;rabbitmq的版本必须与erlang编译器的版本适配 2.安装依赖环境 //打开虚拟机 yum install build-essential openssl openssl-devel unixODBC unixODBC-devel make gcc gcc-c kernel-devel m4 ncurses-devel tk tc xz3.安装erlan…

第3天----在一行句子中寻找最长最短单词

今天我们将学习如何在一行句子中寻找(第一次出现的)最长最短单词。本节内容会或多或少地利用到第一讲/第二讲的知识点&#xff0c;需要的同学可以先去看看前面的内容。 一、小试牛刀&#xff1a; 题目描述 输入 1 行句子&#xff08;不多于 200 个单词&#xff0c;每个单词长度…

Spring学习笔记+SpringMvc+SpringBoot学习笔记

壹、核心概念&#xff1a; 1.1. IOC和DI IOC&#xff08;Inversion of Control&#xff09;控制反转&#xff1a;对象的创建控制权由程序转移到外部&#xff0c;这种思想称为控制反转。/使用对象时&#xff0c;由主动new产生对象转换为由外部提供对象&#xff0c;此过程种对象…

[四次挥手]TCP四次挥手握手由入门到精通(知识精讲)

⬜⬜⬜ &#x1f430;&#x1f7e7;&#x1f7e8;&#x1f7e9;&#x1f7e6;&#x1f7ea;(*^▽^*)欢迎光临 &#x1f7e7;&#x1f7e8;&#x1f7e9;&#x1f7e6;&#x1f7ea;&#x1f430;⬜⬜⬜ ✏️write in front✏️ &#x1f4dd;个人主页&#xff1a;陈丹宇jmu &am…

人工智能与云计算实训室建设方案

一、 人工智能与云计算系统概述 人工智能&#xff08;Artificial Intelligence&#xff0c;简称AI&#xff09;是一种模拟人类智能的科学和工程&#xff0c;通过使用计算机系统来模拟、扩展和增强人类的智能能力。人工智能涉及多个领域&#xff0c;包括机器学习、深度学习、自然…

mysql的两张表left join 进行关联后,索引进行优化案例

一 mysql的案例 1.1 不加索引情况 1.表1没加索引 2.表2没加索引 3.查看索引 1.2 添加索引 1.表1添加索引 2.表2添加索引 3.查看

使用navicat连接postgresql报错问题解决

使用navicat连接postgresql报错问题解决 一、问题现象&#xff1a; 最近使用Navicat来连接postgreSQL数据库&#xff0c;发现连接不上&#xff0c;报错信息如下&#xff1a; 自己百度了一下&#xff0c;发现pgsql 15版本以后&#xff0c;有些系统表的列名改了&#xff0c;pg_…

一文科普,配资门户网是什么?

配资门户网是一个为投资者提供配资服务的平台。配资是指通过借用他人资金进行投资交易的一种金融操作方式。配资门户网作为一个线上平台&#xff0c;为投资者提供了方便、快捷的配资服务。 配资门户网提供了多种不同的配资方案&#xff0c;以满足不同投资者的需求。投资者可以…

录制游戏视频的软件有哪些?分享3款软件!

“有录制游戏视频的软件推荐吗&#xff1f;最近迷上了网游&#xff0c;想录制点自己高端操作的游戏画面&#xff0c;但是不知道用什么软件录屏比较好&#xff0c;就想问问大家&#xff0c;有没有好用的录制游戏视频软件。” 在游戏领域&#xff0c;玩家们喜欢通过录制游戏视频…

根据源码,模拟实现 RabbitMQ - 实现消息持久化,统一硬盘操作(3)

目录 一、实现消息持久化 1.1、消息的存储设定 1.1.1、存储方式 1.1.2、存储格式约定 1.1.3、queue_data.txt 文件内容 1.1.4、queue_stat.txt 文件内容 1.2、实现 MessageFileManager 类 1.2.1、设计目录结构和文件格式 1.2.2、实现消息的写入 1.2.3、实现消息的删除…

HCIP实验之MPLS

目录 一&#xff0c;实验题目 ​编辑 拓扑与IP地址规划如图所示 二&#xff0c;实验思路 三&#xff0c;实验步骤 3.1 私网部分IP地址配置 3.2 LSP部分配置 3.3 启动OSPF协议 3.4 启动MPLS协议 3.5 启动MPLS VPN 3.6 实现公网私网互通 3.7 配置BGP 3.8 双向重发布 …

常见的 Python 错误及其解决方案

此文整理了一些常见的 Python 错误及其解决方案。 1、SyntaxError: invalid syntax 说明&#xff1a;无效的语法是最常见的错误之一&#xff0c;通常是由于编写代码时违反了 Python 的语法规则。可能的原因&#xff1a; 忘记在 if、while、for 等语句后写冒号&#xff0c;或者…

我和 TiDB 的故事 | 远近高低各不同

作者&#xff1a; ShawnYan 原文来源&#xff1a; https://tidb.net/blog/b41a02e6 Hi, TiDB, Again! 书接上回&#xff0c; 《我和 TiDB 的故事 | 横看成岭侧成峰》 &#xff0c;一年时光如白驹过隙&#xff0c;这一年我好似在 TiDB 上投入的时间总量不是很多&#xff0…

vite打包配置以及性能优化

vite打包配置以及性能优化 安装插件 首先该安装的插件&#xff0c;你要安装一下吧 这三个是基本的插件&#xff0c;其他优化的插件下面会介绍到 "vite": "4.4.6","vite-plugin-html": "^3.2.0","vitejs/plugin-vue": &qu…

Eureka:集群环境配置

创建三个集群 导包 <!-- 导包--><dependencies><!-- Eureka -server --><dependency><groupId>org.springframework.cloud</groupId><artifactId>spring-cloud-starter-eureka-server</artifactId><version>1.…