DETR整体模型结构解析

DETR流程

  1. Backbone用卷积神经网络抽特征。最后通过一层1*1卷积转化到d_model维度fm(B,d_model,HW)。

  2. position embedding建立跟fm维度相同的位置编码(B,d_model,HW)。

  3. Transformer Encoder,V为fm,K,Q为fm+position embedding。因为V代表的是图像特征。所以不添加位置编码

  4. Transformer Decoder。生成一个固定大小(query_num)的object query(B,q_num,d_model)比如100个预测框。Decoder输入tgt与object query形状相同。代码中为torch.zero()。第一层selfattention K,V为tgt+query,Q为tgt。第二层Q为上一层输出+query。V为encoder输出,K为encoder输出+position。这里V仍然代表图像特征所以不添加位置编码

  5. 用输出的100个object query框和ground truth框做一个匹配,然后在一一配对好的框中去计算目标检测的loss(分类loss与回归loss(L1+IOU))

  6. 二分图匹配与匈牙利算法

    DETR 预测了一组固定大小的 N = 100 个边界框

    将 ground-truth 也扩展成 N = 100 个检测框

    使用一个额外的特殊类标签 ϕ 来表示在未检测到任何对象,或者认为是背景类别。

    这样预测和真实都是两个100 个元素的集合了

    采用匈牙利算法进行二分图匹配,对预测集合和真实集合的元素进行一一对应,使得匹配损失最小。

  7. 推理过程不需要二分图匹配,只需要取最大得分框即可

代码详细参考:

transformer 在 CV 中的应用(二) DETR 目标检测网络 -

网络结构

参数说明:B:batchsize大小,C通道数,H,W:CNN输出特征图的高宽。d_model设定的特征维度大小如512。
Q,K,V:自注意力矩阵。l_q:Q矩阵的长度,l_kv:K,V矩阵的长度。KV矩阵的长度必须相同,Q矩阵长度可以跟KV矩阵长度不同
Q矩阵维度:(B,l_q,d_model)
K矩阵维度:(B,l_kv,d_model)
V矩阵维度:(B,l_kv,d_model)
object_query维度(B,q_num,d_model)

Backbone:

img→CNNbackbone→fm特征图(B,C,H,W) → fm特征图输入到transformer中时要再经过一层卷积将通道数转化成d_ model。C→d_model.

position embedding(B,d_model,H*W)。backbone通过CNN提取图像特征,然后通过特征图生成尺度对应的位置编码。

position embedding:

位置编码官方实现了两种,一种是固定位置编码,另一种是自学习位置编码,这里就介绍固定位置编码。

位置编码要考虑 x, y 两个方向,图像中任意一个点 (h, w) 有一个位置,这个位置编码长度为 256 ,前 128 维代表 h 的位置编码, 后 128 维代表 w 的位置编码,把这两个 128 维的向量拼接起来就得到一个 256 维的向量,它代表 (h, w) 的位置编码。位置编码的计算公式如下图所示

在这里插入图片描述

Transformer
DETRtransformer结构图
在这里插入图片描述

接受CNN提取的特征(B,d_model,HW),位置编码(B,d_model,HW),querys(B,query_num,d_model)

encoder:q,k添加位置编码。v代表图像本身特征,不添加位置编码。multi_head_attention跟FFN后都带了两个残差连接。

# post代表norm放在后面
def forward_post(self,src,src_mask: Optional[Tensor] = None,src_key_padding_mask: Optional[Tensor] = None,pos: Optional[Tensor] = None):q = k = self.with_pos_embed(src, pos)  #q,k增加positionsrc2 = self.self_attn(q, k, value=src, attn_mask=src_mask,key_padding_mask=src_key_padding_mask)[0]src = src + self.dropout1(src2)   # 残差src = self.norm1(src)# ffnsrc2 = self.linear2(self.dropout(self.activation(self.linear1(src))))# 残差src = src + self.dropout2(src2)src = self.norm2(src)return src

decoder:

设定一个object queries(num_query,d_model)

有两层multihead self attention

  • 第一层obquery添加到K,Q上作为position embedding

第二层的Q来于decoder,K,V来自于encoder输出。

第二层self attention K添加编码,Q增加object queries。V代表图像特征,不添加任何信息

KV要有相同维度,Q可以跟KV在长度维度上不同,d_model维度相同

softmax(QKt/(d^0.5))V→矩阵乘法Q*Kt:(l_q,d_model)@(d_model_l_kv)→(l_q,l_kv)

再乘以V(l_q,l_kv)@(l_kv,d_model)→(l_q,d_model)

