目标检测7-DETR算法剖析与实现

文章目录

  • 端到端目标检测框架DETR
    • 背景介绍
    • 模型结构
    • 模块解析
      • 数据
      • 模型结构
    • 动手实现`DETR`


欢迎访问个人网络日志🌹🌹知行空间🌹🌹


端到端目标检测框架DETR

背景介绍

DETRFacebook AINicolas Carion等于202005月提交的论文中提出的。

论文地址: https://arxiv.org/abs/2005.12872
开源代码: https://github.com/facebookresearch/detr

DETR(DEtection TRansformer)将目标检测问题看成是集合预测的问题,所谓集合预测set prediction是指一次输出一张图像中的所有待检测对象。

DETR使用transformer来做目标检测,直接预测检测框到检测框中心点归一化的距离。在模型训练时,Proposal Assignment使用的算法是一对一的匈牙利算法,通过query的方式获取最后的输出。以上介绍的策略,使得DETR实现了目标检测算法的端到端训练,不需要使用NMS和先验anchor

模型结构

从上面这个图可以看到DETR的架构相当简单,输入一张图像,直接输出的就是所有的检测框,不需要复杂的编解码,不需要NMS

模块解析

数据

官方源码中数据定义在CocoDetection类中,这个类继承自torchvision.datasets.CocoDetection只需要传入COCO格式数据集的图像和json标注文件即可,

COCO格式数据集文件夹路径:

.
├── annotations
│   ├── train.json
│   └── val.json
└── images
    ├── train
    └── val

其中,标签文件bounding box的格式为:

left top width height

CoCoDetection类中有一个self.prepare属性,这是一个函数,其中会将ltwh格式的检测框变换成x1y1x2y2格式的检测框。

DETR源码中使用的变换函数不是从torchvision中导入的,而是自定义的,可以看到在Normalize中,不仅处理了图像数据,还将检测框从x1y1x2y2格式变换成了cxcywh格式,并相对于图像的宽高进行了归一化,其值变换到了[0,1]

class Normalize(object):
    def __init__(self, mean, std):
        self.mean = mean
        self.std = std

    def __call__(self, image, target=None):
        image = F.normalize(image, mean=self.mean, std=self.std)
        if target is None:
            return image, None
        target = target.copy()
        h, w = image.shape[-2:]
        if "boxes" in target:
            boxes = target["boxes"]
            boxes = box_xyxy_to_cxcywh(boxes)
            boxes = boxes / torch.tensor([w, h, w, h], dtype=torch.float32)
            target["boxes"] = boxes
        return image, target

模型结构

DETR的模型结构其实很简单,先是将图像输入到几层卷积神经网络中得到特征图feature map,然后使用src = src.flatten(2).permute(2, 0, 1)将特征图WH维度拉平将图像变换成长度为L=W*H的序列数据。

根据序列的长度和每个Token的通道数生成位置编码。

feature map生成的序列和位置编码信息相加作为transformer的输入src

除了输入的特征序列之外,还输入了图像数据的掩码src_mask。原因是因为一个batch输入的图像宽高不一定相同,源码中的处理方式是取一个batch中尺寸最大的图像尺寸,其余图像往右下方向补0,最后变成尺寸一致的图像用于计算。这是为了避免padding-0参与计算,需要将src_mask输入到transformer中。

DETR使用的位置编码是针对图像的带mask的二维位置编码

class PositionEmbeddingSine(nn.Module):
    """
    This is a more standard version of the position embedding, very similar to the one
    used by the Attention is all you need paper, generalized to work on images.
    """
    def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
        super().__init__()
        self.num_pos_feats = num_pos_feats
        self.temperature = temperature
        self.normalize = normalize
        if scale is not None and normalize is False:
            raise ValueError("normalize should be True if scale is passed")
        if scale is None:
            scale = 2 * math.pi
        self.scale = scale

    def forward(self, tensor_list: NestedTensor):
        x = tensor_list.tensors
        mask = tensor_list.mask
        assert mask is not None
        not_mask = ~mask
        y_embed = not_mask.cumsum(1, dtype=torch.float32)
        x_embed = not_mask.cumsum(2, dtype=torch.float32)
        if self.normalize:
            eps = 1e-6
            y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
            x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale

        dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
        dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)

        pos_x = x_embed[:, :, :, None] / dim_t
        pos_y = y_embed[:, :, :, None] / dim_t
        pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
        pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
        pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
        return pos

其在x/y单个方向上使用位置编码的方法同标准的transformer,然后再将x,y上的两个位置编码分别进行了合并。

