基于transformer的解码decode目标检测框架(修改DETR源码)

提示:transformer结构的目标检测解码器,包含loss计算,附有源码

文章目录

  • 前言
  • 一、main函数代码解读
    • 1、整体结构认识
    • 2、main函数代码解读
    • 3、源码链接
  • 二、decode模块代码解读
    • 1、decoded的TransformerDec模块代码解读
    • 2、decoded的TransformerDecoder模块代码解读
    • 3、decoded的DecoderLayer模块代码解读
  • 三、decode模块训练demo代码解读
    • 1、解码数据输入格式
    • 2、解码训练demo代码解读
  • 四、decode模块预测demo代码解读
    • 1、预测数据输入格式
    • 2、解码预测demo代码解读
  • 五、losses模块代码解读
    • 1、matcher初始化
    • 2、二分匹配matcher代码解读
    • 3、num_classes参数解读
    • 4、losses的demo代码解读


前言

最近重温DETR模型,越发感觉detr模型结构精妙之处,不同于anchor base 与anchor free设计,直接利用100框给出预测结果,使用可学习learn query深度查找,使用二分匹配方式训练模型。为此,我基于detr源码提取解码decode、loss计算等系列模块,并重构、修改、整合一套解码与loss实现的框架,该框架可适用任何backbone特征提取接我框架,实现完整训练与预测,我也有相应demo指导使用我的框架。那么,接下来,我将完整介绍该框架源码。同时,我将此源码进行开源,并上传github中,供读者参考。


一、main函数代码解读

1、整体结构认识

在介绍main函数代码前,我先说下整体框架结构,该框架包含2个文件夹,一个losses文件夹,用于处理loss计算,一个是obj_det文件,用于transformer解码模块,该模块源码修改于detr模型,也包含main.py,该文件是整体解码与loss计算demo示意代码,如下图。

在这里插入图片描述

2、main函数代码解读

该代码实际是我随机创造了标签target数据与backbone特征提取数据及位置编码数据,使其能正常运行的demo,其代码如下:

import torch
from obj_det.transformer_obj import TransformerDec
from losses.matcher import HungarianMatcher
from losses.loss import SetCriterionif __name__ == '__main__':Model = TransformerDec(d_model=256, output_intermediate_dec=True, num_classes=4)num_classes = 4   #  类别+1matcher = HungarianMatcher(cost_class=1, cost_bbox=5, cost_giou=2)  # 二分匹配不同任务分配的权重losses = ['labels', 'boxes', 'cardinality']  # 计算loss的任务weight_dict = {'loss_ce': 1, 'loss_bbox': 5, 'loss_giou': 2}  # 为dert最后一个设置权重criterion = SetCriterion(num_classes, matcher=matcher, weight_dict=weight_dict, eos_coef=0.1, losses=losses)# 下面使用iter,我构造了虚拟模型编码数据与数据加载标签数据src = torch.rand((391, 2, 256))pos_embed = torch.ones((391, 1, 256))# 创造真实target数据target1 = {'boxes':torch.rand((5,4)),'labels':torch.tensor([1,3,2,1,2])}target2 = {'boxes': torch.rand((3, 4)), 'labels': torch.tensor([1, 1, 2])}target = [target1, target2]res = Model(src, pos_embed)losses = criterion(res, target)print(losses)

如下图:

在这里插入图片描述

3、源码链接

源码链接:点击这里

二、decode模块代码解读

该模块主要是使用transform方式对backbone提取特征的解码,主要使用learn query等相关trike与transform解码方式内容。
我主要介绍TransformerDec、TransformerDecoder、DecoderLayer模块,为依次被包含关系,或说成后者是前者组成部分。

1、decoded的TransformerDec模块代码解读

该类大意是包含了learn query嵌入、解码transform模块调用、head头预测logit与boxes等内容,是实现解码与预测内容,该模块参数或解释已有注释,读者可自行查看,其代码如下:

