Transformer实战-系列教程15:DETR 源码解读2(整体架构:DETR类)

🚩🚩🚩Transformer实战-系列教程总目录

有任何问题欢迎在下面留言
本篇文章的代码运行界面均在Pycharm中进行
本篇文章配套的代码资源已经上传
点我下载源码

DETR 算法解读
DETR 源码解读1(项目配置/CocoDetection类/ConvertCocoPolysToMask类)
DETR 源码解读2(DETR类)
DETR 源码解读3(位置编码:Joiner类/PositionEmbeddingSine类)
DETR 源码解读4(BackboneBase类/Backbone类)
DETR 源码解读5(Transformer类)
DETR 源码解读6(编码器:TransformerEncoder类/TransformerEncoderLayer类)
DETR 源码解读7(解码器:TransformerDecoder类/TransformerDecoderLayer类)
DETR 源码解读8(训练函数/损失函数)

4、DETR类

位置:models/detr.py/DETR类

4.1 构造函数

class DETR(nn.Module):def __init__(self, backbone, transformer, num_classes, num_queries, aux_loss=False):super().__init__()self.num_queries = num_queriesself.transformer = transformerhidden_dim = transformer.d_modelself.class_embed = nn.Linear(hidden_dim, num_classes + 1)self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)self.query_embed = nn.Embedding(num_queries, hidden_dim)self.input_proj = nn.Conv2d(backbone.num_channels, hidden_dim, kernel_size=1)self.backbone = backboneself.aux_loss = aux_loss
  1. DETR类继承torch nn.Module
  2. 构造函数,传入5个参数:
    • backbone:CNN骨架网络,用于特征提取
    • transformer:Transformer模型,用于处理序列数据
    • num_classes:目标类别的数量
    • num_queries:解码器初始化生成的100个向量的个数,num_queries=100
    • aux_loss:一个布尔值,指示是否使用辅助损失来帮助训练
  3. 初始化
  4. num_queries
  5. transformer
  6. hidden_dim ,Transformer中的隐藏层维度
  7. class_embed ,类别预测的输出层,这个全连接层是接Transformer的输出,类别加1是额外的无类别对象
  8. bbox_embed,一个MLP,也是接Transformer的输出,边界框的四个坐标的回归
  9. query_embed ,解码器的初始100个向量
  10. input_proj ,一个1x1的二维卷积,使得backbone的输出通道数映射到与Transformer隐藏层维度相同
  11. backbone,一个预训练的卷积神经网络,主要作用是提取图像的特征,它的输出经过input_proj 处理后作为Transformer的输入
  12. aux_loss,保存是否使用辅助损失的标志

这里包含了几个自定义函数和类:
nested_tensor_from_tensor_list函数:将不同尺寸处理的图像Tensor转换为一个嵌套Tensor
MLP类:边界框的四个坐标的回归
transformer类:构建transformer架构
backbone:用于提取图像特征的CNN

4.2 前向传播

    def forward(self, samples: NestedTensor):if isinstance(samples, (list, torch.Tensor)):samples = nested_tensor_from_tensor_list(samples)features, pos = self.backbone(samples)src, mask = features[-1].decompose()assert mask is not Nonehs = self.transformer(self.input_proj(src), mask, self.query_embed.weight, pos[-1])[0]outputs_class = self.class_embed(hs)outputs_coord = self.bbox_embed(hs).sigmoid()out = {'pred_logits': outputs_class[-1], 'pred_boxes': outputs_coord[-1]}if self.aux_loss:out['aux_outputs'] = self._set_aux_loss(outputs_class, outputs_coord)return out    
  1. 前向传播函数,输入为samples=NestedTensor{mask={Tensor(2,771,911)},tensors={Tensor(2,3,771,911)}}
  2. 检查samples是否为列表或Tensor类型
  3. samples ,如果,使用nested_tensor_from_tensor_list函数转换为NestedTensor
  4. features, pos,图像特征图对应的位置编码,backbone实际上是一个resnet,features和pos是一个list结构,保存了各层的输出
  5. src, mask,解构最后一层的特征,获取源数据和掩码,src:torch.Size([2, 2048, 21, 18]),mask torch.Size([2, 21, 18]),2是batch,2048是特征维度,后面两个是图像长宽,这里的features[-1]表示在backbone中有多层都有输出,features保存了各层的输出,这里-1就表示最后的输出
  6. 确保掩码不为空
  7. 将数据通过Transformer处理,获取序列输出,torch.Size([6, 2, 100, 256]),6是Transformer的堆叠层数,2是batch,100是生成100个目标预测,256是每个目标预测的维度,Transformer模块有两个返回值,只取第一个返回值
  8. outputs_class ,获取类别预测
  9. outputs_coord ,获取边界框坐标预测,并使用sigmoid函数将输出值限制在0到1之间
  10. out ,将类别预测结果和 边界框坐标预测结果做成一个字典
  11. 如果启用了辅助损失
  12. 通过辅助函数_set_aux_loss计算辅助损失
  13. 返回out

