修改YOLOv5的模型结构第二弹

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊 | 接辅导、项目定制
  • 🚀 文章来源:K同学的学习圈子

上节说到了通过修改YOLOv5的common.py来修改模型的结构,修改的是模块的内部结构,具体某些模块组织顺序等,如插入一个新的模型,则需要在yolo.py文件中修改

yolo.py文件

yolo.py中有几个主要的组成部分

parse_model函数

主要负责读取传入的–cfg配置指定的模型配置文件。例如经常使用的yolov5s.yaml,通过配置文件创建实际的模块对象,并将模块拼接起来。

Detect类

主要用来构建Detect层,将输入的feature map通过一个卷积操作和公式计算得到想要的shape,为后面计算损失或者NMS做准备

Model类

这个类实现的是整个模型的搭建。YOLOv5的作者在其中还加入了很多功能,例如:特征可视化、打印模型信息、TTA推理增强、融合Conv+Bn加速推理、模型搭载NMS功能、autoshape函数等等。

修改模型

任务

参考C3模块,创建一个C2模块,并插入到模型的第二、三层之间
C2模型
总模型结构
可以发现C2模块就是上一篇文章中,对模型C3模块的修改结果。

步骤

由于要插入一个新的模块,首先就是要在commm.py里仿造C3模块,新增一个可以创建C2模块的方法

class C3(nn.Module):
    # CSP Bottleneck with 3 convolutions
    def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):  # ch_in, ch_out, number, shortcut, groups, expansion
        super().__init__()
        c_ = int(c2 * e)  # hidden channels
        self.cv1 = Conv(c1, c_, 1, 1)
        self.cv2 = Conv(c1, c_, 1, 1)
        self.cv3 = Conv(2 * c_, c2, 1)  # optional act=FReLU(c2)
        self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)))

    def forward(self, x):
        return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), 1))
    
class C2(nn.Module):
    # CSP Bottleneck with 3 convolutions
    def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):  # ch_in, ch_out, number, shortcut, groups, expansion
        super().__init__()
        c_1 = c2 // 2
        c_2 = c2 - c1  # hidden channels
        self.cv1 = Conv(c1, c_2, 1, 1)
        self.cv2 = Conv(c1, c_1, 1, 1)
        self.m = nn.Sequential(*(Bottleneck(c_2, c_2, shortcut, g, e=1.0) for _ in range(n)))

    def forward(self, x):
        return torch.cat((self.m(self.cv1(x)), self.cv2(x)), 1)

然后修改yolo.py中构建模型的部分,增加C2模块

def parse_model(d, ch):  # model_dict, input_channels(3)
    # Parse a YOLOv5 model.yaml dictionary
    LOGGER.info(f"\n{'':>3}{'from':>18}{'n':>3}{'params':>10}  {'module':<40}{'arguments':<30}")
    anchors, nc, gd, gw, act = d['anchors'], d['nc'], d['depth_multiple'], d['width_multiple'], d.get('activation')
    if act:
        Conv.default_act = eval(act)  # redefine default activation, i.e. Conv.default_act = nn.SiLU()
        LOGGER.info(f"{colorstr('activation:')} {act}")  # print
    na = (len(anchors[0]) // 2) if isinstance(anchors, list) else anchors  # number of anchors
    no = na * (nc + 5)  # number of outputs = anchors * (classes + 5)

    layers, save, c2 = [], [], ch[-1]  # layers, savelist, ch out
    for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']):  # from, number, module, args
        m = eval(m) if isinstance(m, str) else m  # eval strings
        for j, a in enumerate(args):
            with contextlib.suppress(NameError):
                args[j] = eval(a) if isinstance(a, str) else a  # eval strings

        n = n_ = max(round(n * gd), 1) if n > 1 else n  # depth gain
        if m in {
                Conv, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, MixConv2d, Focus, CrossConv,
                BottleneckCSP, C3, C2, C3TR, C3SPP, C3Ghost, nn.ConvTranspose2d, DWConvTranspose2d, C3x}:
            c1, c2 = ch[f], args[0]
            if c2 != no:  # if not output
                c2 = make_divisible(c2 * gw, 8)

            args = [c1, c2, *args[1:]]
            if m in {BottleneckCSP, C3,C2, C3TR, C3Ghost, C3x}:
                args.insert(2, n)  # number of repeats
                n = 1
        elif m is nn.BatchNorm2d:
            args = [ch[f]]
        elif m is Concat:
            c2 = sum(ch[x] for x in f)
        # TODO: channel, gw, gd
        elif m in {Detect, Segment}:
            args.append([ch[x] for x in f])
            if isinstance(args[1], int):  # number of anchors
                args[1] = [list(range(args[1] * 2))] * len(f)
            if m is Segment:
                args[3] = make_divisible(args[3] * gw, 8)
        elif m is Contract:
            c2 = ch[f] * args[0] ** 2
        elif m is Expand:
            c2 = ch[f] // args[0] ** 2
        else:
            c2 = ch[f]

        m_ = nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args)  # module
        t = str(m)[8:-2].replace('__main__.', '')  # module type
        np = sum(x.numel() for x in m_.parameters())  # number params
        m_.i, m_.f, m_.type, m_.np = i, f, t, np  # attach index, 'from' index, type, number params
        LOGGER.info(f'{i:>3}{str(f):>18}{n_:>3}{np:10.0f}  {t:<40}{str(args):<30}')  # print
        save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1)  # append to savelist
        layers.append(m_)
        if i == 0:
            ch = []
        ch.append(c2)
    return nn.Sequential(*layers), sorted(save)

