Transformer实战-系列教程16:DETR 源码解读3(DETR类)

🚩🚩🚩Transformer实战-系列教程总目录

有任何问题欢迎在下面留言
本篇文章的代码运行界面均在Pycharm中进行
本篇文章配套的代码资源已经上传
点我下载源码

DETR 算法解读
DETR 源码解读1(项目配置/CocoDetection类)
DETR 源码解读2(ConvertCocoPolysToMask类)
DETR 源码解读3(DETR类)
DETR 源码解读4(Joiner类/PositionEmbeddingSine类/位置编码/backbone)

4、DETR类

位置:models/detr.py/DETR类

4.1 构造函数

class DETR(nn.Module):def __init__(self, backbone, transformer, num_classes, num_queries, aux_loss=False):super().__init__()self.num_queries = num_queriesself.transformer = transformerhidden_dim = transformer.d_modelself.class_embed = nn.Linear(hidden_dim, num_classes + 1)self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)self.query_embed = nn.Embedding(num_queries, hidden_dim)self.input_proj = nn.Conv2d(backbone.num_channels, hidden_dim, kernel_size=1)self.backbone = backboneself.aux_loss = aux_loss
  1. DETR类继承torch nn.Module
  2. 构造函数,传入5个参数:
    • backbone:CNN骨架网络,用于特征提取
    • transformer:Transformer模型,用于处理序列数据
    • num_classes:目标类别的数量
    • num_queries:解码器初始化生成的100个向量的个数,num_queries=100
    • aux_loss:一个布尔值,指示是否使用辅助损失来帮助训练
  3. 初始化
  4. num_queries
  5. transformer
  6. hidden_dim ,Transformer中的隐藏层维度
  7. class_embed ,类别预测的输出层,这个全连接层是接Transformer的输出,类别加1是额外的无类别对象
  8. bbox_embed,一个MLP,也是接Transformer的输出,边界框的四个坐标的回归
  9. query_embed ,解码器的初始100个向量
  10. input_proj ,一个1x1的二维卷积,使得backbone的输出通道数映射到与Transformer隐藏层维度相同
  11. backbone,一个预训练的卷积神经网络,主要作用是提取图像的特征,它的输出经过input_proj 处理后作为Transformer的输入
  12. aux_loss,保存是否使用辅助损失的标志

这里包含了几个自定义函数和类:
nested_tensor_from_tensor_list函数:将不同尺寸处理的图像Tensor转换为一个嵌套Tensor
MLP类:边界框的四个坐标的回归
transformer类:构建transformer架构
backbone:用于提取图像特征的CNN

4.2 前向传播

    def forward(self, samples: NestedTensor):if isinstance(samples, (list, torch.Tensor)):samples = nested_tensor_from_tensor_list(samples)features, pos = self.backbone(samples)src, mask = features[-1].decompose()assert mask is not Nonehs = self.transformer(self.input_proj(src), mask, self.query_embed.weight, pos[-1])[0]outputs_class = self.class_embed(hs)outputs_coord = self.bbox_embed(hs).sigmoid()out = {'pred_logits': outputs_class[-1], 'pred_boxes': outputs_coord[-1]}if self.aux_loss:out['aux_outputs'] = self._set_aux_loss(outputs_class, outputs_coord)return out    
  1. 前向传播函数,输入为samples=NestedTensor{mask={Tensor(2,771,911)},tensors={Tensor(2,3,771,911)}}
  2. 检查samples是否为列表或Tensor类型
  3. samples ,如果,使用nested_tensor_from_tensor_list函数转换为NestedTensor
  4. features, pos,图像特征图列表和对应的位置编码列表,backbone实际上一个现在的resnet
  5. src, mask,解构最后一层的特征,获取源数据和掩码,src:torch.Size([2, 2048, 21, 18]),mask torch.Size([2, 21, 18]),2是batch,2048是特征维度,后面两个是图像长宽
  6. 确保掩码不为空
  7. 将数据通过Transformer处理,获取序列输出,torch.Size([6, 2, 100, 256]),6是Transformer的堆叠层数,2是batch,100是生成100个目标预测,256是每个目标预测的维度
  8. outputs_class ,获取类别预测
  9. outputs_coord ,获取边界框坐标预测,并使用sigmoid函数将输出值限制在0到1之间
  10. out ,将类别预测结果和 边界框坐标预测结果做成一个字典
  11. 如果启用了辅助损失
  12. 通过辅助函数_set_aux_loss计算辅助损失
  13. 返回out