4.3 辅助函数_set_aux_loss()

@torch.jit.unuseddef _set_aux_loss(self, outputs_class, outputs_coord):return [{'pred_logits': a, 'pred_boxes': b}for a, b in zip(outputs_class[:-1], outputs_coord[:-1])]
  1. @torch.jit.unused:一个装饰器,指示当使用TorchScript编译模型时,该方法不应被编译。这是因为辅助损失的计算可能不兼容TorchScript的静态图特性
  2. 定义函数,接收类别预测和边界框坐标作为输入
  3. 返回一个列表,将每一个类别预测和边界框坐标都封装成一个字典,这样,训练过程中可以计算每一层的损失,从而实现辅助损失的目的

DETR 算法解读
DETR 源码解读1(项目配置/CocoDetection类/ConvertCocoPolysToMask类)
DETR 源码解读2(DETR类)
DETR 源码解读3(位置编码:Joiner类/PositionEmbeddingSine类)
DETR 源码解读4(BackboneBase类/Backbone类)
DETR 源码解读5(Transformer类)
DETR 源码解读6(编码器:TransformerEncoder类/TransformerEncoderLayer类)
DETR 源码解读7(解码器:TransformerDecoder类/TransformerDecoderLayer类)
DETR 源码解读8(训练函数/损失函数)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/689965.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

什么是485远程水表?

485远程水表是一种利用RS485通信协议进行数据传输的智能水表,它具有远程读数、实时监控、数据存储等功能,为水资源管理和居民用水提供了便捷。在我国,随着物联网、大数据等技术的发展,485远程水表得到了广泛的应用,为智…

引领企业服务新篇章,纷享销客揽获4项大奖

近日,连接型CRM的开创者纷享销客,凭借其卓越的整体实力,分别荣获《互联网周刊》&eNet研究院“2023年度最佳企业服务产品奖”、携手盈建科荣获中国工业报社“数字化转型优秀案例”、入选产业家“2023产业数字化金铲奖”以及KVBrand“2023年…

BUUCTF第二十二、二十三题解题思路

