计算机视觉目标检测-DETR网络

目录

  • 摘要
  • abstract
  • DETR目标检测网络详解
    • 二分图匹配和损失函数
  • DETR总结
  • 总结

摘要

DETR(DEtection TRansformer)是由Facebook AI提出的一种基于Transformer架构的端到端目标检测方法。它通过将目标检测建模为集合预测问题,摒弃了锚框设计和非极大值抑制(NMS)等复杂后处理步骤。DETR使用卷积神经网络提取图像特征,并将其通过位置编码转换为输入序列,送入Transformer的Encoder-Decoder结构。Decoder通过固定数量的目标查询(Object Queries),预测类别和边界框位置。DETR创新性地引入匈牙利算法进行二分图匹配,确保预测与真实值的唯一对应关系,且采用交叉熵损失和L1-GIoU损失进行优化。在COCO数据集上的实验表明,DETR在大目标检测中表现优异,并能灵活迁移到其他任务,如全景分割。

abstract

DETR (DEtection TRansformer) is an end-to-end target detection method based on Transformer architecture proposed by Facebook AI. By modeling object detection as a set prediction problem, it eliminates complex post-processing steps such as anchor frame design and non-maximum suppression (NMS). DETR uses convolutional neural networks to extract image features and convert them via positional encoding into input sequences that feed into Transformer’s Encoder-Decoder structure. Decoder predicts categories and bounding box positions with a fixed number of Object Queries. DETR innovates by introducing the Hungarian algorithm for bipartite graph matching to ensure a unique relationship between the prediction and the true value, and optimizes with cross-entropy losses and L1-GIoU losses. Experiments on the COCO dataset show that DETR performs well in large target detection and can be flexibly migrated to other tasks, such as panoramic segmentation.

下图是目标检测中检测器模型的发展:
在这里插入图片描述

DETR目标检测网络详解

DETR(DEtection TRansformer)是由Facebook AI在2020年提出的一种基于Transformer架构的端到端目标检测方法。与传统的目标检测方法(如Faster R-CNN、YOLO等)不同,DETR直接将目标检测建模为一个集合预测问题,摆脱了锚框设计和复杂的后处理(如NMS)。结果在 COCO 数据集上效果与 Faster RCNN 相当,在大目标上效果比 Faster RCNN 好,且可以很容易地将 DETR 迁移到其他任务例如全景分割。
在这里插入图片描述
简单来说,就是通过CNN提取图像特征(通常 Backbone 的输出通道为 2048,图像高和宽都变为了 1/32),并经过input embedding+positional encoding操作转换为图像序列(如下图所说,就是类似[N, HW, C]的序列)作为transformer encoder的输入,得到了编码后的图像序列,在图像序列的帮助下,将object queries(下图中说的是固定数量的可学习的位置embeddings)转换/预测为固定数量的类别+bbox预测。相当于Transformer本质上起了一个序列转换的作用。
在这里插入图片描述
下图为DETR的详细结构:
在这里插入图片描述
DETR中的encoder-decoder与transformer中的encoder-decoder对比:

  1. spatial positional encoding:新提出的二维空间位置编码方法,该位置编码分别被加入到了encoder的self attention的QK和decoder的cross attention的K,同时object queries也被加入到了decoder的两个attention(第一个加到了QK中,第二个加入了Q)中。而原版的Transformer将位置编码加到了input和output embedding中。
  2. DETR在计算attention的时候没有使用masked attention,因为将特征图展开成一维以后,所有像素都可能是互相关联的,因此没必要规定mask。
  3. object queries的转换过程:object queries是预定义的目标查询的个数,代码中默认为100。它的意义是:根据Encoder编码的特征,Decoder将100个查询转化成100个目标,即最终预测这100个目标的类别和bbox位置。最终预测得到的shape应该为[N, 100, C],N为Batch Num,100个目标,C为预测的100个目标的类别数+1(背景类)以及bbox位置(4个值)。
  4. 得到预测结果以后,将object predictions和ground truth box之间通过匈牙利算法进行二分匹配:假如有K个目标,那么100个object predictions中就会有K个能够匹配到这K个ground truth,其他的都会和“no object”匹配成功,使其在理论上每个object query都有唯一匹配的目标,不会存在重叠,所以DETR不需要nms进行后处理。
  5. 分类loss采用的是交叉熵损失,针对所有predictions;bbox loss采用了L1 loss和giou loss,针对匹配成功的predictions。