DETR源码中使用的transformertorch.nn.Transformer也不太一样

DETRtransformer中将位置编码信息输入到编码器和解码器的每一层,在encoder中将pos加在输入的feature上组成qk

class Encoder:
    ...
    def forward_post(self,
                     src,
                     src_mask: Optional[Tensor] = None,
                     src_key_padding_mask: Optional[Tensor] = None,
                     pos: Optional[Tensor] = None):
        q = k = self.with_pos_embed(src, pos)
        src2 = self.self_attn(q, k, value=src, attn_mask=src_mask,
                              key_padding_mask=src_key_padding_mask)[0]
        src = src + self.dropout1(src2)
        src = self.norm1(src)
        src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
        src = src + self.dropout2(src2)
        src = self.norm2(src)
        return src

decoder中将pos加在了encoder的输出memory作为k的值,query_postgt相加的值作为q来计算多头注意力:

class Decoder:
    def forward_post(self, tgt, memory,
                     tgt_mask: Optional[Tensor] = None,
                     memory_mask: Optional[Tensor] = None,
                     tgt_key_padding_mask: Optional[Tensor] = None,
                     memory_key_padding_mask: Optional[Tensor] = None,
                     pos: Optional[Tensor] = None,
                     query_pos: Optional[Tensor] = None):
        q = k = self.with_pos_embed(tgt, query_pos)
        tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask,
                              key_padding_mask=tgt_key_padding_mask)[0]
        tgt = tgt + self.dropout1(tgt2)
        tgt = self.norm1(tgt)
        tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),
                                   key=self.with_pos_embed(memory, pos),
                                   value=memory, attn_mask=memory_mask,
                                   key_padding_mask=memory_key_padding_mask)[0]
        tgt = tgt + self.dropout2(tgt2)
        tgt = self.norm2(tgt)
        tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
        tgt = tgt + self.dropout3(tgt2)
        tgt = self.norm3(tgt)
        return tgt

DETR中实现的transformer中还将每层decoder输出都保存下来以计算检测框,用来辅助训练


class DETRTransformerDecoder():

    ...
    def forward(self, tgt, memory,
                tgt_mask: Optional[Tensor] = None,
                memory_mask: Optional[Tensor] = None,
                tgt_key_padding_mask: Optional[Tensor] = None,
                memory_key_padding_mask: Optional[Tensor] = None,
                pos: Optional[Tensor] = None,
                query_pos: Optional[Tensor] = None):
        output = tgt

        intermediate = []

        for layer in self.layers:
            output = layer(output, memory, tgt_mask=tgt_mask,
                           memory_mask=memory_mask,
                           tgt_key_padding_mask=tgt_key_padding_mask,
                           memory_key_padding_mask=memory_key_padding_mask,
                           pos=pos, query_pos=query_pos)
            if self.return_intermediate:
                intermediate.append(self.norm(output))

        if self.norm is not None:
            output = self.norm(output)
            if self.return_intermediate:
                intermediate.pop()
                intermediate.append(output)

        if self.return_intermediate:
            return torch.stack(intermediate)

        return output.unsqueeze(0)   

transformer输出的特征输入到计算评分和检测框的两支多层感知积网络中就能预测检测框了:

class DETR:
    ...
    def forward(self, x):
        hs = self.transformer(self.input_proj(src), mask, self.query_embed.weight, pos[-1])[0] # shape: [BATCH, NUM_QUERY, D_MODEL]

        outputs_class = self.class_embed(x)
        outputs_coord = self.box_embed(x).sigmoid()

以上就是模型的整体结构。

模型输出的num_query个预测框和真值框之间的匹配通过匈牙利算法来实现。匈牙利算法会实现预测框和真值框的一对一匹配,避免了对同个对象生成重复的检测框。在使用anchor的检测算法中,为了减轻候选框中正样本和负样本不平衡的问题,通常会使用多个proposal box来预测一个对象,以提升算法的召回率,代价是预测推理时也会对一个对象生成多个预测框,需要使用NMS算法进行处理。

标签匹配使用的代价包括三部分,分别是分类代价,检测框回归相关的L1距离和GIoU

import torch

class HungarianMatcher(torch.nn.Module):
    ...
    @torch.no_grad()
    def forward(self, outputs, targets):
        ...
        cost_class = -out_prob[:, tgt_ids]
        cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1)
        cost_giou = -giou(cxcywh2x1y1x2y2(out_bbox), cxcywh2x1y1x2y2(tgt_bbox))
        all_cost = self.cost_class * cost_class + \
                    self.cost_bbox * cost_bbox + \
                self.cost_giou * cost_giou

