【代码实现】DETR原文解读及代码实现细节

1 模型总览

在这里插入图片描述

宏观上来说,DETR主要包含三部分:以卷积神经网络为主的骨干网(CNN Backbone)、以TRM(Transformer)为主的特征抽取及交互器以及以FFN为主的分类和回归头,如DETR中build()函数所示。DETR最出彩的地方在于,它摒弃了非端到端的处理过程,如NMS、anchor generation等,以集合预测的方式来端到端建模目标检测过程,并且将Transformer引入到目标检测中,打开新领域的大门)。

def build(args):backbone = build_backbone(args)transformer = build_transformer(args)model = DETR(backbone,# 骨干网transformer,# 重点部分num_classes=81,num_queries=100,# object query数量,作用相当于spatial embeddingaux_loss=args.aux_loss)matcher = build_matcher(args)# 二分图匹配weight_dict = {'loss_ce': 1, 'loss_bbox': args.bbox_loss_coef}weight_dict['loss_giou'] = args.giou_loss_coeflosses = ['labels', 'boxes', 'cardinality']postprocessors = {'bbox': PostProcess()}return model, criterion, postprocessors

2 DETR的基本流程

  • CNN backbone 提取图像的 feature
  • Transformer Encoder 通过 self-attention 建模全局关系对 feature 进行增强
  • Transformer Decoder 的输入是 object queries(spatial embedding) 和 Transformer encoder 的输出(content embedding),主要包含 self-attention 和 cross-attention 的过程。Self-attention 主要是对每个 query 之间做交互,让每个 query 能看到其他 query 在查询什么东西,从而不重复,类似与 NMS 的作用;Cross- attention 主要是将 object query 当做查询,encoder feature 当做 key,为了查询和 query 有关的区域。
  • 对 Decoder 输出的查询好了的 query,使用 FFN 提取出目标框的位置和类别信息

顺着上边的基本流程,从代码入手一点点理解原文的思想,下面开始!

3 backbone

首先是构建backbone模块的函数

def build_backbone(args):position_embedding = build_position_encoding(args)backbone = Backbone(args.backbone, train_backbone, return_interm_layers, args.dilation)model = Joiner(backbone, position_embedding)model.num_channels = backbone.num_channelsreturn model

3.1 根据特征图生成位置编码

def build_position_encoding(args):N_steps = args.hidden_dim // 2if args.position_embedding in ('v2', 'sine'):# TODO find a better way of exposing other argumentsposition_embedding = PositionEmbeddingSine(N_steps, normalize=True)elif args.position_embedding in ('v3', 'learned'):position_embedding = PositionEmbeddingLearned(N_steps)else:raise ValueError(f"not supported {args.position_embedding}")return position_embedding

对于输入特征x,假定其尺寸为BxCxHxW,位置编码需要在H W两个维度上进行位置编码,所以一般会将hidden_dim(C, 通道)切分为两部分,一部分代表H另一部分代表W,最后在通道维度上进行拼接。DETR中位置编码主要是sine和learning position embedding。我自己仿照TRM写了一个可学习位置编码的实现,可以运行试试

# 验证learnable pos embedding机制
x = torch.randn((8, 3, 32, 32))
h,w=x.shape[-2:]
row_embed,col_embed=nn.Embedding(50,256),nn.Embedding(50,256)i,j=torch.arange(w,device=x.device),torch.arange(h,device=x.device)
x_emb,y_emb=col_embed(i),row_embed(j)
x_cat=x_emb.unsqueeze(0).repeat(h,1,1)
y_cat=y_emb.unsqueeze(1).repeat(1,w,1)pos=torch.cat([x_cat,y_cat],dim=-1)
pos_learn=pos.permute(2,0,1).unsqueeze(0).repeat(x.shape[0],1,1,1)# shape:(8,512,32,32)

构建backbone

这部分不太难,相关注释已经写在代码块中

class BackboneBase(nn.Module):def __init__(self, backbone: nn.Module, train_backbone: bool, num_channels: int, return_interm_layers: bool):super().__init__()for name, parameter in backbone.named_parameters():if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:parameter.requires_grad_(False)if return_interm_layers:# 是否返回中间层,在多尺度融合操作时会用到return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"}else:# 一般情况下只返回最后一层的输出return_layers = {'layer4': "0"}# IntermediateLayerGetter作用类似于Sequential,将多个神经层组合并可以指定返回中间输出self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)self.num_channels = num_channelsdef forward(self, tensor_list: NestedTensor):xs = self.body(tensor_list.tensors)# 将数据传入网络,实例化网络得到输出,xs即为经过resnet四部分后的输出out: Dict[str, NestedTensor] = {}# 定义输出格式for name, x in xs.items():# 如果返回中间层,out可以按照name存储,返回最后一层则只有layer4m = tensor_list.maskassert m is not Nonemask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]# 利用插值法产生不同尺度下的maskout[name] = NestedTensor(x, mask)return out

将上述二者对应组合

class Joiner(nn.Sequential):def __init__(self, backbone, position_embedding):super().__init__(backbone, position_embedding)def forward(self, tensor_list: NestedTensor):xs = self[0](tensor_list)# [0]代表backbone的输出out: List[NestedTensor] = []pos = []for name, x in xs.items():out.append(x)# position encodingpos.append(self[1](x).to(x.tensors.dtype))# [1]代表position_embeddingreturn out, pos# 返回抽取后的特征及对应的位置编码

Transformer

在这里插入图片描述
TRM部分其实跟Attention is all you need的模型结构完全相同,不同的部分只是decoder部分输入。在原始transformer中,decoder的输入是对应目标序列融合位置编码后的embedding,而在本文中,则使用初始化为全0的tgt作为目标序列,然后再融合query_embed。这里非常容易混淆的一点是:tgt全零序列才是content embedding, 代码中的query_embed是代表目标框集合位置的spatial embedding
由于encoder部分主要是特征提取,对边界定位影响不大,所以我们考虑decoder的cross-attention部分,其输入主要包括三部分:query key value

  • queries:每个 query 都是 decoder 第一层 self-attention 的输出( content query )+ object query( spatial query ),这里的 object query 就是 DETR 中提出的概念,每个 object query 都是候选框的信息,经过 FFN 后能输出位置和类别信息(本文 object query 个数 N 为 100)
  • keys:每个 key 都是 encoder 的输出特征( content key ) + 位置编码( spatial key )构成
  • values:只有来自 encoder 的输出
class Transformer(nn.Module):def __init__(self, d_model=512, nhead=8, num_encoder_layers=6,num_decoder_layers=6, dim_feedforward=2048, dropout=0.1,activation="relu", normalize_before=False,return_intermediate_dec=False):super().__init__()encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,dropout, activation, normalize_before)encoder_norm = nn.LayerNorm(d_model) if normalize_before else Noneself.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward,dropout, activation, normalize_before)decoder_norm = nn.LayerNorm(d_model)self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm,return_intermediate=return_intermediate_dec)self.d_model = d_modelself.nhead = nheaddef forward(self, src, mask, query_embed, pos_embed):# flatten NxCxHxW to HWxNxCbs, c, h, w = src.shapesrc = src.flatten(2).permute(2, 0, 1)# flatten(k)表示将[k:n-1]拉平为一个维度pos_embed = pos_embed.flatten(2).permute(2, 0, 1)# sine# num_queries x hidden_dim to num_queries x N x hidden_dimquery_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)# NxHxW to NxHWmask = mask.flatten(1)# decoder embedding,初始化为全0tgt = torch.zeros_like(query_embed)# encoder特征抽取,得到memory,shape同tgtmemory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)# decoder特征交互,得到hshs = self.decoder(tgt, memory, memory_key_padding_mask=mask,pos=pos_embed, query_pos=query_embed)return hs.transpose(1, 2), memory.permute(1, 2, 0).view(bs, c, h, w)