匈牙利算法是用于解决二分图匹配的问题,即将Ground Truth的K个bbox和预测出的100个bbox作为二分图的两个集合,匈牙利算法的目标就是找到最大匹配,即在二分图中最多能找到多少条没有公共端点的边。匈牙利算法的输入就是每条边的cost 矩阵
在这里插入图片描述

二分图匹配和损失函数

思考
DETR 预测了一组固定大小的 N = 100 个边界框,这比图像中感兴趣的对象的实际数量大得多。怎么样来计算损失呢?或者说预测出来的框我们怎么知道对应哪一个 ground-truth 的框呢?

为了解决这个问题,第一步是将 ground-truth 也扩展成 N = 100 个检测框。使用了一个额外的特殊类标签 ϕ \phiϕ 来表示在未检测到任何对象,或者认为是背景类别。这样预测和真实都是两个100 个元素的集合了。这时候采用匈牙利算法进行二分图匹配,即对预测集合和真实集合的元素进行一一对应,使得匹配损失最小。
σ ^ = arg ⁡ min ⁡ G ∈ G N ∑ i N L m a t c h ( y i , y ^ σ ( i ) ) \hat{\sigma}=\arg\min_{\mathrm{G\in G_N}}\sum_{\mathrm{i}}^{\mathrm{N}}\mathcal{L}_{\mathrm{match}}\left(\mathrm{y_i},\hat{\mathrm{y}}_{\mathrm{\sigma(i)}}\right) σ^=argGGNminiNLmatch(yi,y^σ(i))
L m a t c h ( y i , y ^ σ ( i ) ) = − 1 { c i ≠ ∅ } p ^ σ ( i ) ( c i ) + 1 { c i ≠ ∅ } L b o x ( b i , b ^ σ ( i ) ) \mathcal{L}_{\mathrm{match}}\left(\mathrm{y_i},\hat{\mathrm{y}}_{\mathrm{\sigma(i)}}\right)=-1_{\{\mathrm{c_i}\neq\varnothing\}}\hat{\mathrm{p}}_{\mathrm{\sigma(i)}}\left(\mathrm{c_i}\right)+1_{\{\mathrm{c_i}\neq\varnothing\}}\mathcal{L}_{\mathrm{box}}\left(\mathrm{b_i},\hat{\mathrm{b}}_{\mathrm{\sigma(i)}}\right) Lmatch(yi,y^σ(i))=1{ci=}p^σ(i)(ci)+1{ci=}Lbox(bi,b^σ(i))
对于那些不是背景的,获得其对应的预测是目标类别的概率,然后用框损失减去预测类别概率。这也就是说不仅框要近,类别也要基本一致,是最好的。经过匈牙利算法之后,我们就得到了 ground truth 和预测目标框之间的一一对应关系。然后就可以计算损失函数了。

下面是利用pytorch实现DETR的代码:
位置编码部分:

class PositionalEncoding(nn.Module):def __init__(self, d_model, max_len=5000):super().__init__()pe = torch.zeros(max_len, d_model)position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-torch.log(torch.tensor(10000.0)) / d_model))pe[:, 0::2] = torch.sin(position * div_term)pe[:, 1::2] = torch.cos(position * div_term)pe = pe.unsqueeze(0).transpose(0, 1)self.register_buffer('pe', pe)def forward(self, x):return x + self.pe[:x.size(0), :]

用于为序列数据(如Transformer中的输入)添加位置信息。位置编码帮助模型保留序列中元素的位置信息,这是因为Transformer模型本身不具备位置信息感知能力。
使用正弦和余弦函数优点
优点:
正弦和余弦具有周期性和平滑性;
不同维度具有不同频率,编码了多尺度的位置信息。
作用:保留序列的位置信息,使模型能够感知数据的顺序。

编码可视化结果:

import matplotlib.pyplot as pltimport torch
import torch.nn as nn# 位置编码
class PositionalEncoding(nn.Module):def __init__(self, d_model, max_len=5000):super().__init__()pe = torch.zeros(max_len, d_model)position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-torch.log(torch.tensor(10000.0)) / d_model))pe[:, 0::2] = torch.sin(position * div_term)pe[:, 1::2] = torch.cos(position * div_term)pe = pe.unsqueeze(0).transpose(0, 1)self.register_buffer('pe', pe)def forward(self, x):return x + self.pe[:x.size(0), :]pe = PositionalEncoding(d_model=16, max_len=100)
x = torch.zeros(100, 1, 16)
encoded = pe(x).squeeze(1).detach().numpy()plt.figure(figsize=(10, 5))
plt.imshow(encoded, aspect='auto', cmap='viridis')
plt.colorbar(label='Encoding Value')
plt.xlabel('Dimension')
plt.ylabel('Position')
plt.title('Positional Encoding Visualization')
plt.show()

在这里插入图片描述
上图反应以下几点变化
不同维度的变化

  1. 低频维度(如 d=0,1):颜色变化缓慢,代表位置之间编码的相似性较高,捕捉全局信息。
  2. 高频维度(如 d=14,15):颜色变化迅速,代表位置之间编码差异较大,捕捉局部信息。

同一位置的编码:
值的分布(正弦和余弦的相互作用)保证了每个位置在多维空间中具有唯一性。

时间步的相对差异:
相邻位置(如第1和第2位置)在高维上的值差异较大,这为模型提供了感知时间步变化的能力。

encoder-decoder:

class Transformer(nn.Module):def __init__(self, d_model=256, nhead=8, num_encoder_layers=6, num_decoder_layers=6):super().__init__()self.encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead)self.decoder_layer = nn.TransformerDecoderLayer(d_model=d_model, nhead=nhead)self.encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=num_encoder_layers)self.decoder = nn.TransformerDecoder(self.decoder_layer, num_layers=num_decoder_layers)def forward(self, src, tgt, src_mask=None, tgt_mask=None):memory = self.encoder(src, mask=src_mask)output = self.decoder(tgt, memory, tgt_mask=tgt_mask)return output

DETR模型:

# DETR模型
class DETR(nn.Module):def __init__(self, num_classes, num_queries, backbone='resnet50'):super().__init__()self.num_queries = num_queries# Backboneself.backbone = models.resnet50(pretrained=True)self.conv = nn.Conv2d(2048, 256, kernel_size=1)# Transformerself.transformer = Transformer(d_model=256)self.query_embed = nn.Embedding(num_queries, 256)self.positional_encoding = PositionalEncoding(256)# Prediction headsself.class_embed = nn.Linear(256, num_classes + 1)  # +1 for no-object classself.bbox_embed = nn.Linear(256, 4)def forward(self, images):# Feature extractionfeatures = self.backbone(images)features = self.conv(features)h, w = features.shape[-2:]# Flatten and add positional encodingsrc = features.flatten(2).permute(2, 0, 1)  # (HW, N, C)src = self.positional_encoding(src)# Query embeddingquery_embed = self.query_embed.weight.unsqueeze(1).repeat(1, images.size(0), 1)  # (num_queries, N, C)# Transformerhs = self.transformer(src, query_embed)# Predictionoutputs_class = self.class_embed(hs)outputs_coord = self.bbox_embed(hs).sigmoid()  # Normalized to [0, 1]return {'pred_logits': outputs_class, 'pred_boxes': outputs_coord}

DETR总结

DETR通过Transformer实现端到端的目标检测,无需(如NMS)复杂的后处理。相比传统检测器,DETR具有简洁的架构和强大的全局建模能力,但训练时对数据和计算资源的需求较高。

总结

