[杂记]mmdetection3.x中的数据流与基本流程详解(数据集读取, 数据增强, 训练)


之前跑了一下mmdetection 3.x自带的一些算法, 但是具体的代码细节总是看了就忘, 所以想做一些笔记, 方便初学者参考. 其实比较不能忍的是, 官网的文档还是空的…

在这里插入图片描述

这次想写其中的数据流是如何运作的, 包括从读取数据集的样本与真值, 到数据增强, 再到模型的forward当中.


0. MMDetection整体组成部分

让我们首先回顾一下C++的标准模板库(STL)是怎样设计的. STL的三个核心组件是容器, 算法与迭代器. 容器, 例如vector, queue等等, 他们是负责存储数据的, 算法是负责进行一些操作, 例如排序, 查找等等. 而迭代器是容器与算法之间的桥梁, 也就是算法可以通过迭代器去访问容器, 使得算法可以独立于容器的类型进行操作. 三个部分相辅相成, 就达到了泛型编程的理念.

再让我们回顾一下一套深度学习的代码包含什么部分. 从大的方面来说, 需要有数据的读取与增强(DataLoader), 模型的定义, 损失函数的计算, 负责梯度传播的优化器, 在验证(测试)集上的评估等. 同理, MMDetection也是按照这种方式来的, 并且每个部分接口相通, 就可以实现更广义的模型定义和训练方式.

mmengine/registry/__init__.py中, 我们可以看到, MMEngine(或者说MMDetection)总体有这些类型的模块:

from .root import (DATA_SAMPLERS, DATASETS, EVALUATOR, FUNCTIONS, HOOKS,
                   INFERENCERS, LOG_PROCESSORS, LOOPS, METRICS, MODEL_WRAPPERS,
                   MODELS, OPTIM_WRAPPER_CONSTRUCTORS, OPTIM_WRAPPERS,
                   OPTIMIZERS, PARAM_SCHEDULERS, RUNNER_CONSTRUCTORS, RUNNERS,
                   TASK_UTILS, TRANSFORMS, VISBACKENDS, VISUALIZERS,
                   WEIGHT_INITIALIZERS)

那么以上这么多模块可以分成几类, 分别负责什么呢? 按照我个人的理解, MMDetection的整体组成部分可以表示为下图:

在这里插入图片描述

为了节省空间, 优化器相关并未画出

1. 认识config文件

mmdetection设计的核心思想是通过字典来配置整个的训练过程和模型定义, 这些字典放在一个.py的config文件中. 一般来说,config文件最重要的就是数据加载(train_dataloader, val_dataloader和test_dataloader), 模型定义(model)和训练与测试过程(train_pipeline, test_pipeline). 除此之外, 还有一些训练, 测试配置(train_cfg, test_cfg)等等. 具体config的例子可以参照官网Learn about configs.

需要注意的是, mmdetection中字典定义class的方式, 往往是键type表示类的名字, 之后的其他键都是类初始化需要的参数. 例如, 如果我想自定义一个模型, 叫做MyModel, 定义在当前目录下的./models/my_model.py中, 定义方式如下:


from mmdet.registry import MODELS  # 自定义模型, 需要在模型库中"注册", 初始化时才能找到定义
from mmdet.models.mot.base import BaseMOTModel  # 一个模型基类

@MODELS.register_module()  # 装饰器 在模型库中"注册"
class MyModel(BaseMOTModel):
    def __init__(self, 
                 arg1=..., 
                 arg2=..., 
                 arg3=...):
        ...
	def loss(self, inputs, data_samples):  # 前向传播, inputs是输入tensor, data_samples是包含标签的列表
		...

如果按上述方式定义了模型, 那么在我们的配置文件中, 就是这个样子:


# 必须将自定义类的py文件导入 这样可以自动register自定义模型 否则模型初始化时找不到

custom_imports = dict(
    imports=['models.my_model'],
    allow_failed_imports=False)

# 现在就可以愉快的传参了
models=dict(
    type='MyModel', 
    arg1=1, 
    arg2=[16, 128], 
    arg3=dict(channel=256), 
    ...
)

同样, 我们可以自定义DataLoader, Loss, 等等.

此外, dict是可以嵌套的, 例如mmdetection将检测模型分成了backbone, neck和head三部分, 那么如果我们又自定义了一个Head, 叫MyHead:


