使用预训练模型构建自己的深度学习模型(迁移学习)

在深度学习的实际应用中,很少会去从头训练一个网络,尤其是当没有大量数据的时候。即便拥有大量数据,从头训练一个网络也很耗时,因为在大数据集上所构建的网络通常模型参数量很大,训练成本大。所以在构建深度学习应用时,通常会使用预训练模型

需求

训练一个模型来分类蚂蚁ants和蜜蜂bees

步骤

  1. 加载数据集
  2. 编写函数(训练并寻找最优模型)
  3. 编写函数(查看模型效果)
  4. 使用torchvision微调模型
  5. 使用tensorboard可视化训练情况

Pytorch保存和加载模型的两种方式

1.完整保存模型、加载模型

torch.save(net, 'mnist.pth')
net = torch.load('mnist.pth', 'map_location="cpu"')
# # Model class must be defined somewhere

load() 默认会将该张量加载到保存时所在的设备上,map_location 可以强制加载到指定的设备上

2. 保存、加载模型的状态字典(模型中的参数)

torch.save(model.state_dict(), PATH)
model = TheModelClass(*args, **kwargs)

model.load_state_dict(torch.load(PATH))

迁移学习全过程

1.加载数据集

ants和bees各有约120张训练图片。

每个类有75张验证图片,从零开始在 如此小的数据集上进行训练通常是很难泛化的。

由于我们使用迁移学习,模型的泛化能力会相当好。

from __future__ import print_function, division

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torch.backends.cudnn as cudnn
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy
cudnn.benchmark = True  
plt.ion()

lr_scheduler:学习率调度器,用于在训练过程中动态调整学习率

torch.backends.cudnn.benchmark = True:大部分情况下,设置这个 flag 可以让内置的 cuDNN 的 auto-tuner 自动寻找最适合当前配置的高效算法,从而加速运算

plt.ion():这允许你在一个交互式环境中运行matplotlib

data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

transforms.RandomResizedCrop(224):随机大小裁剪和缩放图像,使得裁剪后的图像尺寸为224x224像素。这个操作有助于模型学习不同尺度和长宽比的图像特征,从而提高模型的泛化能力。

transforms.RandomHorizontalFlip():以一定的概率(默认为0.5)对图像进行水平翻转。这也是一种数据增强技术,有助于模型学习对称性

transforms.Resize(256):将图像缩放到256x256像素

transforms.CenterCrop(224):从图像的中心裁剪出224x224像素的区域。这种裁剪方式确保每次裁剪的都是图像的中心部分,有助于在验证或测试时获得更一致的结果

data_dir = 'data/hymenoptera_data'
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
                                          data_transforms[x])
                  for x in ['train', 'val']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,
                                             shuffle=True, num_workers=4)
              for x in ['train', 'val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

可视化数据

def imshow(inp, title=None): 
    # 可视化一组 Tensor 的图片
    inp = inp.numpy().transpose((1, 2, 0)) 
    mean = np.array([0.485, 0.456, 0.406]) 
    std = np.array([0.229, 0.224, 0.225]) 
    inp = std * inp + mean 
    inp = np.clip(inp, 0, 1) 
    plt.imshow(inp) 
    if title is not None: 
        plt.title(title) 
    plt.pause(0.001) # 暂停一会儿,为了将图片显示出来
# 获取一批训练数据
inputs, classes = next(iter(dataloaders['train'])) 
# 批量制作网格
out = torchvision.utils.make_grid(inputs) 
imshow(out, title=[class_names[x] for x in classes]) 

2.编写函数(训练并寻找最优模型)

def train_model(model, criterion, optimizer, scheduler, num_epochs=25): 
    """ 
    训练模型,并返回在验证集上的最佳模型和准确率 
    - criterion: 损失函数 

   Return: 
    - model(nn.Module): 最佳模型 
    - best_acc(float): 最佳准确率 
    """

    since = time.time()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        print(f'Epoch {epoch}/{num_epochs - 1}')
        print('-' * 10)
        # 训练集和验证集交替进行前向传播
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  
            else:
                model.eval()   

            running_loss = 0.0
            running_corrects = 0

            # 遍历数据集
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)
                # 清空梯度,避免累加了上一次的梯度
                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):
                    # 正向传播
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)
                    # 反向传播且仅在训练阶段进行优化
                    if phase == 'train':
                        loss.backward() 
                        optimizer.step()

                # 统计loss、准确率
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
           
            if phase == 'train':
                scheduler.step()

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]

            print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

            # 发现了更优的模型,记录起来
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())

        print()

    time_elapsed = time.time() - since
    print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
    print(f'Best val Acc: {best_acc:4f}')

    # 加载训练的最好的模型
    model.load_state_dict(best_model_wts)
    return model

3.编写函数(查看模型效果)