DETR简化了目标检测的流程,摒弃了传统检测器中繁琐的锚框设计和后处理步骤,架构更简洁,且依托于Transformer的全局建模能力,在捕捉长距离特征关系方面表现出色。相比传统方法,DETR在目标数量固定的场景下,能够更高效地处理目标检测任务。其优点包括易迁移、多任务适用性和端到端优化能力,但其劣势在于训练时间较长、计算资源消耗较大,尤其是在小目标检测和训练数据量不足的情况下效果略显不足。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/65645.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Vim Masterclass 笔记09】S06L22:Vim 核心操作训练之 —— 文本的搜索、查找与替换操作(第一部分)

文章目录 S06L22 Search, Find, and Replace - Part One1 从光标位置起,正向定位到当前行的首个字符 b2 从光标位置起,反向查找某个字符3 重复上一次字符查找操作4 定位到目标字符的前一个字符5 单字符查找与 Vim 命令的组合6 跨行查找某字符串7 Vim 的增…

springboot 默认的 mysql 驱动版本

本案例以 springboot 3.1.12 版本为例 <parent><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-parent</artifactId><version>3.1.12</version><relativePath/> </parent> 点击 spring-…

计算机网络(二)——物理层和数据链路层

一、物理层 1.作用 实现相信计算机节点之间比特流的透明传输&#xff0c;尽可能屏蔽具体传输介质和物理设备的差异。 2.数据传输单位 比特。 3.相关通信概念 ①信源和信宿&#xff1a;即信号的发送方和接收方。 ②数据&#xff1a;即信息的实体&#xff0c;比如图像、视频等&am…

sql server cdc漏扫数据

SQL Server的CDC指的是“变更数据捕获”&#xff08;Change Data Capture&#xff09;。这是SQL Server数据库提供的一项功能&#xff0c;能够跟踪并记录对数据库表中数据所做的更改。这些更改包括插入、更新和删除操作。CDC可以捕获这些变更的详细信息&#xff0c;并使这些信息…

AI数字人+文旅:打造数字文旅新名片

在数字化浪潮的推动下&#xff0c;人工智能技术正以前所未有的速度渗透到我们生活的每一个角落。特别是在文化和旅游领域&#xff0c;AI数字人的出现&#xff0c;不仅为传统文旅产业注入了新的活力&#xff0c;也为游客带来了全新的体验。 肇庆AI数字人——星湖 “星湖”是肇…

做一个 简单的Django 《股票自选助手》显示 用akshare 库(A股数据获取)

图&#xff1a; 股票自选助手 这是一个基于 Django 开发的 A 股自选股票信息查看系统。系统使用 akshare 库获取实时股票数据&#xff0c;支持添加、删除和更新股票信息。 功能特点 支持添加自选股票实时显示股票价格和涨跌幅一键更新所有股票数据支持删除不需要的股票使用中…

Protobuf编码规则详解

Protobuf编码规则详解 1 Message 结构1.1 tag1.1.1 字段编号(field_num)1.1.2 传输类型(wire_type) 1.2 字段顺序1.3 默认值 2 编码2.1 Varint编码2.1.1 Varint编码过程2.1.2解码过程2.1.3 存储2.1.4 小结2.2 有符号整数(sint32和sint64)编码的问题与zigzag优化 3 编码实践3.1测…

【docker】exec /entrypoint.sh: no such file or directory

dockerfile生成的image 报错内容&#xff1a; exec /entrypoint.sh: no such file or directory查看文件正常在此路径&#xff0c;但是就是报错没找到。 可能是因为sh文件的换行符使用了win的。

计算机的错误计算(二百零七)

摘要 利用两个数学大模型计算 arccot(0.125664e2)的值&#xff0c;结果保留16位有效数字。 实验表明&#xff0c;它们的输出中分别仅含有3位和1位正确数字。 例1. 计算 arccot(0.125664e2)的值&#xff0c;结果保留16位有效数字。 下面是与一个数学解题器的对话。 以上为与…

MCANet: 基于多模态字幕感知的大语言模型训练无关视频异常检测

目录 摘要01 引言02 相关工作2.1 视频异常检测2.2 基于视频的大语言模型&#xff08;VLLMs&#xff09; 03 方法论3.1 问题定义3.2 MCANet3.3 图像字幕分支3.4 音频字幕分支3.5 基于LLM的异常评分3.6 视频-文本分数优化 04 实验4.1 数据集和评估指标4.2 实现细节4.3 定性结果4.…