def forward_post(self, tgt, memory,tgt_mask: Optional[Tensor] = None,memory_mask: Optional[Tensor] = None,tgt_key_padding_mask: Optional[Tensor] = None,memory_key_padding_mask: Optional[Tensor] = None,pos: Optional[Tensor] = None,query_pos: Optional[Tensor] = None):''':param tgt: query_pos rep 2tensor of shape (bs, c, h, w) ->tgt = torch.zeros_like(query_embed):param memory::param tgt_mask::param memory_mask::param tgt_key_padding_mask::param memory_key_padding_mask::param pos::param query_pos::return:'''q = k = self.with_pos_embed(tgt, query_pos)tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask,key_padding_mask=tgt_key_padding_mask)[0]tgt = tgt + self.dropout1(tgt2)tgt = self.norm1(tgt)tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),key=self.with_pos_embed(memory, pos),value=memory, attn_mask=memory_mask,key_padding_mask=memory_key_padding_mask)[0]tgt = tgt + self.dropout2(tgt2)tgt = self.norm2(tgt)tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))tgt = tgt + self.dropout3(tgt2)tgt = self.norm3(tgt)return tgt
  • DETR在计算attention的时候没有使用masked attention,因为将特征图展开成一维以后,所有像素都可能是互相关联的,因此没必要规定mask。

  • object queries的转换过程:object queries是预定义的目标查询的个数,代码中默认为100。它的意义是:根据Encoder编码的特征,Decoder将100个查询转化成100个目标,即最终预测这100个目标的类别和bbox位置。最终预测得到的shape应该为[N, 100, C],N为Batch Num,100个目标,C为预测的100个目标的类别数+1(背景类)以及bbox位置(4个值)

得到预测结果以后,将object predictions和ground truth box之间通过匈牙利算法进行二分匹配:假如有K个目标,那么100个object predictions中就会有K个能够匹配到这K个ground truth,其他的都会和“no object”匹配成功,使其在理论上每个object query都有唯一匹配的目标,不会存在重叠,所以DETR不需要nms进行后处理。

匹配

匈牙利匹配算法

匈牙利匹配算法,二分图匹配算法

scipy.optimize.linear_sum_assignment(cost_matrix, maximize=False)
#cost_matrix 二分图开销矩阵

https://blog.csdn.net/CV_Autobot/article/details/129096035

https://blog.csdn.net/lemonxiaoxiao/article/details/108672039

query与gt匹配

transformer通过query输出n_q数量的bbox与对应分类置信度

真实框[gt1,gt2,…gtn]

每个bbox与gt之间有一个距离度量。

距离度量由三部分组成:真实类别的置信度得分+边界框的L1loss+边界框的IOUloss

通过匈牙利算法找出距离最小的query_bbox为gt对应的prebbox

loss训练

整体流程

pred输出→100(num_queries)class,100(num_queries)boxes

gt(tagert)→100class,100boxes(包含背景类)

pred,gt→计算相互loss,得到二分图成本矩阵,然后计算匈牙利匹配算法→return匹配上的classes与boxes

匹配成功的框→计算真正的class损失(),box回归损失(GLOUloss)。。

预测框与真实框的差异来自于两方面:1.二分图匹配时带来的差异。2.预测框与真实框之间的差异。

  • 分类损失:交叉熵损失,针对所有predictions。没有匹配到的querybbox应该分类为背景

  • 回归损失:bbox loss采用了L1 loss和giou loss,针对匹配成功的querybbox

  • cardinality 损失,对应函数是 loss_cardinality ; cardinality 损失是计算预测有物体的个数的绝对损失,值是为了记录,不参与反向传播

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/18805.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

非量表题如何进行信效度分析

效度是指设计的题确实在测量某个东西,一般问卷中使用到。如果是量表类的数据,其一般是用因子分析这种方法去验证效度水平,其可通过因子分析探究各测量量表的内部结构情况,分析因子分析得到的内部结构与自己预期的内部结构进行对比…

大模型预训练结果到底是什么?

近日参加一个线下 AI 交流会议,会上有个非本行业的老师提问:“大家说的训练好的大模型到底是什么?是像 Word 软件一样可以直接使用的程序吗?” 这个问题看似简单,却一下把我问住了。的确,我们这些身处 AI 领…

Kafka原生API使用Java代码-生产者-发送消息

文章目录 1、生产者发送消息1.1、使用EFAK创建主题my_topic31.2、根据kafka官网文档写代码1.3、pom.xml1.4、KafkaProducer1.java1.5、使用EFAK查看主题1.6、再次运行KafkaProducer1.java1.7、再次使用EFAK查看主题 1、生产者发送消息 1.1、使用EFAK创建主题my_topic3 1.2、根…

STM32 OTA需要注意问题

一、OTA设计思路(问题) 1、根据stm32f405 flash分布,最初将flash划分为四个区域,分别是Bootloader、APP1、APP2、参数区,设备上电后,进入Bootloader程序,判断OTA参数,根据参数来确定…

APP逆向之调试的开启

很基础的一个功能设置,大佬轻喷。 背景 在开始进行对APP逆向分析的时候,需要对APP打开调试模式。 打开调试的模式有多种方式可以通过直接改包方式也可以通过借助第三方工具进行打开调试模式。 下面就整理下这个打开调试模式的一些方式。 改包修改模…

Java面试题分享-敏感词替换 java 版本

入职啦最近更新了一些后端笔试、面试题目,大家看看能快速实现吗? 关注 入职啦 微信公众号,每日更新有用的知识,Python,Java,Golang,Rust,javascript 等语言都有 不要再用replaceAll做…