最后修改模型的配置文件,将yolov5s.yaml另存为yolov5s_aug.yaml并修改

# YOLOv5 v6.0 backbone
backbone:
  # [from, number, module, args]
  [[-1, 1, Conv, [64, 6, 2, 2]],  # 0-P1/2
   [-1, 1, Conv, [128, 3, 2]],  # 1-P2/4
   [-1, 3, C3, [128]],
   [-1, 3, C2, [128]],
   [-1, 1, Conv, [256, 3, 2]],  # 3-P3/8
   [-1, 6, C3, [256]],
   [-1, 1, Conv, [512, 3, 2]],  # 5-P4/16
   [-1, 9, C3, [512]],
   [-1, 1, Conv, [1024, 3, 2]],  # 7-P5/32
   [-1, 3, C3, [1024]],
   [-1, 1, SPPF, [1024, 5]],  # 9
  ]

如此改造就完成了,可以使用训练集训练一下当前的yolov5s_aug模型
训练的结果如下:
修改后的模型

对比没有修改前的模型训练结果

修改前模型
对比发现增加这层模块没有对整个模型带来好的结果,每种分类的检测结果都变得更差了,上篇文章也分析过,C2模块由于缺少了最后的卷积重排,会导致特征的表达变差。

根据经验,cat操作后面最好是经过全连接层或者卷积再提取一下特征,直接使用的效果比较差。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/158023.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【图解算法】- 异位词问题:双指针+哈希表

一 - 前言 介绍&#xff1a;大家好啊&#xff0c;我是hitzaki辰。 社区&#xff1a;&#xff08;完全免费、欢迎加入&#xff09;日常打卡、学习交流、资源共享的知识星球。 自媒体&#xff1a;我会在b站/抖音更新视频讲解 或 一些纯技术外的分享&#xff0c;账号同名&#xff…

代码随想录算法训练营第23天|669. 修剪二叉搜索树 108.将有序数组转换为二叉搜索树 538.把二叉搜索树转换为累加树

JAVA代码编写 669. 修剪二叉搜索树 给你二叉搜索树的根节点 root &#xff0c;同时给定最小边界low 和最大边界 high。通过修剪二叉搜索树&#xff0c;使得所有节点的值在[low, high]中。修剪树 不应该 改变保留在树中的元素的相对结构 (即&#xff0c;如果没有被移除&#x…

企业是否需要单独一套设备管理系统?

在现代企业中&#xff0c;设备管理是一个至关重要的环节。随着科技的不断进步和信息化的发展&#xff0c;企业对设备管理的要求也越来越高。为了提高设备管理的效率和准确性&#xff0c;许多企业开始考虑是否需要单独一套设备管理系统。本文将从设备管理系统的介绍、和其他系统…

腾讯云服务器怎么买便宜?腾讯云服务器优惠链接

现在&#xff0c;让我们一起探索如何在腾讯云服务器上购买便宜的云服务器吧&#xff01; 首先&#xff0c;我们来看看都有哪些便宜的腾讯云服务器值得我们入手吧&#xff01; 首先是轻量2核2G3M服务器&#xff0c;只需要一年88元就能轻松拥有&#xff0c;对于刚开始接触云服务…

linux查看pcie速率

先通过lspci命令查找pcie对应设备编号&#xff0c;如下图中为01:00.0 再通过以下命令查找上一步编号对应设备带宽信息&#xff0c;如下图中为8GT/s lspci -n -s 01:00.0 -vvv | grep --color Width

简单解决网页的验证码

翻到一个网站,展开需要验证码,而验证码需要关注微信公众号,懒得弄,所以有了这篇文章 首先,先看一下F12中的网络(Network),发现并没有使用网络动态验证 那么这个验证码必定是写在资源文件中的 在确定按钮上看到如下元素监听(Event Listeners) 进入打断点 成功断下 单步跟到…

数据库常见面试题

存储引擎 InnoDB&#xff08;默认&#xff09; 存储引擎的对比 MYISAM被MangoDB替代了 MEMORY被Redis替代了 索引 是一种高效获取数据的数据结构 索引结构 二叉树&#xff0c;红黑树&#xff08;都不合适&#xff09; B树 插入超过5个数&#xff0c;会从中间分裂 B树 …

苍穹外卖项目笔记(2)

1 Nginx 反向代理和负载均衡 1.1 概念 【Tips】可以看到前端请求地址和后端接口地址并不匹配&#xff0c;这里涉及到 nginx 反向代理 &#xff0c;就是将前端发送的动态请求由 nginx 转发到后端服务器 使用 nginx 作反向代理的好处&#xff1a; 提高访问速度&#xff08;在请…

