如何更改一个训练好的网络的部分架构+重新训练部分参数,冻结不需要的参数+配合学习器更改(以Deeplabv3+为例)

这里先给出deeplav3+的架构(只给出主体部分):

class DeepLabV3Plus(BaseNet):
    def __init__(self, backbone, nclass):
        super(DeepLabV3Plus, self).__init__(backbone)

        low_level_channels = self.backbone.channels[0]
        high_level_channels = self.backbone.channels[-1]

        self.head = ASPPModule(high_level_channels, (12, 24, 36))

        self.reduce = nn.Sequential(nn.Conv2d(low_level_channels, 48, 1, bias=False),
                                    nn.BatchNorm2d(48),
                                    nn.ReLU(True))

        self.fuse = nn.Sequential(nn.Conv2d(high_level_channels // 8 + 48, 256, 3, padding=1, bias=False),
                                  nn.BatchNorm2d(256),
                                  nn.ReLU(True),

                                  nn.Conv2d(256, 256, 3, padding=1, bias=False),
                                  nn.BatchNorm2d(256),
                                  nn.ReLU(True),
                                  nn.Dropout(0.1, False))

        self.classifier = nn.Conv2d(256, nclass, 1, bias=True)

    def base_forward(self, x):
        h, w = x.shape[-2:]

        c1, _, _, c4 = self.backbone.base_forward(x)

        c4 = self.head(c4)
        c4 = F.interpolate(c4, size=c1.shape[-2:], mode="bilinear", align_corners=True)

        c1 = self.reduce(c1)

        out = torch.cat([c1, c4], dim=1)
        out = self.fuse(out)

        out = self.classifier(out)
        out = F.interpolate(out, size=(h, w), mode="bilinear", align_corners=True)

        return out

假设刚开始我们的模型的nclass设置为15,整个网络模型已经经过了一次训练,保存的模型参数为best.pth 现在我想改变这个网络的最后一层self.classifier,将它的输出改为2个通道(原来是15个通道)

1 首先我们加载预训练模型

model = DeepLabV3Plus(backbone, nclass)
model.load_state_dict(torch.load('best.pth'))

2 然后,我们需要更改那一层的分类器

in_features = model.classifier.in_channels
new_out_features = 15  # 新的类别数量
#构建新的分类层,其他参数可以自定义
model.classifier = nn.Conv2d(in_features, new_out_features, kernel_size=1, bias=True)

3 对参数的可学习性进行设置(冻结 or no)

# 先冻结所有层的参数
for param in model.parameters():
    param.requires_grad = False
# 然后将新的层的参数可学习性设置为True
model.classifier.requires_grad = True  

4 优化器要做出相应改动(如果本身就是选择grad为true的参数进行更新,那不需要更改)

# 定义优化器,只更新分类器层的参数
optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=0.001, momentum=0.9)

这样,得到的model便可以进行使用了。

我利用gpt写了一段训练的完整流程代码(pass 部分自行填充即可):

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.transforms import Compose, Resize, ToTensor
from torchvision.datasets import YourDataset  # 替换为你的数据集类
from tqdm import tqdm

# 定义数据集
transform = Compose([
    Resize((256, 256)),
    ToTensor()
])
dataset = YourDataset(root='path/to/your/dataset', transform=transform)  # 替换为你的数据集路径

# 定义数据加载器
batch_size = 32
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# 定义模型
class DeepLabV3Plus(nn.Module):
    def __init__(self, backbone, nclass):
        super(DeepLabV3Plus, self).__init__()
        # 你的模型定义

    def forward(self, x):
        # 前向传播逻辑
        pass

backbone = None  # 你的骨干网络
nclass = 10  # 类别数量
model = DeepLabV3Plus(backbone, nclass)

# 加载预训练模型
model.load_state_dict(torch.load('aaa.pth'))

# 修改分类器层
in_features = model.classifier.in_channels
new_out_features = 10  # 新的类别数量
model.classifier = nn.Conv2d(in_features, new_out_features, kernel_size=1, bias=True)

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

# 训练过程
num_epochs = 10
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
model.train()