4.3 辅助函数_set_aux_loss()

@torch.jit.unuseddef _set_aux_loss(self, outputs_class, outputs_coord):return [{'pred_logits': a, 'pred_boxes': b}for a, b in zip(outputs_class[:-1], outputs_coord[:-1])]
  1. @torch.jit.unused:一个装饰器,指示当使用TorchScript编译模型时,该方法不应被编译。这是因为辅助损失的计算可能不兼容TorchScript的静态图特性
  2. 定义函数,接收类别预测和边界框坐标作为输入
  3. 返回一个列表,将每一个类别预测和边界框坐标都封装成一个字典,这样,训练过程中可以计算每一层的损失,从而实现辅助损失的目的

DETR 算法解读
DETR 源码解读1(项目配置/CocoDetection类)
DETR 源码解读2(ConvertCocoPolysToMask类)
DETR 源码解读3(DETR类)
DETR 源码解读4(Joiner类/PositionEmbeddingSine类/位置编码/backbone)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/680119.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【MySQL】操作库 —— 库的操作 -- 详解

一、增删数据库 1、创建数据库 create database db_name; 本质就是在 /var/lib/mysql 创建一个目录。 说明: 大写的表示关键字。[ ] 是可选项。CHARACTER SET:指定数据库采用的字符集。COLLATE:指定数据库字符集的校验规则。 2、数据库删除…

Linux第51步_移植ST公司的linux内核第3步_添加修改设备树

1、设备树文件的路径 1)、创建linux中的设备树头文件 在“my_linux/linux-5.4.31/arch/arm/boot/dts/”目录中,以“stm32mp15xx-edx.dtsi”为蓝本,复制一份,并命名为 “stm32mp157d-atk.dtsi”,这就是我们开发板的设备树头文件。…

【stomp实战】Springboot+Stomp协议实现聊天功能

本示例实现一个功能,前端通过websocket发送消息给后端服务,后端服务接收到该消息时,原样将消息返回给前端。前端技术栈htmlstomp.js,后端SpringBoot 前端代码 关于stomp客户端的开发,如果不清楚的,可以看…

机器学习10-特征缩放

特征缩放的目的是确保不同特征的数值范围相近,使得模型在训练过程中更加稳定,加速模型收敛,提高模型性能。具体而言,零均值和单位方差的目标有以下几点好处: 1. 均值为零(Zero Mean)&#xff1a…

15 ABC基于状态机的按键消抖原理与状态转移图

1. 基于状态机的按键消抖 1.1 什么是按键? 从按键结构图10-1可知,按键按下时,接点(端子)与导线接通,松开时,由于弹簧的反作用力,接点(端子)与导线断开。 从…

【开源】SpringBoot框架开发天沐瑜伽馆管理系统

目录 一、摘要1.1 项目介绍1.2 项目录屏 二、功能模块2.1 数据中心模块2.2 瑜伽课程模块2.3 课程预约模块2.4 系统公告模块2.5 课程评价模块2.6 瑜伽器械模块 三、系统设计3.1 实体类设计3.1.1 瑜伽课程3.1.2 瑜伽课程预约3.1.3 系统公告3.1.4 瑜伽课程评价 3.2 数据库设计3.2.…

牛客周赛 Round 32 F.小红的矩阵修改【三进制状态压缩dp】

原题链接:https://ac.nowcoder.com/acm/contest/75174/F 时间限制:C/C 1秒,其他语言2秒 空间限制:C/C 262144K,其他语言524288K 64bit IO Format: %lld 题目描述 小红拿到了一个字符矩阵,矩阵中仅包含&q…

java 执行方式和类加载过程

java默认属于混合执行: 编译和解释并存 java先进行解释执行,遇到多次重复的代码会把它编程成可执行文件,方便下次直接执行。 可以通过VM参数来修改执行方式。 类加载过程