最后是模型训练时使用的损失函数,对于目标检测任务,DETRLoss包含2部分,分别是标签类别损失和检测框回归的L1损失和GIoU损失。


loss_ce = torch.nn.functional.cross_entropy(pred_logits.transpose(1, 2),target_classes_all, self.empty_weight)

loss_bbox = torch.nn.functional.l1_loss(src_boxes, target_boxes, reduction='none')
losses = {}

losses["loss_bbox"] = loss_bbox.sum() / num_boxes
loss_giou = 1 - torch.diag(giou(cxcywh2x1y1x2y2(src_boxes),
                            cxcywh2x1y1x2y2(target_boxes)))
losses['loss_giou'] = loss_giou.sum() / num_boxes

动手实现DETR

DETR的架构如此简洁,不需要太多的trick,参考DETR源码,很容易自己动手实现DETR目标检测算法。具体的实现见:

https://gitee.com/lx_r/object_detection_task/tree/main/detection/detr

运行程序会自动生成训练数据开始训练,若平台有GPU会自动调用GPU训练,如果没有GPU会使用CPU训练。

上面的实现中,与原始代码有些许不同:

  • 1)使用的是torch.nn中的transformerpos没有加到encoder的输出memory
  • 2)torch.nn中的transformer只给出了最后一层decoder上的输出,没有给出其他层decoder上的输出,所有没有使用辅助损失训练
  • 3)输入的是相同尺寸的方形图像,没有使用输入掩码



欢迎访问个人网络日志🌹🌹知行空间🌹🌹


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/401117.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

stm32——hal库学习笔记(定时器)

这里写目录标题 一、定时器概述(了解)1.1,软件定时原理1.2,定时器定时原理1.3,STM32定时器分类1.4,STM32定时器特性表1.5,STM32基本、通用、高级定时器的功能整体区别 二、基本定时器&#xff0…

消息队列-RabbitMQ:延迟队列、rabbitmq 插件方式实现延迟队列、整合SpringBoot

十六、延迟队列 1、延迟队列概念 延时队列内部是有序的,最重要的特性就体现在它的延时属性上,延时队列中的元素是希望在指定时间到了以后或之前取出和处理,简单来说,延时队列就是用来存放需要在指定时间被处理的元素的队列。 延…

揭秘离子交换工艺:解决地下水氟超标问题的绿色高效方案

在水处理领域,面对地下水氟化物超标的挑战,各类除氟工艺如活性氧化铝吸附法、电渗析法、反渗透法等各显其能。然而,在综合考虑处理效果、运行成本及环保效益后,离子交换工艺以其独特的技术优势和可持续性脱颖而出,成为…

Python环境下基于门控双注意力机制的滚动轴承剩余使用寿命RUL预测(Tensorflow模块)

机械设备的寿命是其从开始工作持续运行直至故障出现的整个时间段,以滚动轴承为例,其寿命为开始转动直到滚动体或是内外圈等元件出现首次出现故障前。目前主流的滚动轴承RUL预测分类方法包含两种:一是基于物理模型的RUL预测方法,二…

互联网高科技公司领导AI工业化,MatrixGo加速人工智能落地

作者:吴宁川 AI(人工智能)工业化与AI工程化正在引领人工智能的大趋势。AI工程化主要从企业CIO角度,着眼于在企业生产环境中规模化落地AI应用的工程化举措;而AI工业化则从AI供应商的角度,着眼于以规模化方式…

C++面试宝典第31题:有效的数独

题目 判断一个9 x 9的数独是否有效。只需要根据以下规则,验证已经填入的数字是否有效即可。 1、数字1-9在每一行只能出现一次。 2、数字1-9在每一列只能出现一次。 3、数字1-9在每一个以粗实线分隔的3 x 3宫内只能出现一次。 下图是一个部分填充的有效的数独,数独部分空格内已…

FITC Palmitate Conjugate,FITC-棕榈酸酯缀合物,可以用标准 FITC 滤光片组进行成像

FITC Palmitate Conjugate,FITC-棕榈酸酯缀合物,可以用标准 FITC 滤光片组进行成像 您好,欢迎来到新研之家 文章关键词:FITC Palmitate Conjugate,FITC-棕榈酸酯缀合物,FITC 棕榈酸酯缀合物,F…

如何将cocos2d-x js打包部署到ios上 Mac M1系统