WMS仓库管理系统,Vue前端开发,Java后端技术源码(源码学习)

一、项目背景和建设目标 随着企业业务的不断扩展&#xff0c;仓库管理成为影响生产效率、成本控制及客户满意度的重要环节。为了提升仓库作业的透明度、准确性和效率&#xff0c;本方案旨在构建一套全面、高效、易用的仓库管理系统&#xff08;WMS&#xff09;。该系统将涵盖库…

【Uniapp-Vue3】创建自定义页面模板

大多数情况下我们都使用的是默认模板&#xff0c;但是默认模板是Vue2格式的&#xff0c;如果我们想要定义一个Vue3模板的页面就需要自定义。 一、我们先复制下面的模板代码&#xff08;可根据自身需要进行修改&#xff09;&#xff1a; <template><view class"…

【Go】:图片上添加水印的全面指南——从基础到高级特性

前言 在数字内容日益重要的今天&#xff0c;保护版权和标识来源变得关键。为图片添加水印有助于声明所有权、提升品牌认知度&#xff0c;并防止未经授权的使用。本文将介绍如何用Go语言实现图片水印&#xff0c;包括静态图片和带旋转、倾斜效果的文字水印&#xff0c;帮助您有…

springCloudGateWay使用总结

1、什么是网关 功能: ①身份认证、权限验证 ②服务器路由、负载均衡 ③请求限流 2、gateway搭建 2.1、创建一个空项目 2.2、引入依赖 2.3、加配置 3、断言工厂 4、过滤工厂 5、全局过滤器 6、跨域问题

【UE5 C++课程系列笔记】22——多线程基础——FRunnable和FRunnableThread

目录 1、FRunnable 1.1 概念 1.2 主要成员函数 &#xff08;1&#xff09;Init 函数 &#xff08;2&#xff09;Run 函数 &#xff08;3&#xff09;Stop 函数 &#xff08;4&#xff09;Exit 函数 2、FRunnableThread 2.1 概念 2.2 主要操作 &#xff08;1&#xff…

《图解HTTP》 学习日记

1.了解WEB以及网络基础 1.1使用HTTP协议访问WEB web页面显示:根据web浏览器地址栏中输入指定的URL,web浏览器从web服务端获取文件资源(resource)等信息&#xff0c;从而显示出web页面 1.2网络基础TCP/IP 通常使用的网络(包括 互联网)是在tcp/ip协议族的基础上运作的&#xf…

【Docker】docker compose 安装 Redis Stack

注&#xff1a;整理不易&#xff0c;请不要吝啬你的赞和收藏。 前文 Redis Stack 什么是&#xff1f; 简单来说&#xff0c;Redis Stack 是增强版的 Redis &#xff0c;它在传统的 Redis 数据库基础上增加了一些高级功能和模块&#xff0c;以支持更多的使用场景和需求。Redis…

kubesphere前端源码运行

一、下载源码 源码是react&#xff0c;下载地址是 GitHub - kubesphere/console at v3.3.2 然后直接用git下拉就可以了 下拉完成后差不多是这样一个目录结构&#xff0c;记得切分支到3.3.2 二、下载依赖 1、node & yurn 想要运行源码首先需要node&#xff0c;使用刚才…

蓝桥杯历届真题 #分布式队列 (Java,C++)

文章目录 题目解读[蓝桥杯 2024 省 Java B] 分布式队列题目描述输入格式输出格式样例 #1样例输入 #1样例输出 #1 提示 思路完整代码 题目解读 题目链接 [蓝桥杯 2024 省 Java B] 分布式队列 题目描述 小蓝最近学习了一种神奇的队列&#xff1a;分布式队列。简单来说&#x…

PySide6 Qt for Python Qt Quick参考网址

Qt QML BOOK&#xff1a; 《Qt for Python》 -Building an Application https://www.qt.io/product/qt6/qml-book/ch19-python-build-app#signals-and-slots Qt for Python&#xff1a;与C版本的差异即BUG处理&#xff08;常见的DLL文件确实的问题等&#xff09; Qt for Pyt…