def visualize_model(model, num_images=6):
    was_training = model.training
    model.eval()
    images_so_far = 0
    fig = plt.figure()

    with torch.no_grad():
        for i, (inputs, labels) in enumerate(dataloaders['val']):
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)

            for j in range(inputs.size()[0]):
                images_so_far += 1
                ax = plt.subplot(num_images//2, 2, images_so_far)
                ax.axis('off')
                ax.set_title(f'predicted: {class_names[preds[j]]}')
                imshow(inputs.cpu().data[j])

                if images_so_far == num_images:
                    model.train(mode=was_training)
                    return
        model.train(mode=was_training)

4.使用torchvision微调模型

微调步骤图

加载预训练模型,并将最后一个全连接层重置

model = models.resnet18(pretrained=True) # 加载预训练模型
num_ftrs = model.fc.in_features # 获取低级特征维度 
model.fc = nn.Linear(num_ftrs, 2) # 替换新的输出层 
model = model.to(device) 
criterion = nn.CrossEntropyLoss() 
# 所有参数都参加训练 
optimizer_ft = optim.SGD(model.parameters(), lr=0.001, momentum=0.9) 
# 每过 7 个 epoch 将学习率变为原来的 0.1 
scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)

训练

model_ft = train_model(model, criterion, optimizer_ft, scheduler, num_epochs=3)

预估

visualize_model(model_ft)

 

5.使用tensorboard可视化训练情况

将以下代码块合理的加入到训练模块里

from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()
ep_losses, ep_acces = [], []

ep_losses.append(epoch_loss)
ep_acces.append(epoch_acc.item())

writer.add_scalars('loss', {'train': ep_losses[-2], 'val': ep_losses[-1]}, global_step=epoch)
writer.add_scalars('acc', {'train': ep_acces[-2], 'val': ep_acces[-1]}, global_step=epoch)
writer.close()

运行命令启动tensorboard

tensorboard --logdir runs

可以通过执行命令的终端看到tensorboard可视化页面地址,且该命令会在执行该命令的目录下生成一个runs文件夹(日志文件的保存位置)

以下是我训练了25个epochs的acc和loss的动态图,

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/578337.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【redis】Redis数据类型(二)Hash类型

目录 Hash类型介绍特性hash 的内部编码方式/底层结构hashtableziplistlistpack 适用场景举例 常用命令hset示例 hsetnx示例: hmset示例 hget示例 hmget示例 hgetall示例 hdel示例 hlen示例 hexists示例 hincrby示例 hincrbyfloat示例 hkeys示例 hvals示例 Hash类型介…

VS2019编译OSG3.7.0+OSGEarth3.3+OSGQt5.15.2时遇到的问题及解决方法

注:本次编译以文章《VS2019编译OSG3.7.0+OSGEarth3.3+OSGQt》为基础搜集资料并进行编译 一 OSG编译 1.Osg3.7.0编译中,cmake阶段按照文章步骤即可。 2.另外,还需要对以下三项进行设置,参照《OSG-OpenSceneGraph在WIN10与VS2022下的部署(OSG3.6.5+VS2022+Win10_x64)个…

RustGUI学习(iced)之小部件(二):如何使用滑动条部件

前言 本专栏是学习Rust的GUI库iced的合集,将介绍iced涉及的各个小部件分别介绍,最后会汇总为一个总的程序。 iced是RustGUI中比较强大的一个,目前处于发展中(即版本可能会改变),本专栏基于版本0.12.1. 概述…

mybatis基本使用

文章目录 1. mybatis2. 基本使用(1) maven坐标(2) 配置文件编写(3) 数据库操作(4) 注解查询 2. 基本配置(1) 读取外部配置文件(2) mapper映射 3. 映射文件查询删除/修改/新增 动态sql 1. mybatis MyBatis 是一款优秀的持久层框架,它支持自定义 SQL、存储过程以及高…

CSS盒子模型(如果想知道CSS有关盒子模型的知识点,那么只看这一篇就足够了!)

前言:在网页制作的时候,我们需要将网页中的元素放在指定的位置,那么我们如何将元素放到指定的位置上呢?这时候我们就需要了解盒子模型。 ✨✨✨这里是秋刀鱼不做梦的BLOG ✨✨✨想要了解更多内容可以访问我的主页秋刀鱼不做梦-CSD…

sCrypt全新上线RUNES功能

sCrypt智能合约平台全新上线一键etch/mint RUNES功能! 请访问 https://runes.scrypt.io/ 或点击阅读原文体验! 关于sCrypt sCrypt是BSV区块链上的一种智能合约高级语言。比特币使用基于堆栈的Script语言来支持智能合约,但是用原生Script编…

网络靶场实战-物联网安全Unicorn框架初探