DETR

class DETR(nn.Module):""" This is the DETR module that performs object detection """def __init__(self, backbone, transformer, num_classes, num_queries, aux_loss=False):super().__init__()self.num_queries = num_queries# object query numsself.transformer = transformerhidden_dim = transformer.d_model# 隐层维度self.class_embed = nn.Linear(hidden_dim, num_classes + 1)# 分类头,最后的类别为:类别数+背景self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)# 检测头,使用三层全连接层进行映射,最后投影到xywhself.query_embed = nn.Embedding(num_queries, hidden_dim)# object queryself.input_proj = nn.Conv2d(backbone.num_channels, hidden_dim, kernel_size=1)# 将得到的feature map通道归一self.backbone = backboneself.aux_loss = aux_lossdef forward(self, samples: NestedTensor):if isinstance(samples, (list, torch.Tensor)):samples = nested_tensor_from_tensor_list(samples)features, pos = self.backbone(samples)# 得到的feature可能是C3-C5几层,DETR只拿最后一层输入TRMsrc, mask = features[-1].decompose()assert mask is not None# self.transformer()[0]表示取dncoder的输出,序列1表示encoder输出hs = self.transformer(self.input_proj(src), mask, self.query_embed.weight, pos[-1])[0]outputs_class = self.class_embed(hs)outputs_coord = self.bbox_embed(hs).sigmoid()out = {'pred_logits': outputs_class[-1], 'pred_boxes': outputs_coord[-1]}if self.aux_loss:out['aux_outputs'] = self._set_aux_loss(outputs_class, outputs_coord)return out

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/67891.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

