Conditional DETR解读---带anchor的DETR

DETR存在的问题

1.收敛速度慢

2.对小目标物体检测效果不好,因为transformer计算量大,受限于计算规模,CNN提取特征时只采取了最后一层特征,没有用FPN等结构。所以对于小目标检测效果不好。

论文主要观点

  • 通过对DETRdecoder中的attentionmap进行可视化,发现query查询到的区域都是物体的extremity末端区域。所以论文认为attention尝试找到物体的边界区域。

  • 论文中认为DETRtransofmer结构中的信息主要可以分为两部分,一部分是与图像的特征(颜色纹理等)相关的信息,称为content,比如encoder或decoder的输出信息。另一部分是代表空间上的信息,称为spatial,比如position embedding等。

  • detr中的CNN与encoder只涉及图像特征向量提取;decoder中的self-attn只涉及query之间的交互去重;所以收敛慢的最可能原因发生在cross attn

  • Cross attention中的K包含encoder输出信息(content key Ck)与position embedding(spatial Key Pk),Q包含self attention的输出(content query Cq)和object query(spatial query Pq)信息。论文中发现去掉cross attention中的object基本不掉点,所以收敛慢很可能是content query难学习导致的。

  • 提出了reference point的概念,为每个query设定一个检测范围,使得匹配更加稳定,加快了收敛

  • 原始detr混合两者学习,使得content query难学习。所以将content与spatial进行解耦

在这里插入图片描述

变为

在这里插入图片描述

网络结构

在这里插入图片描述

对于object query生成了一个2D坐标embedding(上图中的s),用于限定当前query的预测范围。最终decoder的输出的是相对与s的偏移量

bbox回归输出

在这里插入图片描述

其中f是decoer的输出,S表示x,y的坐标。最终b是[x,y,w,h]的向量。

classifier分类输出

在这里插入图片描述

f是decoder的输出,输出每个候选框的类别

decoder Pq生成:

提出了reference point的概念,即图中的s,是一个2d的坐标(q_num,B,2),由object queries经过一个线性层生成,代表了每个query的预测范围。

s经过sigmoid和position embedding后(图中的Ps),跟FFN(decoder embedding)(即图中的T)做内积。得到空间特征Pq

在这里插入图片描述

在这里插入图片描述

代码spatial query这一部分的实现:

# query_pos [num_query,batch,d_model]
# reference_points_before_sigmoid [num_query,batch,2]  从query预测一个坐标,代表了这个query预测的大概范围
reference_points_before_sigmoid = self.ref_point_head(query_pos)    # [num_queries, batch_size, 2]
reference_points = reference_points_before_sigmoid.sigmoid().transpose(0, 1)
for layer_id, layer in enumerate(self.layers):# 图里的s,代表了query的预测大概范围obj_center = reference_points[..., :2].transpose(0, 1)      # [num_queries, batch_size, 2]# For the first decoder layer, we do not apply transformation over p_s## pos_transformation代表图里的T,表示decoder embedding的特征经过ffn后其实得到的是相对于s的偏移量if layer_id == 0:pos_transformation = 1else:pos_transformation = self.query_scale(output)# get sine embedding for the query vectorquery_sine_embed = gen_sineembed_for_position(obj_center)     # apply transformation# 最终的Pq,代表空间特征信息query_sine_embed = query_sine_embed * pos_transformationoutput = layer(output, memory, tgt_mask=tgt_mask,memory_mask=memory_mask,tgt_key_padding_mask=tgt_key_padding_mask,memory_key_padding_mask=memory_key_padding_mask,pos=pos, query_pos=query_pos, query_sine_embed=query_sine_embed,is_first=(layer_id == 0))

decoder中cross attention的实现