第二十二题[WUSTCTF2020]level1 查壳 64位ELF文件,用64位IDA打开。 在函数界面可以看到一个“flag”,跟进该函数。 int __cdecl main(int argc, const char **argv, const char **envp) {int i; // [rsp4h] [rbp-2Ch]FILE *stream; // [rsp8h] [rbp-2…

Java项目中,值的对应问题

数据库表 实体类&#xff08;对应数据库的字段&#xff0c;可以驼峰命名&#xff09; 封装的查询方法sql List<Student> getAllStudents(String name,String studentId,Integer classId,String className); 这里的值一一对应。 在多表查询时&#xff0c;查询到的指定字段…

备战蓝桥杯 Day3

目录 搜索与回溯 1222&#xff1a;放苹果 1221&#xff1a;分成互质组 1218&#xff1a;取石子游戏 数组 1126&#xff1a;矩阵转置 1127&#xff1a;图像旋转 1128&#xff1a;图像模糊处理 1120&#xff1a;同行列对角线的格 string 2046&#xff1a;【例5.15】替换…

JAVA高并发——核心知识点

文章目录 1、重要概念1.1、同步(Synchronous)和异步(Asynchronous)1.2、并发(Concurrency)和并行(Parallelism)1.3、临界区1.4、阻塞(Blocking)和非阻塞(Non-Blocking)1.5、死锁(Deadlock)、饥饿(Starvation)和活锁(Livelock)1.6、并发级别1.6.1、阻塞1.6.2、无饥饿(Starvation…

2011-2022年上市公司ESG表现、制造业高质量发展与数字化转型原始数据计算结果do代码

2011-2022年上市公司ESG表现、制造业高质量发展与数字化转型 原始数据(exceldta)计算结果do代码 参照王丹&#xff08;2023&#xff09;的做法&#xff0c;对来自统计与决策《ESG表现、制造业高质量发展与数字化转型》一文中的基准回归部分进行复刻&#xff1a; 1、数据时间&a…

java 单例模式

单例模式是最简单的设计模式之一。即一个类负责创建自己的对象&#xff0c;同时确保只有单个对象被创建&#xff0c;提供一种访问其唯一的对象的方式&#xff0c;可以直接访问&#xff0c;不需要实例化该类的对象。 1、懒汉式&#xff0c;线程不安全 public class Singleton …

两个发散级数的和是否发散?

1、两个发散级数的和可能是收敛的也可能是发散的。 例子&#xff1a; 发散级数 ∑ 1 n \sum\frac{1}{n} ∑n1​和发散级数 ∑ ( 1 n 2 − 1 n ) \sum(\frac{1}{n^{2}}-\frac{1}{n}) ∑(n21​−n1​)的和是收敛级数&#xff1b; 发散级数∑(1/n) 和发散级数 ∑(1/n1/n) 的和是…

为什么你用的redis没有出现雪崩,击穿,穿透

一、前言 在大规模并发访问系统中&#xff0c;如果你的系统用到redis&#xff0c;在面试的时候面试官往往会问你的系统有没有出现雪崩&#xff0c;击穿&#xff0c;穿透这样的场景&#xff0c;然后是怎样解决的。博主也经常反复温习redis的特性&#xff0c;总是被雪崩&#xf…

不懂咱就学,记不住多看几遍(二)

一、Redis分布式锁中加锁与解锁、过期如何续命 实现要点&#xff1a; 互斥性&#xff0c;同一时刻&#xff0c;只能有一个客户端持有锁。防止死锁发生&#xff0c;如果持有锁的客户端因崩溃而没有主动释放锁&#xff0c;也要保证锁可以释放并且其他客户端可以正常加锁。加锁和…

请解释 C++ 中的析构函数,并说明它们的作用。

请解释 C 中的析构函数&#xff0c;并说明它们的作用。 在C中&#xff0c;析构函数&#xff08;Destructor&#xff09;是一种特殊类型的成员函数&#xff0c;用于在对象被销毁时执行特定的清理工作。析构函数的名称与类名相同&#xff0c;前面加上一个波浪号&#xff08;~&am…

【C++开篇 -- 入门语法篇】

C学习笔记---001 C知识开篇1、介绍C的背景以及与C语言的区别1.1、什么是C?1.2、C的背景 2、C与C语言的区别3、C优化命名空间3.1、C中的问题3.2、命名空间的应用 4、总结 C知识开篇 前言&#xff1a; 首先&#xff0c;C兼容C&#xff0c;C在C语言范畴上增添了一些优化&#xf…

WPF中样式

WPF中样式:类似于winform中控件的属性 <Grid><!-- Button属性 字体大小 字体颜色 内容 控件宽 高 --><Button FontSize="20" Foreground="Blue" Content="Hello" Width="100" Height="40"/></Grid&g…

proteus8.15图文安装教程

proteus8.15版本可以用STM32系列单片机来进行仿真设计&#xff0c;比7.8版本方便多了&#xff0c;有需要的朋友们可以在公众号后台回复 proteus8.15 获取软件包。 1、下载好软件包&#xff0c;解压如下&#xff0c;右键proteus8.15.sp1以管理员身份运行。 2、第一次安装&#x…

UE5 动态加载资源和类

// Called when the game starts or when spawned void AMyActor::BeginPlay() {Super::BeginPlay();if (MyActor){UE_LOG(LogTemp,Warning,TEXT("MyActor is %s"),*MyActor->GetName());}//动态加载资源UStaticMesh* MyTmpStaticMesh LoadObject<UStaticMesh…

什么时候会触发FullGC?描述一下JVM加载class文件的原理机制?

什么时候会触发 FullGC&#xff1f; 除直接调用 System.gc 外&#xff0c;触发 Full GC 执行的情况有如下四种。 1. 旧生代空间不足 旧生代空间只有 在新生代对象转入及创建为大对象、大数组时才会出现不足的现象&#xff0c;当执行 Full GC 后空间仍然不 足&#xff0c;则…

ALINX黑金AXU3EGB 开发板用户手册RS485通信接口图示DI RO信号方向标识错误说明

MAX3485这类RS485芯片&#xff0c;DI是TTL信号输入&#xff0c;RO是TTL信号输出 如下图是MAX3485手册规格书。 因此 ALINX黑金AXU3EGB 用户手册 Page 43页 图 3-11-1 PL 端 485 通信的连接示意图&#xff0c;MAX3485芯片的DI RO信号输入输出标识方向是错误的&#xff0c;应为蓝…

Redis 只会用缓存?16种妙用让同事直呼牛X

1、缓存2、数据共享分布式3、分布式锁4、全局ID5、计数器6、限流7、位统计8、购物车9、用户消息时间线timeline10、消息队列11、抽奖12、点赞、签到、打卡13、商品标签14、商品筛选15、用户关注、推荐模型16、排行榜图片 1、缓存 String类型 例如:热点数据缓存(例如报表、明…

Spring Farmework,Spring Boot,Spring MVC 分别是什么?它们的关系又是什么?

Spring Framework是一个综合性的Java开发框架&#xff0c;提供了一系列的模块和功能来简化企业级应用程序的开发。Spring框架包括IoC&#xff08;Inversion of Control&#xff09;容器、AOP&#xff08;Aspect-Oriented Programming&#xff09;支持、数据访问、事务管理、模型…