Nacos、Eureka、Zookeeper、Consul对比

开发中,经常需要对微服务进行管理,所以需要引入一些服务治理的中间件,用于注册、发现服务,常见的服务治理中间件为 服务治理中间件 【1】Nacos 【2】Eureka 【3】Zookeeper 【4】Consul(Consul 所在的 HashiCorp 公司…

从完成[flutter竖向显示文字]到对实现方式[Rich Text和Text Span]的一些整理

前言 完成的需求是竖向显示文字,而已有的RotatedBox虽然可以让文字内部控件进行指定角度的旋转,但是不能保持文字仍正常显示(它会因为旋转横着),遂尝试Rich Text和Text Span的方式,这两个我曾在android有略…

红队笔记Day2 -->上线不出网机器

今天就来讲一下在企业攻防中如何上线不出网的机器!! 1.基本网络拓扑 基本的网络拓扑就是这样 以下是对应得的P信息,其中的52网段充当一个内网的网段,而111充当公网网段 先ping一下,确保外网ping不通内网,内…

文档类图像的智能识别,百度、阿里、华为腾讯开放接口

文档类图像的智能识别是指利用人工智能技术对文档图像进行自动识别和信息提取。在我国,百度、阿里、华为和腾讯等科技巨头都提供了相应的开放接口,方便开发者集成和使用文档类图像识别功能。以下是这些公司提供的相关开放接口: 1. 百…

微信小程序(四十一)wechat-http的使用

注释很详细,直接上代码 上一篇 新增内容: 1.模块下载 2.模块的使用 在终端输入npm install wechat-http 没有安装成功vue的先看之前的一篇 微信小程序(二十)Vant组件库的配置- 如果按以上的成功配置出现如下报错先输入以下语句 …

leetcode 24

24. 两两链表交换链表中的节点 已经给出了链表节点结构类: public class ListNode {int val;ListNode next;ListNode() {}ListNode(int val) { this.val val; }ListNode(int val, ListNode next) { this.val val; this.next next; }} 简而言之,我们…

DS:单链表实现队列

创作不易,友友们来个三连支持吧! 一、队列的概念 队列:是只允许在一端进行插入数据操作,在另一端进行删除数据操作的特殊线性表,队列具有先进先出FIFO(First In First Out)的特点。 入队列:进行插入操作…

leetcode题目记录

文章目录 单调栈[127. 单词接龙](https://leetcode.cn/problems/word-ladder/)[139. 单词拆分](https://leetcode.cn/problems/word-break/)[15. 三数之和](https://leetcode.cn/problems/3sum/)[140. 单词拆分 II](https://leetcode.cn/problems/word-break-ii/)[113. 路径总和…

《数字孪生城市建设指引报告(2023年)》指引智慧城市行动方向

2023年12月27日,中国信息通信研究院(简称“中国信通院”)产业与规划研究所、中国互联网协会数字孪生技术应用工作委员会和苏州工业园区数字孪生创新坊联合发布《数字孪生城市建设指引报告(2023年)》。该报告提出了三大…

PostgreSQL的学习心得和知识总结(一百二十八)|构建 PostgreSQL 负载测试器

目录结构 注:提前言明 本文借鉴了以下博主、书籍或网站的内容,其列表如下: 1、参考书籍:《PostgreSQL数据库内核分析》 2、参考书籍:《数据库事务处理的艺术:事务管理与并发控制》 3、PostgreSQL数据库仓库链接,点击前往 4、日本著名PostgreSQL数据库专家 铃木启修 网站…

Linux:docker在线仓库(docker hub 阿里云)基础操作

把镜像放到公网仓库,这样可以方便大家一起使用,当需要时直接在网上拉取镜像,并且你可以随时管理自己的镜像——删除添加或者修改。 1.docker hub仓库 2.阿里云加速 3.阿里云仓库 由于docker hub是国外的网站,国内的对数据的把控…

.NET Core 3 foreach中取索引index

for和foreach 循环是 C# 开发人员工具箱中最有用的构造之一。 在我看来,迭代一个集合比大多数情况下更方便。 它适用于所有集合类型,包括不可索引的集合类型(如 ,并且不需要通过索引访问当前元素)。 但有时&#xf…