[羊城杯 2020] easyphp

打开题目&#xff0c;源代码 <?php$files scandir(./); foreach($files as $file) {if(is_file($file)){if ($file ! "index.php") {unlink($file);}}}if(!isset($_GET[content]) || !isset($_GET[filename])) {highlight_file(__FILE__);die();}$content $_GE…

Spring 6.0和SpringBoot 3.0新特性

目录 主要更新内容是以下几个&#xff1a; AOT编译 Spring Native GraalVM SpringBoot3生成二进制可执行文件底层流程 主要更新内容是以下几个&#xff1a; A Java 17 baselineSupport for Jakarta EE 10 with an EE 9 baselineSupport for generating native images with…

【Sentinel】核心API-Entry与Context

文章目录 一、Entry1、Entry的声明2、使用API自定义资源3、基于SentinelResource注解标记资源 二、Context1、Context介绍2、Context的初始化3、AbstractSentinelInterceptor4、ContextUtil 一、Entry 1、Entry的声明 默认情况下&#xff0c;Sentinel会将controller中的方法作…

46、TCP的“三次握手”

在上一节中&#xff0c;TCP首部常用的几个选项&#xff0c;有些选项的参数就是在通信双方在建立TCP连接的时候进行确定和协商的。所以在学习过TCP报文首部之后&#xff0c;下面我们开始学习TCP的连接建立。 TCP的一个特点是提供可靠的传输机制&#xff0c;还有一个特点就是TCP…

Spring MVC 五 - DispatcherServlet初始化过程(续)

今天的内容是SpringMVC的初始化过程&#xff0c;其实也就是DispatcherServilet的初始化过程。 Special Bean Types DispatcherServlet委托如下一些特殊的bean来处理请求、并渲染正确的返回。这些特殊的bean是Spring MVC框架管理的bean、按照Spring框架的约定处理相关请求&…

leetcode56. 合并区间(java)

合并区间 题目描述贪心算法代码演示 题目描述 难度 - 中等 leetcode56. 合并区间 以数组 intervals 表示若干个区间的集合&#xff0c;其中单个区间为 intervals[i] [starti, endi] 。请你合并所有重叠的区间&#xff0c;并返回 一个不重叠的区间数组&#xff0c;该数组需恰好…

Elasticsearch 对比传统数据库:深入挖掘 Elasticsearch 的优势

当你为项目选择数据库或搜索引擎时&#xff0c;了解每个选项的细微差别至关重要。 今天&#xff0c;我们将深入探讨 Elasticsearch 的优势&#xff0c;并探讨它与传统 SQL 和 NoSQL 数据库的比较。 1. Elasticsearch简介 Elasticsearch 以强大的 Apache Lucene 库为基础&#…

算法通关村第9关【白银】| 二分查找与搜索树高频问题

基于二分查找的拓展问题 1.山脉数组的峰顶索引 思路&#xff1a;二分查找 山峰有三种状态&#xff1a;需要注意数组边界 1.顶峰&#xff1a;arr[mid]>arr[mid1]&&arr[mid]>arr[mid-1] 2.上坡&#xff1a;arr[mid]<arr[mid1] 3.下坡&#xff1a;arr[mid]…

l8-d6 socket套接字及TCP的实现框架