背景 Unicorn 是一款基于 QEMU 的快速 CPU 模拟器框架,可以模拟多种体系结构的指令集,包括 ARM、MIPS、PowerPC、SPARC 和 x86 等。Unicorn使我们可以更好地关注 CPU 操作, 忽略机器设备的差异。它能够在虚拟内存中加载和运行二进制代码,并提…

密码加密案例

文章目录 描述思路错误关于增强for循环改变不了数组的值这一现象的疑问代码反思 描述 思路错误 应该是将其放入数组,而不是单纯的读到,因为你要对每一位数字进行操作 关于增强for循环改变不了数组的值这一现象的疑问 我们尝试使用增强for循环 键盘输…

uniapp使用地图开发app

使用uniapp开发app中使用到地图的坑: 1、简单使用地图的功能比较简单,仅使用到地图选点和定位功能:(其中问题集中在uni.chooseLocation中)下面是api官网地址 uni.getLocation(OBJECT) | uni-app官网 官方建议app端使…

迁移学习基础知识

简介 使用迁移学习的优势: 1、能够快速的训练出一个理想的结果 2、当数据集较小时也能训练出理想的效果。 注意:在使用别人预训练的参数模型时,要注意别人的预处理方式。 原理: 对于浅层的网络结构,他们学习到的…

视频批量剪辑新纪元:轻松调整音频采样率,一键实现高效视频处理!

视频剪辑已成为我们日常生活和工作中不可或缺的一部分。然而,面对大量的视频文件,如何高效地进行批量剪辑,同时又能轻松调整音频采样率,成为了许多视频制作人员、自媒体从业者、教育者和学生的共同需求。 第一步,进入…

[C++基础学习]----02-C++运算符详解

前言 C中的运算符用于执行各种数学或逻辑运算。下面是一些常见的C运算符及其详细说明:下面详细解释一些常见的C运算符类型,包括其原理和使用方法。 正文 01-运算符简介 算术运算符: a、加法运算符():对两个…

4.27日学习打卡----初学Redis(四)

4.27日学习打卡 目录: 4.27日学习打卡一. Redis的配置文件二. Redis构建Web应用实践环境搭建redis的优点引入本地缓存Google 开源工具GuavaGuava实现本地缓存 一. Redis的配置文件 在Redis的解压目录下有个很重要的配置文件 redis.conf ,关于Redis的很多…

达梦(DM) SQL日期操作及分析函数

达梦DM SQL日期操作及分析函数 日期操作SYSDATEEXTRACT判断一年是否为闰年周的计算确定某月内第一个和最后一个周末某天的日期确定指定年份季度的开始日期和结束日期补充范围内丢失的值按照给定的时间单位查找使用日期的特殊部分比较记录 范围处理分析函数定位连续值的范围查找…

如何通过安全数据传输平台,保护核心数据的安全传输?

在数字化的浪潮中,企业的数据安全传输显得尤为关键。随着网络攻击手段的日益复杂,传统的数据传输方式已不再安全,这就需要我们重视并采取有效的措施,通过安全数据传输平台来保护核心数据。 传统的数据传输面临的主要问题包括&…

Bun 入门到精通(一)

Bun 是什么? Bun 是用于 JavaScript 和 TypeScript 应用程序的多合一工具包。它作为一个名为 bun 的可执行文件提供。 其核心是 Bun 运行时,这是一个快速的 JavaScript 运行时,旨在替代 Node.js。它是用 Zig 编写的,并由 JavaSc…

数字文旅重塑旅游发展新格局:以数字化转型为突破口,提升旅游服务的智能化水平,为游客带来全新的旅游体验

随着信息技术的迅猛发展,数字化已成为推动各行各业创新发展的重要力量。在旅游业领域,数字文旅的兴起正以其强大的驱动力,重塑旅游发展的新格局。数字文旅以数字化转型为突破口,通过提升旅游服务的智能化水平,为游客带…

C#基础|OOP、类与对象的认识

哈喽,你好,我是雷工! 所有的面向对象的编程语言,都是把我们要处理的“数据”和“行为”封装到类中。 以下为OOP的学习笔记。 01 什么是面向对象编程(OOP)? 设计类:就是根据需求设计…

论文精读InstructPix2Pix: Learning to Follow Image Editing Instructions

InstructPix2Pix: Learning to Follow Image Editing Instructions 我们提出了一种根据人类指令编辑图像的方法:给定输入图像和告诉模型该做什么的书面指令,我们的模型遵循这些指令来编辑图像。 为了获得这个问题的训练数据,我们结合了两个大型预训练模…

输入输出重定向,追加重定向(Linux)

文章目录 一、输出重定向二、追加重定向三.输入重定向总结 一、输出重定向 我们在使用echo内容时,会把内容显示在显示器上。 echo自动换行。 我们如果输入 echo “hello linux” >file.txt 我们运行一下就会发现系统中多了一个file.txt的文件,如果这…
最新文章