from mmdet.registry import MODELS  # 自定义模型, 需要在模型库中"注册", 初始化时才能找到定义
from mmengine.model import BaseModule  # 一个模型基类

@MODELS.register_module()  # 装饰器 在模型库中"注册"
class MyHead(BaseModule):
    def __init__(self, 
                 arg4=...):
        ...

这样, 如果MyModel的前向传播过程中需要一个head, 则代码大致是这个样子:


from mmdet.registry import MODELS  # 自定义模型, 需要在模型库中"注册", 初始化时才能找到定义
from mmdet.models.mot.base import BaseMOTModel  # 一个模型基类

@MODELS.register_module()  # 装饰器 在模型库中"注册"
class MyModel(BaseMOTModel):
    def __init__(self, 
                 arg1=..., 
                 arg2=..., 
                 arg3=...,
                 head=...):
        self.head = MODELS.build(head)  # 建立Head的模型, 类型是nn.Module
        ...
	def loss(self, inputs, data_samples):  # 前向传播, inputs是输入tensor, data_samples是包含标签的列表
		...  # 一些其他过程
		ret = self.head(inputs)  # forward
		...  # 后处理

配置文件中对应更改为:

如果按上述方式定义了模型, 那么在我们的配置文件中, 就是这个样子:


custom_imports = dict(
    imports=['models.my_model', '自定义HEAD所在的py文件'],
    allow_failed_imports=False)

models=dict(
    type='MyModel', 
    arg1=1, 
    arg2=[16, 128], 
    arg3=dict(channel=256), 
    head=dict(  # 定义head
        type='MyHead',
        arg4=256,
        ...
	)
    ...
)

篇幅所限, 自定义损失函数, 数据增强之类的就不一一列举了.

2. 数据流

我们接下来以检测与跟踪任务为例, 看看数据到底是如何被读入的. 我们以训练过程说明.

在训练过程中, 我们会初始化一个RUNNER类, 其读入我们的config文件并依次完成各种(模型, 数据加载, 优化器, 钩子等等)的初始化. 我们以官方提供的train.py为例:

runner = Runner.from_cfg(cfg)

from_cfg()是一个类方法(classmethod), 在其中我们实例化了Runner类.

随后, 我们调用Runnertrain()方法进行训练. 首先, 我们实例化训练循环:

        self._train_loop = self.build_train_loop(
            self._train_loop)  # type: ignore

训练循环就属于LOOP类型.

在这里, 我们以最常用的EpochBasedTrainLoop为例. 在EpochBasedTrainLoop的初始化函数中, 根据config文件中的train_dataloader字典实例化出torchDataLoader类():
在这里插入图片描述

        data_loader = DataLoader(
            dataset=dataset,
            sampler=sampler if batch_sampler is None else None,
            batch_sampler=batch_sampler,
            collate_fn=collate_fn,
            worker_init_fn=init_fn,
            **dataloader_cfg)
        return data_loader

当然, 我们知道torch的DataLoader类在调用的时候, 会调用到dataset(类别是torch.utils.data.Dataset)的__getitem__方法. 因此, 我们从__getitem__入手来探索数据流.

在MMDetection的设计中, 数据集的类都是继承于MMengine中的BaseDataset, 其中的__getitem__是这样写的:
在这里插入图片描述

    def __getitem__(self, idx: int) -> dict:
        if not self._fully_initialized:
            print_log(
                'Please call `full_init()` method manually to accelerate '
                'the speed.',
                logger='current',
                level=logging.WARNING)
            self.full_init()

        if self.test_mode:
            data = self.prepare_data(idx)
            if data is None:
                raise Exception('Test time pipline should not get `None` '
                                'data_sample')
            return data

        for _ in range(self.max_refetch + 1):
            data = self.prepare_data(idx)
            # Broken images or random augmentations may cause the returned data
            # to be None
            if data is None:
                idx = self._rand_another()
                continue
            return data

        raise Exception(f'Cannot find valid image after {self.max_refetch}! '
                        'Please check your image path and pipeline')

我们可以看到, 在__getitem__中最核心的是self.prepare_data(idx). 按照这种思路一级一级向上查找, 我们就可以总结出如下图的数据读取流程:

在这里插入图片描述
其中, 数据增强pipeline是一系列类型为TRANSFORMS类的列表, 再每经过一次数据增强时, 字典都会被更新.

我们以较为常用的随机便宜(RandomShift)来说, 其是这样定义的:


@TRANSFORMS.register_module()
class RandomShift(BaseTransform):

    def __init__(self,
        ...

    @autocast_box_type()
    def transform(self, results: dict) -> dict:  # transform方法, 更新字典, 图像与对应的边界框等都需要被更新
        """Transform function to random shift images, bounding boxes.

        Args:
            results (dict): Result dict from loading pipeline.

        Returns:
            dict: Shift results.
        """
        if self._random_prob() < self.prob:
            img_shape = results['img'].shape[:2]

            random_shift_x = random.randint(-self.max_shift_px,
                                            self.max_shift_px)
            random_shift_y = random.randint(-self.max_shift_px,
                                            self.max_shift_px)
            new_x = max(0, random_shift_x)
            ori_x = max(0, -random_shift_x)
            new_y = max(0, random_shift_y)
            ori_y = max(0, -random_shift_y)

            # TODO: support mask and semantic segmentation maps.
            bboxes = results['gt_bboxes'].clone()
            bboxes.translate_([random_shift_x, random_shift_y])

            # clip border
            bboxes.clip_(img_shape)

            # remove invalid bboxes
            valid_inds = (bboxes.widths > self.filter_thr_px).numpy() & (
                bboxes.heights > self.filter_thr_px).numpy()
            # If the shift does not contain any gt-bbox area, skip this
            # image.
            if not valid_inds.any():
                return results
            bboxes = bboxes[valid_inds]
            results['gt_bboxes'] = bboxes
            results['gt_bboxes_labels'] = results['gt_bboxes_labels'][
                valid_inds]

            if results.get('gt_ignore_flags', None) is not None:
                results['gt_ignore_flags'] = \
                    results['gt_ignore_flags'][valid_inds]

            # shift img
            img = results['img']
            new_img = np.zeros_like(img)
            img_h, img_w = img.shape[:2]
            new_h = img_h - np.abs(random_shift_y)
            new_w = img_w - np.abs(random_shift_x)
            new_img[new_y:new_y + new_h, new_x:new_x + new_w] \
                = img[ori_y:ori_y + new_h, ori_x:ori_x + new_w]
            results['img'] = new_img

        return results

需要注意的是, 经过pipeline后, 字典最终会被更新成如下形式:

dict = {'inputs': torch.Tensor, 
        'data_samples': DetDataSample或TrackDataSample等}

其中'inputs'键对应的值就是转换为tensor的图片, 而'data_samples'键对应的值是表示样本的类, 在检测任务中, 是DetDataSample, 跟踪任务中, 是TrackDataSample. DetDataSample类有许多成员, 包括该样本(图片)的目标的边界框真值, 分割真值等:

在这里插入图片描述

class DetDataSample(BaseDataElement):
    """A data structure interface of MMDetection. They are used as interfaces
    between different components.

    The attributes in ``DetDataSample`` are divided into several parts:

        - ``proposals``(InstanceData): Region proposals used in two-stage
            detectors.
        - ``gt_instances``(InstanceData): Ground truth of instance annotations.
        - ``pred_instances``(InstanceData): Instances of detection predictions.
        - ``pred_track_instances``(InstanceData): Instances of tracking
            predictions.
        - ``ignored_instances``(InstanceData): Instances to be ignored during
            training/testing.
        - ``gt_panoptic_seg``(PixelData): Ground truth of panoptic
            segmentation.
        - ``pred_panoptic_seg``(PixelData): Prediction of panoptic
           segmentation.
        - ``gt_sem_seg``(PixelData): Ground truth of semantic segmentation.
        - ``pred_sem_seg``(PixelData): Prediction of semantic segmentation.

以上过程可以借用MMEngine文档里的一个图说明:

在这里插入图片描述

最终, 模型的forward, loss, predict等方法都是接收inputs: torch.Tensordata_samples作为输入, 例如:

在这里插入图片描述

def loss(self, inputs: Tensor, data_samples: TrackSampleList,
             **kwargs) -> Union[dict, tuple]:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/396188.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

打字侠,提供免费的五笔打字练习

在当今数字化时代&#xff0c;打字已成为生活和工作中不可或缺的技能之一。特别是在办公室环境中&#xff0c;快速准确地输入文字对提高工作效率至关重要。而对于许多中文输入法用户来说&#xff0c;五笔输入法因其高效和便捷而备受青睐。 然而&#xff0c;掌握五笔输入法并非…

JVM原理

一、java虚拟机的生命周期&#xff1a; Java虚拟机的生命周期 一个运行中的Java虚拟机有着一个清晰的任务&#xff1a;执行Java程序。程序开始执行时他才运行&#xff0c;程序结束时他就停止。你在同一台机器上运行三个程序&#xff0c;就会有三个运行中的Java虚拟机。 Java虚拟…

一休哥助手网页版如何使用

一休哥助手网页版可以使用GPT4提问了&#xff0c;具体操作流程如下&#xff1a; 1.登录网页版一休哥助手&#xff08;首次打开页面时&#xff0c;初始化久一点&#xff0c;请耐心等一下&#xff09; https://www.fudai.fun 2.登录后就可以使用GPT4了 3.你还可以自定义系统角色…

vtkBoarderWidget及图片坐标包含计算

开发环境&#xff1a; Windows 11 家庭中文版Microsoft Visual Studio Community 2019VTK-9.3.0.rc0vtk-example demo解决问题&#xff1a;移动图片到坐标轴的中心&#xff0c;创建一个vtkBoarderWidget控件&#xff0c;移动控件&#xff0c;计算控件与图片的包含关系 关键点…

K3s v1.26.0-rc.0-k3s1 部署Harbor私库权限配置

在K3s服务端配置 cat >> /etc/rancher/k3s/registries.yaml <<EOF mirrors: "harbor.baize-k3s.org": endpoint: - "https://harbor.baize-k3s.org" configs: "harbor.baize-k3s.org": auth: username: admin password: Harbor1…

LiveGBS流媒体平台GB/T28181常见问题-基础配置流媒体服务配置中本地|内网IP外网IP(可选)外网IP收流如何配置

LiveGBS常见问题基础配置流媒体服务配置中本地|内网IP外网IP外网IP收流如何配置&#xff1f; 1、流媒体服务配置2、播放提示none rtp data receive3、多网卡服务器4、收流端口配置5、端口区间可以如何配置6、搭建GB28181视频直播平台 1、流媒体服务配置 LiveGBS中基础配置-》流…

ssm在线学习平台-计算机毕业设计源码09650

目 录 摘要 1 绪论 1.1 选题背景及意义 1.2国内外现状分析 1.3论文结构与章节安排 2 在线学习平台系统分析 2.1 可行性分析 2.2 系统业务流程分析 2.3 系统功能分析 2.3.1 功能性分析 2.3.2 非功能性分析 2.4 系统用例分析 2.5本章小结 3 在线学习平台总体设计 …

HCIA-HarmonyOS设备开发认证V2.0-IOT硬件子系统-I2C

目录 一、 I2C 概述二、I2C 模块相关API三、接口调用实例四、I2C HDF驱动开发4.1、开发步骤(待续...) 坚持就有收获 一、 I2C 概述 I2C&#xff08;Inter Integrated Circuit&#xff09;集成电路间总线是由 Philips 公司开发的一种简单、双向二线制同步串行总线。I2C 以主从方…

Unity老项目Android 13支持

Unity老项目Android 13支持 前言 Google官方要求新、老app在一定时间要求内需要面向Android 12、Android 13构建&#xff0c;不然不给app过审。我们之前是面向Android API 30构建的&#xff0c;现在需要支持面向Android API 33构建。 https://developer.android.com/about/ver…

为什么2023年是AI视频的突破年,以及对2024年的预期#a16z

2023年所暴露的AI生成视频的各种问题&#xff0c;大部分被OpenAI发布的Sora解决了吗&#xff1f;以下为a16z发布的总结&#xff0c;在关键之处&#xff0c;我做了OpenAI Sora的对照备注。 推荐阅读&#xff0c;了解视频生成技术进展。 Why 2023 Was AI Video’s Breakout Year,…

怎么清理mac系统缓存系统垃圾文件 ?怎么清理mac系统DNS缓存

很多使用苹果电脑的用户都喜欢在同时运行多个软件&#xff0c;不过这样会导致在运行一些大型软件的时候出现不必要的卡顿现象&#xff0c;这时候我们就可以去清理下内存&#xff0c;不过很多人可能并不知道正确的清内存方式&#xff0c;下面就和小编一起来看看吧。 mac系统是一…

力扣94 二叉树的中序遍历 (Java版本) 递归、非递归

文章目录 题目描述递归解法非递归解法 题目描述 给定一个二叉树的根节点 root &#xff0c;返回 它的 中序 遍历 。 示例 1&#xff1a; 输入&#xff1a;root [1,null,2,3] 输出&#xff1a;[1,3,2] 示例 2&#xff1a; 输入&#xff1a;root [] 输出&#xff1a;[] 示…

chrome版本117驱动下载路,解决版本不匹配问题

&#x1f525; 交流讨论&#xff1a;欢迎加入我们一起学习&#xff01; &#x1f525; 资源分享&#xff1a;耗时200小时精选的「软件测试」资料包 &#x1f525; 教程推荐&#xff1a;火遍全网的《软件测试》教程 &#x1f4e2;欢迎点赞 &#x1f44d; 收藏 ⭐留言 &#x1…

伦敦金适合现在进行投资吗?

伦敦金作为一种贵金属投资品种&#xff0c;近年来在全球范围内受到了越来越多的关注。那么&#xff0c;伦敦金适合现在进行投资吗&#xff1f;在回答这个问题之前&#xff0c;我们先来了解一下什么是伦敦金。 伦敦金&#xff0c;顾名思义&#xff0c;是指在伦敦市场上交易的黄…

小白都能看懂的力扣算法详解——哈希表(一)

&#xff01;&#xff01;本篇所选题目及解题思路均来自​​​​​​代码随想录 (programmercarl.com) 一 LC242.有效的字母异位词 题目要求&#xff1a; 给定两个字符串 s 和 t &#xff0c;编写一个函数来判断 t 是否是 s 的字母异位词。 注意&#xff1a;若 s 和 t 中每个字…

阿里云服务器镜像是什么?如何选择镜像?

阿里云服务器镜像怎么选择&#xff1f;云服务器操作系统镜像分为Linux和Windows两大类&#xff0c;Linux可以选择Alibaba Cloud Linux&#xff0c;Windows可以选择Windows Server 2022数据中心版64位中文版&#xff0c;阿里云服务器网aliyunfuwuqi.com来详细说下阿里云服务器操…

社交商业策略:揭秘Facebook Shops的成功之道

随着数字化时代的不断发展&#xff0c;社交媒体已经成为了商业活动的重要平台之一。在这个趋势下&#xff0c;Facebook作为全球最大的社交媒体平台之一&#xff0c;不仅仅是人们交流互动的场所&#xff0c;更成为了商家开展电子商务的重要渠道。其中&#xff0c;Facebook Shops…

python毕设选题 - 大数据商城人流数据分析与可视化 - python 大数据分析

文章目录 0 前言课题背景分析方法与过程初步分析&#xff1a;总体流程&#xff1a;1.数据探索分析2.数据预处理3.构建模型 总结 最后 0 前言 &#x1f525; 这两年开始毕业设计和毕业答辩的要求和难度不断提升&#xff0c;传统的毕设题目缺少创新和亮点&#xff0c;往往达不到…

如何在Windows 10中停止正在进行的更新?这里有你需要知道的细节

本文介绍如何取消正在进行的Windows更新。说明适用于Windows 10家庭版和专业版。 如何在Windows下载更新后取消更新 如果你还没有完全安装Windows 10更新&#xff0c;但你的电脑已经下载了该文件&#xff0c;并且关闭和重置选项已更改为“更新并关闭”和“更新并重新启动”&a…

EI级 | Matlab实现TCN-GRU-MATT、TCN-GRU、TCN、GRU多变量时间序列预测对比

EI级 | Matlab实现TCN-GRU-MATT、TCN-GRU、TCN、GRU多变量时间序列预测对比 目录 EI级 | Matlab实现TCN-GRU-MATT、TCN-GRU、TCN、GRU多变量时间序列预测对比预测效果基本介绍程序设计参考资料 预测效果 基本介绍 【EI级】Matlab实现TCN-GRU-MATT、TCN-GRU、TCN、GRU多变量时间…