一、socket套接字 /*创建套接字*/ int socket(int domain, int type, int protocol); /*绑定通信结构体*/ int bind(int sockfd, const struct sockaddr *addr, socklen_t addrlen); /*监听套接字*/ int listen(int sockfd, int backlog); /*处理客户端发起的连接&#xff0…

智能合约安全,著名的区块链漏洞:双花攻击

智能合约安全&#xff0c;著名的区块链漏洞&#xff1a;双花攻击 介绍: 区块链技术通过提供去中心化和透明的系统彻底改变了各个行业。但是&#xff0c;与任何技术一样&#xff0c;它也不能免受漏洞的影响。一个值得注意的漏洞是双花攻击。在本文中&#xff0c;我们将深入研究…

【数据结构练习】栈的面试题集锦

目录 前言&#xff1a; 1.进栈过程中可以出栈的选择题 2.将递归转化为循环 3.逆波兰表达式求值 4.有效的括号 5. 栈的压入、弹出序列 6. 最小栈 前言&#xff1a; 数据结构想要学的好&#xff0c;刷题少不了&#xff0c;我们不仅要多刷题&#xff0c;还要刷好题&#x…

16-MyCat

一 Mycat概述 1 什么是Mycat 什么是Mycat Mycat是数据库中间件&#xff0c;所谓数据库中间件是连接Java应用程序和数据库中间的软件。 为什么要用Mycat 遇到问题&#xff1a; Java与数据库的紧耦合高访问量高并发对数据库的压力读写请求数据不一致 2 Mycat与其他中间件区别 目…

【USRP】调制解调系列6:16APSK、32APSK 、基于labview的实现

APSK APSK是&#xff0c;与传统方型星座QAM&#xff08;如16QAM、64QAM&#xff09;相比&#xff0c;其分布呈中心向外沿半径发散&#xff0c;所以又名星型QAM。与QAM相比&#xff0c;APSK便于实现变速率调制&#xff0c;因而很适合目前根据信道及业务需要分级传输的情况。当然…

分布式环境下的数据同步

一般而言elasticsearch负责搜索&#xff08;查询&#xff09;&#xff0c;而sql数据负责记录&#xff08;增删改&#xff09;&#xff0c;elasticsearch中的数据来自于sql数据库&#xff0c;因此sql数据发生改变时&#xff0c;elasticsearch也必须跟着改变&#xff0c;这个就是…

jmeter调试错误大全

一、前言 在使用jmeter做接口测试的过程中大家是不是经常会遇到很多问题&#xff0c;但是无从下手&#xff0c;不知道从哪里开始找起&#xff0c;对于初学者而言这是一个非常头痛的事情。这里结合笔者的经验&#xff0c;总结出以下方法。 二、通过查看运行日志调试问题 写好…

STM32存储左右互搏 I2C总线读写FRAM MB85RC16

STM32存储左右互搏 I2C总线读写FRAM MB85RC16 在较低容量存储领域&#xff0c;除了EEPROM的使用&#xff0c;还有铁电存储器FRAM的使用&#xff0c;相对于EEPROM, 同样是非易失性存储单元&#xff0c;FRAM支持更高的访问速度&#xff0c; 其主要优点为没有EEPROM持续写操作跨页…

Python虚拟环境venv下安装playwright介绍及记录

playwright介绍 Playwright是一个用于自动化Web浏览器测试和Web数据抓取的开源库。它由Microsoft开发&#xff0c;支持Chrome、Firefox、Safari、Edge和WebKit浏览器。Playwright的一个主要特点是它能够在所有主要的操作系统&#xff08;包括Windows、Linux和macOS&#xff09…

计算机毕设 大数据商城人流数据分析与可视化 - python 大数据分析

文章目录 0 前言课题背景分析方法与过程初步分析&#xff1a;总体流程&#xff1a;1.数据探索分析2.数据预处理3.构建模型 总结 最后 0 前言 &#x1f525; 这两年开始毕业设计和毕业答辩的要求和难度不断提升&#xff0c;传统的毕设题目缺少创新和亮点&#xff0c;往往达不到…

SMU200A/罗德与施瓦茨SMU200A信号发生器

181/2461/8938产品概述 R&S SMU200A信号发生器旨在满足现代通信系统研发及其生产中遇到的所有要求。R&S SMU200A矢量信号发生器不仅将多达两个独立的信号发生器组合在一个只有四个高度单位的机柜中&#xff0c;还提供无与伦比的RF和基带特性。 Rohde & Schwarz S…

Vue3数值动画(NumberAnimation)

效果如下图&#xff1a;在线预览 APIs 参数说明类型默认值必传from数值动画起始数值number0falseto数值目标值number1000falseduration数值动画持续时间&#xff0c;单位msnumber3000falseautoplay是否自动开始动画booleantruefalseprecision精度&#xff0c;保留小数点后几位…