时间序列预测中的4大类8种异常值检测方法(从根源上提高预测精度)

一、本文介绍 本文给大家带来的是时间序列预测中异常值检测&#xff0c;在我们的数据当中有一些异常值&#xff08;Outliers&#xff09;是指在数据集中与其他数据点显著不同的数据点。它们可能是一些极端值&#xff0c;与数据集中的大多数数据呈现明显的差异。异常值可能由于…

Linux系统下安装go

目录 下载go安装包解压包并安装添加环境变量验证是否安装成功 下载go安装包 官网地址&#xff1a;go 解压包并安装 复制好包的下载链接后使用下面命令进行安装&#xff1a; curl -O https://storage.googleapis.com/golang/go1.11.1.linux-amd64.tar.gz mkdir -p ~/installe…

结构体数组保存进二进制文件的简单做法

作者&#xff1a;朱金灿 来源&#xff1a;clever101的专栏 为什么大多数人学不会人工智能编程&#xff1f;>>> 最近面临这样一个需求&#xff1a;以比较节省存储空间的存储一组坐标点到文件&#xff0c;要求程序能够跨平台读写这种文件。思考了一下&#xff0c;比较…

Kafka学习笔记(一)

目录 第1章 Kafka概述1.1 消息队列&#xff08;Message Queue&#xff09;1.1.1 传统消息队列的应用场景1.1.2 消息队列的两种模式 1.2 定义 第2章 Kafka快速入门2.1 安装部署2.1.1 集群规划2.1.2 jar包下载2.1.3 集群部署 2.2 Kafka命令行操作 第3章 Kafka架构深入3.1 Kafka工…

23111704[含文档+PPT+源码等]计算机毕业设计springboot办公管理系统oa人力人事办公

文章目录 **软件开发环境及开发工具&#xff1a;****功能介绍&#xff1a;****实现&#xff1a;****代码片段&#xff1a;** 编程技术交流、源码分享、模板分享、网课教程 &#x1f427;裙&#xff1a;776871563 软件开发环境及开发工具&#xff1a; 前端技术&#xff1a;jsc…

实例解释遇到前端报错时如何排查问题

前端页面报错&#xff1a; 1、页面报错500&#xff0c;首先我们可以知道是服务端的问题&#xff0c;需要去看下服务端的报错信息&#xff1a; 2、首先我们查看下前端是否给后端传了id: 我们可以看到接口是把ID返回了&#xff0c;就需要再看下p_id是什么情况了。 3、我们再次请…

【C++】多线程的学习笔记(3)——白话文版(bushi

前言 好久没有继续写博客了&#xff0c;原因就是去沉淀了一下偷懒了一下 现在在学网络编程&#xff0c;c的多线程也还在学 这一变博客就讲讲c中的Condition Variable库吧 Condition Variable的简介 官方原文解释 翻译就是 条件变量是一个对象&#xff0c;它能够阻止调用…

腾讯云服务器秒杀什么时候开始?腾讯云服务器秒杀时间

腾讯云服务器秒杀什么时候开始呢&#xff1f;我们一起来揭晓答案&#xff01; 腾讯云服务器秒杀活动即日起至2023-11-30 23:59:59&#xff0c;每日0点限量秒杀。这意味着&#xff0c;每一天的开始&#xff0c;你都有机会抢到心仪的服务器。秒杀活动入口&#xff1a;https://te…

面试题-3

1.说一下原型链 原型就是一个普通对象,它是为构造函数实例共享属性和方法&#xff0c;所有实例中引用原型都是同一个对象 使用prototype可以把方法挂载在原型上&#xff0c;内存值保存一致 _proto_可以理解为指针,实例对象中的属性,指向了构造函数的原型(prototype) 2.new操…

image图片之间的间隙消除

多个图片排列展示&#xff0c;水平和垂直方向的间隔如何消除 垂直方向 vertical-align 原因&#xff1a; vertical-align属性主要用于改变行内元素的对齐方式&#xff0c;行内元素默认垂直对齐方式是基线对齐&#xff08;baseline&#xff09; 这是因为图片属于行内元素&…

双极性集成电路芯片 D7312,可用于小型收录机中作前置放大电路。电源开关冲击噪音小、 反应快

一块双极性集成电路芯片 D7312。可用于小型收录机中作前置放大电路。 主要特点&#xff1a; ● 含ALC电路和ALC检波电路。 ● 外接元件少。 ● 增益高&#xff0c;噪声低。 ● 静态电流小 ● 电源开关冲击噪音小、 反应快 ● 具有过热保护功能 …

爬取全国高校数据 (高校名称,高校所在地,高校类型,高校性质,高校特色,高校隶属,学校网站)

爬取全国高校数据 网站&#xff1a; 运行下面代码得到网站. import base64 # 解码 website base64.b64decode(IGh0dHA6Ly9jb2xsZWdlLmdhb2thby5jb20vc2NobGlzdC8.encode(utf-8)) print(website)分析&#xff1a; 我们需要爬取的字段&#xff0c;高校名称&#xff0c;高校所…
最新文章