DNF手游攻略:开荒必备攻略!

DNF手游马上就要开服了,今天给大家带来最完整的DNF手游入门教程。这篇攻略主要讲述了 DNF手游开服第一天要注意的事项,这是一个新手必备的技能书,可以让你在开服的时候,少走一些弯路,让你更快完成任务!废话…

蓝牙Mesh模块多跳大数据量高带宽传输数据方法

随着物联网技术的飞速发展,越来越多的设备需要实现互联互通。蓝牙Mesh网络作为一种低功耗、高覆盖、易于部署的无线通信技术,已经成为物联网领域中的关键技术之一。在蓝牙Mesh网络中,节点之间可以通过多个跳数进行通信,从而实现大…

【OrangePi AIpro】香橙派 AIpro 为AI而生

产品简介 OrangePi AIpro(8T):定义边缘智能新纪元的全能开发板 在当今人工智能与物联网技术融合发展的浪潮中,OrangePi AIpro(8T)凭借其强大的硬件配置与全面的接口设计,正逐步成为开发者手中的创新利器。这款开发板不仅代表了香橙派与华为…

最新淘宝死店全自动采集私信筛选脚本,号称日赚500+【采集软件+使用教程】

原理: 利用脚本自动采集长时间未登录店铺,然后脚本自动私信对应的店铺,看看商家是不是不回消息来判断是否是死店,再下单购买死店的产品,超过48小时不发货就可以联系客服获得赔付,一单利润百分之5%-30%&…

配置阿里yum源

配置阿里yum源(这个很重要):https://developer.aliyun.com/article/1480470 1.备份系统自带yum源配置文件 mv /etc/yum.repos.d/CentOS-Base.repo /etc/yum.repos.d/CentOS-Base.repo.backup2.下载ailiyun的yum源配置文件 2.1 CentOS7 wge…

Ansible03-Ansible Playbook剧本详解

目录 写在前面5. Ansible Playbook 剧本5.1 YAML语法5.1.1 语法规定5.1.2 示例5.1.3 YAML数据类型 5.2 Playbook组件5.3 Playbook 案例5.3.1 Playbook语句5.3.2 Playbook1 分发hosts文件5.3.3 Playbook2 分发软件包,安装软件包,启动服务5.3.3.1 任务拆解…

5.28.1 使用卷积神经网络检测乳腺癌

深度学习技术正在彻底改变医学图像分析领域,因此在本研究中,我们提出了卷积神经网络 (CNN) 用于乳腺肿块检测,以最大限度地减少手动分析的开销。CNN 架构专为特征提取阶段而设计,并采用了更快的 R-CNN 的区域提议网络 (RPN) 和感兴…

py黑帽子学习笔记_scapy

简介 代码简洁:相比于前两个博客总结,很多socket操作,如果使用scapy仅需几行代码即可实现 获取邮箱身份凭证 编写基础嗅探器,脚本可显示任何收到的一个包的详细情况 直接运行 尝试监听邮件收发,监听指定端口&#x…

NTP服务的DDoS攻击:原理和防御

NTP协议作为一种关键的互联网基础设施组件,旨在确保全球网络设备间的时钟同步,对于维护数据一致性和安全性至关重要。然而,其设计上的某些特性也为恶意行为者提供了发动大规模分布式拒绝服务(DDoS)攻击的机会。以下是NTP服务DDoS攻击及其防御…

【深度学习实战—9】:基于MediaPipe的坐姿检测

✨博客主页:王乐予🎈 ✨年轻人要:Living for the moment(活在当下)!💪 🏆推荐专栏:【图像处理】【千锤百炼Python】【深度学习】【排序算法】 目录 😺一、Med…

5个免费下载音乐的网站,喜欢听什么就搜什么

以下5个音乐下载网站,中国人不骗中国人,全部免费。个个曲库丰富,喜欢听什么就搜什么,还能下载mp3格式,点赞收藏即刻拥有! 1、MyFreeMP3 tools.liumingye.cn/music/ MyFreeMP3是一个提供音乐播放和下载服…

富凡行是什么软件,来具体聊一聊它的详情,感兴趣的不要错过了

目前做网络项目的人很多,也就衍生出了很多的软件、项目、平台。接触过了很多的产品,感触颇深,确实市面上的东西差别都很大,有好的,有不好的。 我也是喜欢在网上做点副业,自己捣鼓一下,毕竟互联网…

2024-5-29 石群电路-17

2024-5-29,星期三,17:26,天气:晴,心情:晴.今天又是阳光明媚的一天,没有什么特别的事情发生,给女朋友做了好吃的,吃了西瓜,加油学习,嘻嘻嘻~~~~ 今…

四川易点慧电商抖音小店信誉之店

在当下这个电商飞速发展的时代,如何在众多网店中挑选出一家既可靠又值得信赖的店铺,成为了消费者们关注的焦点。四川易点慧电子商务有限公司抖音小店以其卓越的品质和诚信的经营,逐渐在抖音平台上崭露头角,成为了众多消费者心中的…