项目环境 cocos2d-x 3.13 xcode 12 mac m1 big sur 先找到你的项目 使用xcode软件打开上面这个文件 打开后应该是这个样子 执行编译运行就好了 可能会碰到的错误 在xcode11版本以上都会有这个错误,这是因为iOS11废弃了system。 将上面代码修改为 #if (CC_TARGE…

基于springboot+vue的高校学科竞赛系统(前后端分离)

博主主页:猫头鹰源码 博主简介:Java领域优质创作者、CSDN博客专家、阿里云专家博主、公司架构师、全网粉丝5万、专注Java技术领域和毕业设计项目实战,欢迎高校老师\讲师\同行交流合作 ​主要内容:毕业设计(Javaweb项目|小程序|Pyt…

详细描述一下CrossOver2024版本的用途和作用?

当然可以。CrossOver 是一款由 CODE WEAVERS 公司开发的软件,其主要目标是在 macOS 和 Linux 系统上实现与 Windows 应用程序的兼容性。它不同于传统的虚拟机,如 Parallels 或 VMware,因为它并不在 macOS 上创建一个完整的 Windows 环境。相反…

机房预约系统(个人学习笔记黑马学习)

1、机房预约系统需求 1.1系统简介 学校现有几个规格不同的机房,由于使用时经常出现“撞车“现象,现开发一套机房预约系统,解决这一问题。 1.2身份简介 分别有三种身份使用该程序 学生代表:申请使用机房教师:审核学生的预约申请管理员:给学生、教师创建账…

HarmonyOS开发技术全面分析

系统定义 HarmonyOS 是一款 “ 面向未来 ” 、面向全场景(移动办公、运动健康、社交通信、媒体娱乐等)的分布式操作系统。在传统的单设备系统能力的基础上,HarmonyOS提出了基于同一套系统能力、适配多种终端形态的分布式理念,能够…

网络安全“三保一评”深度解析

“没有网络安全就没有国家安全”。近几年,我国法律法规陆续发布实施,为承载我国国计民生的重要网络信息系统的安全提供了法律保障,正在实施的“3保1评”为我国重要网络信息系统的安全构筑了四道防线。 什么是“3保1评”? 等保、分…

LVGL8.1在Windows显示图片

1、将这些宏的值改成1,以便支持这些格式: 2、 这两个地方: LV_USE_FS_WIN32 设置符号,大小写字母、“\”、“”等符号都可以。 LV_FS_WIN32_PATH 为一个目录,图片放入此目录。 3、载入图片: “M:color.pn…

WebServer -- 定时器处理非活动连接(上)

目录 🍍函数指针 🌼基础知识 🐙整体概述 🎂基础API sigaction 结构体 sigaction() sigfillset() SIGALRM, SIGTERM 信号 alarm() socketpair() send() 📕信号通知流程 统一事件源 信号处理机制 &#x…

书生·浦语大模型实战营第二节课作业

使用 InternLM-Chat-7B 模型生成 300 字的小故事(基础作业1)。 熟悉 hugging face 下载功能,使用 huggingface_hub python 包,下载 InternLM-20B 的 config.json 文件到本地(基础作业2)。 下载过程 进阶…

【医学大模型】大模型 + 长期慢病的预测和管理

大模型 长期慢病的预测和管理 提出背景长期慢病框架慢性疾病检测框架如何实现多提示工程为什么使用多提示 慢性疾病管理框架个性化提示工程医学知识注入 提出背景 论文:https://arxiv.org/abs/2401.12988 慢性疾病是指那些需要长期管理和治疗的疾病,包…

# CCF系列会议截稿时间订阅

[晓理紫]CCF系列会议截稿时间订阅 VX关注{晓理紫}免费,每日更新最新CCF系列会议信息,如感兴趣,请转发给有需要的同学,谢谢支持!! VX关注{晓理紫}免费 NETYS (Non-CCF) The International Conference on Networked Systems Deadline: Fri Mar 8th 2024 19:59:00 CST (2…

navicat连接云服务器(宝塔)

下面介绍两种navicat连接云服务器(宝塔)的方法 一、通过ssh配置(安全) 打开navicat,配置新链接的SSH(主机:填写公网IP,用户名和密码是服务器的账号密码) 在常规填写数据…

Android Studio创建项目时gradle下载慢

先停止当前Sync,找到gradle-wrapper.properties文件,将distributionUrl修改为腾讯镜像源: distributionUrlhttps\://mirrors.cloud.tencent.com/gradle/gradle-6.5-bin.zip
最新文章