for epoch in range(num_epochs):
    running_loss = 0.0
    for images, labels in tqdm(data_loader, desc=f'Epoch {epoch + 1}/{num_epochs}'):
        images, labels = images.to(device), labels.to(device)

        # 梯度清零
        optimizer.zero_grad()

        # 前向传播
        outputs = model(images)

        # 计算损失
        loss = criterion(outputs, labels)

        # 反向传播
        loss.backward()

        # 参数更新
        optimizer.step()

        running_loss += loss.item() * images.size(0)

    epoch_loss = running_loss / len(dataset)
    print(f'Epoch {epoch + 1}/{num_epochs}, Loss: {epoch_loss:.4f}')

print('Finished Training')

制作不易,如有帮助请点赞一下哦!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/588175.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

MATLAB 数据导入

MATLAB 数据导入(ImportData) 在MATLAB中导入数据意味着从外部文件加载数据。该importdata功能允许加载不同格式的各种数据文件。它具有以下五种形式 序号 功能说明 1 A importdata(filename) 从filename表示的文件中将数据加载到数组A中。 2 A i…

Electron+Vue3+Vite+ElectronForge整合-全部ts开发 - 一键启动两个服务 一键打包两个服务

说明 本文介绍一下 Electron Vue3 Vite Electron Forge 的高级整合操作。vue3 : 使用 TS 的语法开发; Electron : 使用 TS 的语法开发。 补充 : 目前Electron的开发还是以JS为主,不过我们可以直接使用TS开发,在执行和打包时&a…

UE5 蓝图入门

基础节点创建: 常量: 按住 1 ,点击鼠标左键,创建常量 二维向量: 按住 2 ,点击鼠标左键,创建二维向量 三维向量: 按住 3 ,点击鼠标左键 按 c 键打出一个注释框 参考视…

C# Winform父窗体打开新的子窗体前,关闭其他子窗体

随着Winform项目越来越多,界面上显示的窗体越来越多,窗体管理变得更加繁琐。有时候我们要打开新窗体,然后关闭多余的其他窗体,这个时候如果一个一个去关闭就会变得很麻烦,而且可能还会出现遗漏的情况。这篇文章介绍了三…

HR招聘测评,如何进行人才测评?

说起“人才测评”几个字,相信大家都不会陌生,很多人,尤其是求职者来说,则更加熟悉。在求职应聘中,已经有越来越多的企业开始采用人才测评进行人员选拔。了解人才测评的含义,知道人才测评如何进行&#xff0…

打破失联困境:门店如何利用AI智能名片B2B2C商城小程序重构与消费者的紧密连接?

在如今这个消费者行为日益碎片化的时代,门店经营者们时常感叹:消费者进店如同一场不期而遇的缘分,然而一旦离开门店,就仿佛消失在茫茫人海中,难以再觅其踪迹。这种“进店靠缘分,离店就失联”的困境&#xf…

本地大语言模型LLM的高效运行专家 | Ollama

Ollama简介 Ollama是一个开源的大型语言模型服务工具,它帮助用户快速在本地运行大模型。通过简单的安装指令,用户可以执行一条命令就在本地运行开源大型语言模型,如Llama 2。Ollama极大地简化了在Docker容器内部署和管理LLM的过程&#xff0…

平面模型上提取凸凹多边形------pcl