# ========== Begin of Cross-Attention =============
# Apply projections here
# shape: num_queries x batch_size x 256
q_content = self.ca_qcontent_proj(tgt)
k_content = self.ca_kcontent_proj(memory)
v = self.ca_v_proj(memory)num_queries, bs, n_model = q_content.shape
hw, _, _ = k_content.shape# k的位置编码
k_pos = self.ca_kpos_proj(pos)# For the first decoder layer, we concatenate the positional embedding predicted from 
# the object query (the positional embedding) into the original query (key) in DETR.
if is_first:q_pos = self.ca_qpos_proj(query_pos)q = q_content + q_posk = k_content + k_pos
else:q = q_contentk = k_contentq = q.view(num_queries, bs, self.nhead, n_model//self.nhead)
query_sine_embed = self.ca_qpos_sine_proj(query_sine_embed)
query_sine_embed = query_sine_embed.view(num_queries, bs, self.nhead, n_model//self.nhead)
# decoder embedding cat spatial query
q = torch.cat([q, query_sine_embed], dim=3).view(num_queries, bs, n_model * 2)
k = k.view(hw, bs, self.nhead, n_model//self.nhead)
# encoder embdeding cat position embedding
k_pos = k_pos.view(hw, bs, self.nhead, n_model//self.nhead)
k = torch.cat([k, k_pos], dim=3).view(hw, bs, n_model * 2)tgt2 = self.cross_attn(query=q,key=k,value=v, attn_mask=memory_mask,key_padding_mask=memory_key_padding_mask)[0]               
# ========== End of Cross-Attention =============

head的实现

# hs代表decoder embedding,reference代表s(reference point)
hs, reference = self.transformer(self.input_proj(src), mask, self.query_embed.weight, pos[-1])
reference_before_sigmoid = inverse_sigmoid(reference)
outputs_coords = []
for lvl in range(hs.shape[0]):# 回归head hs输出相对于 reference的偏移量,得到检测框tmp = self.bbox_embed(hs[lvl])tmp[..., :2] += reference_before_sigmoidoutputs_coord = tmp.sigmoid()outputs_coords.append(outputs_coord)
outputs_coord = torch.stack(outputs_coords)
#分类head,hs输出分类结果
outputs_class = self.class_embed(hs)

总结思考

实际上conditional DETR有点像transfoermer版本的faster-RCNN。将特征信息与空间信息进行了解耦。reference point像anchor的概念,让网络自己为每个query设定一个anchor范围,从而使得二分匹配更加问题,所以加快了网络的收敛

作者论文解读:https://zhuanlan.zhihu.com/p/401916664
公式解释得更加详细

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/845223.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

芋道系统,springboot+vue3+mysql实现地址的存储与显示

1.效果图 2.前端实现&#xff1a; <el-form-item label"地址" prop"entrepriseAddress"><el-cascaderv-model"formData.entrepriseAddress"size"large":options"region"/></el-form-item> //导入组件 im…

【JMeter接口自动化】第7讲 Jmeter三个重要组件

线程组:是JMeter中最基本的元素之一&#xff0c;用于模拟并发用户访问目标系统。线程组定义了测试计划中的用户数量、用户行为和用户请求之间的关系。 添加方法:测试计划->添加->线程(用户)->线程组 在线程组中&#xff0c;您可以设置以下参数&#xff1a; 线程数&a…

一种改进的形态学滤波算法-以心电信号的基线校正和噪声抑制为例(MATLAB环境)

信号在釆集和传输过程中难免受到噪声源的干扰&#xff0c;反映非线性动力学行为的特征信息有可能被噪声所掩盖。尤其是在混沌振动信号噪声抑制方面&#xff0c;因为混沌信号的高度非线性及宽频特性&#xff0c;噪声和混沌信号往往具有重叠的带宽。传统的时域及频域降噪方法效果…

神经网络-------人工神经网络

一、什么是神经网络和神经元 人工神经网络&#xff08;英语&#xff1a;Artificial Neural Network&#xff0c;ANN&#xff09;&#xff0c;简称 神经网络&#xff08;Neural Network&#xff0c;NN&#xff09;或 类神经网络&#xff0c;是一种模仿生物神经网络&#xff08;…

AI实时免费在线图片工具3:人物换脸、图像编辑

1、FaceAdapter 人物换脸 https://huggingface.co/spaces/FaceAdapter/FaceAdapter 2、InstaDrag https://github.com/magic-research/InstaDrag

计算机网络之快重传和快恢复以及TCP连接与释放的握手

快重传和快恢复 快重传可以让发送方尽早得知丢失消息&#xff0c; 当发送消息M1,M2&#xff0c;M3,M4,M5后,假如消息M2丢失&#xff0c;那么按照算法会发送对M2报文前一个报文M1的重复确认&#xff08;M1正常接受到&#xff0c;已经发送了确认),然后之后收到M4,M5,也会发送两…

做项目时,怎么运用 SWOT 分析法进行项目或决策分析?

SWOT分析法是一种常用的战略工具&#xff0c;用于评估项目或决策的优势、劣势、机会和威胁。以下是在项目或决策分析中如何运用SWOT分析法的一般步骤&#xff1a; 步骤1&#xff1a;明确分析的目标 在进行SWOT分析之前&#xff0c;首先要明确分析的目标是什么。你可能想要分析…

element-plus关于表单数据自定义参数校验

element-plus关于表单数据自定义参数校验 核心点&#xff1a; 代码&#xff1a; <el-form-item :prop"tableData[ scope.$index ].score":rules"[{ required: true, message: 得分不能为空, trigger: blur },{ validator: (rule: any, value: any, ca…

【Python】解决Python报错:AttributeError: ‘generator‘ object has no attribute ‘xxx‘

&#x1f9d1; 博主简介&#xff1a;阿里巴巴嵌入式技术专家&#xff0c;深耕嵌入式人工智能领域&#xff0c;具备多年的嵌入式硬件产品研发管理经验。 &#x1f4d2; 博客介绍&#xff1a;分享嵌入式开发领域的相关知识、经验、思考和感悟&#xff0c;欢迎关注。提供嵌入式方向…

网络分层与各层网络协议介绍

一.OSI七层模型 1.OSI&#xff08;Open Systems Interconnection&#xff09;七层模型是由国际标准化组织&#xff08;ISO&#xff09;提出的一种网络通信协议的参考模型&#xff0c;用于标准化网络通信的过程。 OSI模型将网络通信分为七个层次&#xff0c;每个层次负责不同的…

Java集合-List(Collection子接口)及其子类(ArrayList、Vector、LinkedList)

List接口是 Collection接口的子接口。 1、List集合类中数据有序&#xff0c; 即添加顺序和取出顺序有序&#xff0c;而且可以重复。 2、List集合类中每个元素都有其对应的顺序索引&#xff0c;即支持索引。例&#xff0c;list.get(2)&#xff1b;取第三个元素。 3、实现类有很多…

家政预约小程序10公众号集成

目录 1 使用测试号3 工作流配置4 配置关注事件脚本5 注册开放平台6 获取公众号access_token6 实现关注业务逻辑总结 我们本次实战项目构建的相当于一个预约平台&#xff0c;既有家政企业&#xff0c;也有家政服务人员还有用户。不同的人员需要收到不同的消息&#xff0c;比如用…

99.网络游戏逆向分析与漏洞攻防-ui界面的设计-角色信息显示的界面与功能

免责声明&#xff1a;内容仅供学习参考&#xff0c;请合法利用知识&#xff0c;禁止进行违法犯罪活动&#xff01; 如果看不懂、不知道现在做的什么&#xff0c;那就跟着做完看效果&#xff0c;代码看不懂是正常的&#xff0c;只要会抄就行&#xff0c;抄着抄着就能懂了 内容…

机器人学导论P115求雅可比矩阵python实现

代码如下&#xff1a; import numpy as np import matplotlib.pyplot as plt import seaborn as sns import plotly.express as px import plotly.graph_objects as go from plotly.subplots import make_subplots from numpy import sin from numpy import cos plt.rcParams[…

【考试100】安全员B证《建设工程安全生产技术》单选题

​ 题库来源&#xff1a;考试100 【考试100】安全员B证《建设工程安全生产技术》单选题 1&#xff0e;在悬空部位作业时&#xff0c;操作人员应&#xff08; &#xff09; A.遵守操作规定 B.进行安全技术交底 C.戴好安全帽 D.系好安全带 【考试100答案】&#xff1a;D…

【R基础】如何开始学习R-从下载R及Rstudio开始

文章目录 概要下载R流程下载Rstudio流程下载完成-打开 概要 提示&#xff1a;如何开始学习R-从下载R及Rstudio开始&#xff0c;此处我只是想下载指定版本R4.3.3 下载R流程 链接: R官网 文件下载到本地 下载文件展示 按照向导指示安装 下载Rstudio流程 链接: Rstudio官网…

低代码与人工智能的深度融合:行业应用的广泛前景

引言 在当今快速变化的数字化时代&#xff0c;企业面临着越来越多的挑战和机遇。低代码平台和人工智能技术的兴起&#xff0c;为企业提供了新的解决方案&#xff0c;加速了应用开发和智能化转型的步伐。 低代码平台的基本概念及发展背景 低代码平台是一种软件开发方法&#x…

解决MYSQL5.7版本only_full_group_by报错解决方法

问题 出现this is incompatible with sql_modeonly_full_group_by这个语句就说明启动了only_full_group_by规则了 介绍only_full_group_by规则&#xff1a; 这种情况可能是5.7版本的规则比较严格&#xff0c;当启用“only_full_group_by”模式时&#xff0c;MySQL会对执行GROU…

SpringBoot中MyBatisPlus的使用

MyBatis Plus 是 MyBatis 的增强工具&#xff0c;提供了许多强大的功能&#xff0c;简化了 MyBatis 的使用。下面是在 Spring Boot 中使用 MyBatis Plus 的步骤&#xff1a; 添加依赖&#xff1a;在 Maven 或 Gradle 的配置文件中添加 MyBatis Plus 的依赖。 配置数据源&#…

Day10:平面转换、渐变色

目标&#xff1a;使用位移、缩放、旋转、渐变效果丰富网页元素的呈现方式。 一、平面转换 1、简介 作用&#xff1a;为元素添加动态效果&#xff0c;一般与过渡配合使用。 概念&#xff1a;改变盒子在平面内的形态&#xff08;位移、旋转、缩放、倾斜&#xff09;。 平面转换…