class TransformerDec(nn.Module):'''d_model=512, 使用多少维度表示,实际为编码输出表达维度nhead=8, 有多少个头num_queries=100, 目标查询数量,可学习querynum_decoder_layers=6, 解码循环层数dim_feedforward=2048, 类似FFN的2个nn.Linear变化dropout=0.1,activation="relu",normalize_before=False,解码结构使用2种方式,默认False使用post解码结构output_intermediate_dec=False, 若为True保存中间层解码结果(即:每个解码层结果保存),若False只保存最后一次结果,训练为True,推理为Falsenum_classes: num_classes数量与数据格式有关,若类别id=1表示第一类,则num_classes=实际类别数+1,若id=0表示第一个,则num_classes=实际类别数额外说明,coco类别id是1开始的,假如有三个类,名称为[dog,cat,pig],batch=2,那么参数num_classes=4,表示3个类+1个背景,模型输出src_logits=[2,100,5]会多出一个预测,target_classes设置为[2,100],其值为4(该值就是背景,而有类别值为123),那么target_classes中没有值为0,我理解模型不对0类做任何操作,是个无效值,模型只对1234进行loss计算,然4为背景会比较多,作者使用权重0.1避免其背景过度影响。forward return: 返回字典,包含{'pred_logits':[],  # 为列表,格式为[b,100,num_classes+2]'pred_boxes':[],  # 为列表,格式为[b,100,4]'aux_outputs'[{},...] # 为列表,元素为字典,每个字典为{'pred_logits':[],'pred_boxes':[]},格式与上相同}'''def __init__(self, d_model=512, nhead=8, num_queries=100, num_decoder_layers=6, dim_feedforward=2048, dropout=0.1,activation="relu", normalize_before=False, output_intermediate_dec=False, num_classes=1):super().__init__()self.num_queries = num_queriesself.query_embed = nn.Embedding(num_queries, d_model)  # 与编码输出表达维度一致self.output_intermediate_dec = output_intermediate_decdecoder_layer = DecoderLayer(d_model, nhead, dim_feedforward,dropout, activation, normalize_before)decoder_norm = nn.LayerNorm(d_model)self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/128270.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

《现代C++语言核心特性解析》笔记(一)

一、新基础类型(C11~C20) C基础类型回顾一览表 1. 整数类型 long long 我们知道long通常表示一个32位整型,而long long则是用来表示一个64位的整型。不得不说,这种命名方式简单粗暴。不仅写法冗余,而且表…

【多线程面试题二十二】、 说说你对读写锁的了解

文章底部有个人公众号:热爱技术的小郑。主要分享开发知识、学习资料、毕业设计指导等。有兴趣的可以关注一下。为何分享? 踩过的坑没必要让别人在再踩,自己复盘也能加深记忆。利己利人、所谓双赢。 面试官:说说你对读写锁的了解 …

华纳云:centos系统中怎么查看cpu信息?

在CentOS系统中,我们可以使用一些命令来查看CPU的详细信息。下面介绍几个常用的命令: 1. lscpu lscpu命令可以显示CPU的架构、型号、核心数、线程数、频率等信息。 # lscpu 执行以上命令后,会输出类似以下内容: 2. cat /proc/…

【C/C++】继承中同名成员处理方式