平面模型上提取凸凹多边形 pcl::PointCloud<pcl::PointXYZ>::Ptr PclTool::ExtractConvexConcavePolygons(pcl::PointCloud<pcl::PointXYZ>::Ptr cloud) {pcl::PointCloud<pcl::PointXYZ>::Ptr cloud_filtered(new pcl::PointCloud<pcl::PointXYZ>);p…

政安晨:【Keras机器学习示例演绎】(二十八)—— 使用 卷积神经网络与循环神经网络 架构进行视频分类

目录 数据收集 设置 定义超参数 数据准备 序列模型 推论 政安晨的个人主页&#xff1a;政安晨 欢迎 &#x1f44d;点赞✍评论⭐收藏 收录专栏: TensorFlow与Keras机器学习实战 希望政安晨的博客能够对您有所裨益&#xff0c;如有不足之处&#xff0c;欢迎在评论区提出指正…

Android Handler用法

Android Handler用法 为什么要设计Handler机制&#xff1f;Handler的用法1、创建Handler2、Handler通信2.1 sendMessage 方式2.2 post 方式 Handler常用方法1、延时执行2、周期执行 HandlerThread用法主线程-创建Handler子线程-创建Handler FAQMessage是如何创建主线程中Looper…

微服务保护和分布式事务(Sentinel、Seata)笔记

一、雪崩问题的解决的服务保护技术了解 二、Sentinel 2.1Sentinel入门 1.Sentinel的安装 &#xff08;1&#xff09;下载Sentinel的tar安装包先 &#xff08;2&#xff09;将jar包放在任意非中文、不包含特殊字符的目录下&#xff0c;重命名为 sentinel-dashboard.jar &…

Docker容器---Harbor私有仓库部署与管理

一、搭建本地私有仓库 1、下载registry镜像 [rootlocalhost ~]#docker pull registry Using default tag: latest latest: Pulling from library/registry 79e9f2f55bf5: Pull complete 0d96da54f60b: Pull complete 5b27040df4a2: Pull complete e2ead8259a04: Pull comp…

vulnhub靶场之FunBox-1

一.环境搭建 1.靶场描述 Boot2Root ! This is a reallife szenario, but easy going. You have to enumerate and understand the szenario to get the root-flag in round about 20min. This VM is created/tested with Virtualbox. Maybe it works with vmware. If you n…

NASA数据集——NASA 标准二级(L2)暗目标(DT)气溶胶产品每 6 分钟在全球范围内对陆地和海洋上空的气溶胶光学厚度(AOT)产品

VIIRS/NOAA20 Dark Target Aerosol 6-Min L2 Swath 6 km 简介 NOAA-20&#xff08;前身为联合极地卫星系统-1&#xff08;JPSS-1&#xff09;&#xff09;--可见红外成像辐射计套件&#xff08;VIIRS&#xff09;NASA 标准二级&#xff08;L2&#xff09;暗目标&#xff08;D…

集合的基本操作

集合&#xff1a; 在java当中&#xff0c;含有着一些不同的存储数据的相关集合。分为单列集合&#xff08;Collection&#xff09;和双列集合(Map)。 Collection 首先学习Collection来进行展示&#xff1a; 以框框为例子&#xff0c;蓝色的代表的是接口&#xff0c;而红色的…

【Linux极简教程】常见实用命令不断更新中......

【Linux极简教程】常见实用命令不断更新中...... 常见问题1.Waiting for cache lock: Could not get lock /var/lib/dpkg/lock. It is held by process xxxx(dpkg) 常见问题 1.Waiting for cache lock: Could not get lock /var/lib/dpkg/lock. It is held by process xxxx(dp…

机器学习:基于Sklearn、XGBoost,使用逻辑回归、支持向量机和XGBClassifier预测股票价格

前言 系列专栏&#xff1a;机器学习&#xff1a;高级应用与实践【项目实战100】【2024】✨︎ 在本专栏中不仅包含一些适合初学者的最新机器学习项目&#xff0c;每个项目都处理一组不同的问题&#xff0c;包括监督和无监督学习、分类、回归和聚类&#xff0c;而且涉及创建深度学…

C语言——队列的实现

队列按照先进先出&#xff08;FIFO&#xff0c;First In First Out&#xff09;的原则管理数据。这意味着最先进入队列的元素会被最先移出&#xff0c;类似于排队等候服务的情况。队列通常有两个主要操作&#xff1a;入队&#xff08;enqueue&#xff09;&#xff0c;将元素添加…

DSP实时分析平台设计方案:924-6U CPCI振动数据DSP实时分析平台

6U CPCI振动数据DSP实时分析平台 一、产品概述 基于CPCI结构完成40路AD输入&#xff0c;30路DA输出的信号处理平台&#xff0c;处理平台采用双DSPFPGA的结构&#xff0c;DSP采用TI公司新一代DSP TMS320C6678&#xff0c;FPGA采用Xilinx V5 5VLX110T-1FF1136芯片&#xff…

《QT实用小工具·五十》动态增删数据与平滑缩放移动的折线图

1、概述 源码放在文章末尾 该项目实现了带动画、带交互的折线图&#xff0c;包含如下特点&#xff1a; 动态增删数值 自适应显示坐标轴数值 鼠标悬浮显示十字对准线 鼠标靠近点自动贴附 支持直线与平滑曲线效果 自定义点的显示类型与大小 自适应点的数值显示位置 根据指定锚点…