一、继承中同名成员处理方式 问题:当子类与父类出现同名的成员,如何通过子类对象,访问到子类或父类中同名的数据呢? 访问子类同名成员 直接访问即可访问父类同名成员 需要加作用域 示例: class Base { public:Base…

在IDEA运行spark程序(搭建Spark开发环境)

建议大家写在Linux上搭建好Hadoop的完全分布式集群环境和Spark集群环境,以下在IDEA中搭建的环境仅仅是在window系统上进行spark程序的开发学习,在window系统上可以不用安装hadoop和spark,spark程序可以通过pom.xml的文件配置,添加…

Java 的高性能缓存库-caffeine!

在项目中用到的除了分布式缓存,还有本地缓存,例如:Guava、Encache,使用本地缓存能够很大程度上提升程序性能,本地缓存是直接从本地内存中读取,没有网络开销。 今天给大家介绍一个高性能的 Java 缓存库 – …

Angular组件生命周期详解

当 Angular 实例化组件类 并渲染组件视图及其子视图时,组件实例的生命周期就开始了。生命周期一直伴随着变更检测,Angular 会检查数据绑定属性何时发生变化,并按需更新视图和组件实例。当 Angular 销毁组件实例并从 DOM 中移除它渲染的模板时…

Ubuntu 使用 nginx 搭建 https 文件服务器

Ubuntu 使用 nginx 搭建 https 文件服务器 搭建步骤安装 nginx生成证书修改 config重启 nginx 搭建步骤 安装 nginx生成证书修改 config重启 nginx 安装 nginx apt 安装: sudo apt-get install nginx生成证书 使用 openssl 生成证书: 到对应的路径…

案例精选|聚铭综合日志分析系统夯实徐州公交集团网络环境基础

徐州市公共交通集团有限公司成立于1960年,现隶属徐州市交通控股集团有限公司,下辖7家运营公司,1家站务公司,8家直属单位及13个职能部室。运营车辆2364辆,线路177条,线路长度3560公里,日发送班次…

云尘-Node1 js代码

继续做题 拿到就是基本扫一下 nmap -sP 172.25.0.0/24 nmap -sV -sS -p- -v 172.25.0.13 然后顺便fscan扫一下咯 nmap: fscan: 还以为直接getshell了 老演员了 其实只是302跳转 所以我们无视 只有一个站 直接看就行了 扫出来了两个目录 但是没办法 都是要跳转 说明还是需要…

@Tag和@Operation标签失效问题。SpringDoc 2.2.0(OpenApi 3)和Spring Boot 3.1.1集成

问题 Tag和Operation标签失效 但是Schema标签有效 pom依赖 <!-- 接口文档--><!--引入openapi支持--><dependency><groupId>org.springdoc</groupId><artifactId>springdoc-openapi-starter-webmvc-ui</artifactId><vers…

51单片机电子钟闹钟温度LCD1602液晶显示设计( proteus仿真+程序+原理图+设计报告+讲解视频)

51单片机电子钟闹钟温度液晶显示设计( proteus仿真程序原理图设计报告讲解视频&#xff09; 1.主要功能&#xff1a;2.仿真3. 程序代码4. 原理图5. 设计报告6. 设计资料内容清单&&下载链接资料下载链接&#xff08;可点击&#xff09;&#xff1a; &#x1f31f;51单片…

Python+requests+exce接口自动化测试框架

一、接口自动化测试框架 二、工程目录 三、Excel测试用例设计 四、基础数据base 封装post/get&#xff1a;runmethod.py #!/usr/bin/env python3 # -*-coding:utf-8-*- # __author__: hunterimport requests import jsonclass RunMain:def send_get(self, url, data):res req…

Day 5 登录页及路由 (三) 基于axios的API调用

系列文章目录 本系列记录一下通过Abp搭建后端&#xff0c;VueElement UI Plus搭建前端&#xff0c;实现一个小型项目的过程。 Day 1 Vue 页面框架Day 2 Abp框架下&#xff0c;MySQL数据迁移时&#xff0c;添加表和字段注释Day 3 登录页以及路由 (一&#xff09;Day 4 登录页以…

viteePress搭建组件文档

目录 安装vitepress 目录结构 文档首页 Home Page Hero 部分 Features 部分 导航栏配置 logo 导航链接 socialLinks 侧边栏 基本使用 多个侧边栏 使用组件 在 markdown 中导入组件 在 theme 中注册全局组件 部署到Github Pages 前提 第一步 第二步 …

大模型问答助手前端实现打字机效果 | 京东云技术团队

1. 背景 随着现代技术的快速发展&#xff0c;即时交互变得越来越重要。用户不仅希望获取信息&#xff0c;而且希望以更直观和实时的方式体验它。这在聊天应用程序和其他实时通信工具中尤为明显&#xff0c;用户习惯看到对方正在输入的提示。 ChatGPT&#xff0c;作为 OpenAI …

[ poi-表格导出 ] java.lang.NoClassDefFoundError: org/apache/poi/POIXMLTypeLoader

解决报错&#xff1a; org.springframework.web.util.NestedServletException: Handler dispatch failed; nested exception is java.lang.NoClassDefFoundError: org/apache/poi/POIXMLTypeLoader 报错描述&#xff1a; 表格导出本来使用正常&#xff0c;偶然就报了以上错误…

项目实战:编辑页面加载库存信息

1、前端编辑页面加载水果库存信息逻辑edit.js let queryString window.location.search.substring(1) if(queryString){var fid queryString.split("")[1]window.onloadfunction(){loadFruit(fid)}loadFruit function(fid){axios({method:get,url:edit,params:{fi…

(四)docker:为mysql和java jar运行环境创建同一网络,容器互联

看了很多资料&#xff0c;说做互联的一个原因是容器内ip不固定&#xff0c;关掉重启后如果有别的容器启动&#xff0c;之前的ip会被占用&#xff0c;所以做互联创建一个网络&#xff0c;让几个容器处于同一个网络&#xff0c;就可以互联还不受关闭再启动ip会改变的影响&#xf…

opencv复习(简短的一次印象记录)

2-高斯与中值滤波_哔哩哔哩_bilibili 1、均值滤波 2、高斯滤波 3、中值滤波 4、腐蚀操作 卷积核不都是255就腐蚀掉 5、膨胀操作 6、开运算 先腐蚀再膨胀 7、闭运算 先膨胀再腐蚀 8、礼帽 原始数据-开运算结果 9、黑帽 闭运算结果-原始数据 10、Sobel算子 左-右&#x…