【深度学习实战(32)】模型结构之解耦头(de-coupled head)与耦合头(coupled head)

一、传统耦合头局限性

传统的检测模型,如YOLOv3和YOLOv4,使用的是单一的检测头,它同时预测目标类别和框的位置。然而,这种设计存在一些问题。首先,将类别预测和位置预测合并在一个头中,可能导致一个任务的误差对另一个任务的影响。其次,类别预测和位置预测的问题域不同,类别预测是一个多类分类问题,而位置预测是一个回归问题。这意味着它们需要不同的损失函数和网络层。

二、解耦头优势

解耦头的设计解决了上述问题。它将类别预测和位置预测分离开来,分别使用两个独立的网络分支进行处理。其中,类别预测使用一个全连接层来输出各个类别的概率,位置预测使用一系列卷积层来生成边界框的坐标。这样做的好处是可以分别优化类别预测和位置预测的损失函数,并且能够更灵活地设计网络结构和调整超参数。

三、哪些模型使用了解耦头?

1 FCOS

在这里插入图片描述

2 YOLOX

在这里插入图片描述

3 FastestDet

在这里插入图片描述

四 代码示例

耦合头demo

import torch
import torch.nn as nn
import torchvision.models as modelsclass CouplingHead(nn.Module):def __init__(self, num_classes, num_boxes):super(CouplingHead, self).__init__()self.num_classes = num_classesself.num_boxes = num_boxes# 使用预训练的ResNet18作为基础模型self.base_model = models.resnet18(pretrained=True)# 修改最后一层的输出通道数num_ftrs = self.base_model.fc.in_featuresself.base_model.fc = nn.Conv2d(num_ftrs, num_classes + 5 * num_boxes, kernel_size=1)# 分类分支self.classification = nn.Conv2d(num_classes, num_classes, kernel_size=1)# 回归分支self.regression = nn.Conv2d(5 * num_boxes, 5 * num_boxes, kernel_size=1)def forward(self, x):x = self.base_model(x)# 目标类别预测classification = self.classification(x[:, :self.num_classes, :, :])# 目标框回归regression = self.regression(x[:, self.num_classes:, :, :])return classification, regression# 创建耦合头模型
num_classes = 10  # 类别数量
num_boxes = 4  # 每个目标的边界框数量
model = CouplingHead(num_classes, num_boxes)# 随机生成输入数据
batch_size = 8
input_size = (224, 224)
x = torch.randn(batch_size, 3, *input_size)# 前向传播
classification, regression = model(x)# 输出结果
print("分类结果尺寸:", classification.shape)
print("回归结果尺寸:", regression.shape)

解耦头demo

import torch.nn as nn
import torch# 定义解耦头模型
class DecouplingHeader(nn.Module):def __init__(self, num_classes=20):super(CouplingHeader, self).__init__()self.num_classes = num_classes# 分类模块self.classification = nn.Sequential(nn.Conv2d(512, 256, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(256, 256, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(256, num_classes, kernel_size=1))# 回归模块self.regression = nn.Sequential(nn.Conv2d(512, 256, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(256, 256, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(256, 4, kernel_size=1))def forward(self, x):classification = self.classification(x)regression = self.regression(x)return classification, regression# 创建ResNet18主干网络
def resnet18():model = nn.Sequential(nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False),nn.BatchNorm2d(64),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=3, stride=2, padding=1),nn.Sequential(BasicBlock(64, 64, stride=1),BasicBlock(64, 64, stride=1)),nn.Sequential(BasicBlock(64, 128, stride=2),BasicBlock(128, 128, stride=1)),nn.Sequential(BasicBlock(128, 256, stride=2),BasicBlock(256, 256, stride=1)),nn.Sequential(BasicBlock(256, 512, stride=2),BasicBlock(512, 512, stride=1)),nn.AvgPool2d(7, stride=1),nn.Flatten())return model# 定义BasicBlock模块
class BasicBlock(nn.Module):def __init__(self, in_channels, out_channels, stride=1):super(BasicBlock, self).__init__()self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)self.bn1 = nn.BatchNorm2d(out_channels)self.relu = nn.ReLU(inplace=True)self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)self.bn2 = nn.BatchNorm2d(out_channels)self.stride = stridedef forward(self, x):identity = xout = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)if self.stride != 1:identity = self.downsample(x)out += identityout = self.relu(out)return out# 创建一个输入样本进行测试
input_sample = torch.randn(1, 3, 224, 224)# 创建ResNet18主干网络实例
backbone = resnet18()# 创建解耦头模型实例
header = DecouplingHeader()# 将输入样本通过主干网络和解耦模型进行前向传播
features = backbone(input_sample)
classification, regression = header(features)# 打印输出结果的形状
print("Classification output shape:", classification.shape)
print("Regression output shape:", regression.shape)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/6494.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Wireshark CLI | 过滤包含特定字符串的流

问题背景 源自于和朋友的一次技术讨论,关于 Wireshark 如何查找特定字符串所在的 TCP 流,原始问题如下: 仔细琢磨了下,基于我对 Wireshark 的使用经验,感觉一步到位实现比较困难,所以想着说用 Wireshark C…

旅游系列之:庐山美景

旅游系列之:庐山美景 一、路线二、住宿二、庐山美景 一、路线 庐山北门乘坐大巴上山,住在上山的酒店东线大巴游览三叠泉,不需要乘坐缆车,步行上下三叠泉即可,线路很短 二、住宿 长江宾馆庐山分部 二、庐山美景

Photoshop中图像编辑的基本操作

Photoshop中图像编辑的基本操作 Photoshop中调整图像窗口大小Photoshop中辅助工具的使用网格的使用标尺的使用注释工具的使用 Photoshop中置入嵌入式对象Photoshop中图像与画布的调整画布大小的修改画布的旋转图像尺寸的修改 Photoshop中撤销与还原采用快捷键进行撤销与还原采用…

机器学习之基于Jupyter多种混合模型的糖尿病预测

欢迎大家点赞、收藏、关注、评论啦 ,由于篇幅有限,只展示了部分核心代码。 文章目录 一项目简介 二、功能三、系统四. 总结 一项目简介 一、项目背景 随着现代生活方式的改变,糖尿病的患病率在全球范围内呈现上升趋势。糖尿病是一种慢性代谢…

上位机图像处理和嵌入式模块部署(树莓派4b使用lua)

【 声明:版权所有,欢迎转载,请勿用于商业用途。 联系信箱:feixiaoxing 163.com】 lua是一个脚本语言,比c语言开发容易,也没有python那么重,整体使用还是非常方便的。一般当成胶水语言进行开发&a…

【Hadoop】--基于hadoop和hive实现聊天数据统计分析,构建聊天数据分析报表[17]

目录 一、需求分析 1、背景介绍 2、目标 3、需求 4、数据内容 5、建库建表 二、ETL数据清洗 1、数据问题 2、需求 3、实现 4、扩展概念:ETL 三、指标计算 1、指标1:统计今日消息总量 2、指标2:统计每小时消息量、发送量和接收用…

哥白尼高程Copernicus DEM下载(CSDN_20240505)

哥白尼数字高程模型(Copernicus DEM, COP-DEM)由欧洲航天局(European Space Agency, 简称ESA或欧空局)发布,全球范围免费提供30米和90米分辨率DEM。COP-DEM是数字表面模型(DSM),它表示地球表面(包括建筑物、基础设施和植被)的高程。COP-DEM是经过编辑的D…

循环神经网络模块介绍(Pytorch 12)

到目前为止,我们遇到过两种类型的数据:表格数据和图像数据。对于图像数据,我们设计了专门的卷积神经网络架构(cnn)来为这类特殊的数据结构建模。换句话说,如果我们拥有一张图像,我们 需要有效地利用其像素位置&#xf…

算法课程笔记——蓝桥云课第六次直播

(只有一个数,或者因子只有一个)先自己打表,找找规律函数就是2的n次方 异或前缀和 相等就抵消 先前缀和再二分

【Python】机器学习之Sklearn基础教程大纲

机器学习之Sklearn基础教程大纲 1. 引言 机器学习简介Scikit-learn(Sklearn)库介绍安装和配置Sklearn 2. 数据预处理 2.1 数据加载与查看 - 加载CSV、Excel等格式的数据- 查看数据的基本信息(如形状、数据类型等)2.2 数据清洗…

本地部署大模型ollama+docker+open WebUI/Lobe Chat

文章目录 大模型工具Ollama下载安装运行Spring Ai 代码测试加依赖配置写代码 ollama的web&Desktop搭建部署Open WebUI有两种方式Docker DesktopDocker部署Open WebUIDocker部署Lobe Chat可以配置OpenAI的key也可以配置ollama 大模型的选择 本篇基于windows环境下配置 大模型…

翔云优配恒生指数涨1.85%、恒生科技指数涨3.74% 小鹏汽车涨超8%

5月3日港股开盘,恒生指数涨1.85%,报18543.3点,恒生科技指数涨3.74%,报4009.96点,国企指数涨2.23%,报6580.81点, 翔云优配是一家领先的在线投资平台,提供全球范围内的股票、期货、基金等交易服务…

小程序引入 Vant Weapp 极简教程

一切以 Vant Weapp 官方文档 为准 Vant Weapp 官方文档 - 快速入手 1. 安装nodejs 前往官网下载安装即可 nodejs官网 安装好后 在命令行(winr,输入cmd)输入 node -v若显示版本信息,即为安装成功 2. 在 小程序根目录 命令行/终端…

C++类的小结

1、类定义 使用class关键字定义类。 类名通常以大写字母开头,以符合命名规范。 类包含成员变量(也称为属性或数据成员)和成员函数(也称为方法或行为)。 class MyClass { public: int x; // 数据成员 void setX…

【Gateway远程开发】0.5GB of free space is necessary to run the IDE.

【Gateway远程开发】0.5GB of free space is necessary to run the IDE. 报错 0.5GB of free space is necessary to run the IDE. Make sure that there’s enough space in following paths: /root/.cache/JetBrains /root/.config/JetBrains 原因 下面两个路径的空间不…

【OpenNJet下一代云原生之旅】

OpenNJet下一代云原生之旅 1、OpenNJet的定义OpenNJet架构图 2、OpenNJet的特点性能无损动态配置灵活的CoPilot框架支持HTTP/3支持国密企业级应用高效安全 3、OpenNJet的功能特性4、OpenNJet的安装使用编译安装配置yum源创建符号连接修改配置编译 5、通过 OpenNJet 部署 WEB SE…

基于OpenCv的图像特征点检测

⚠申明: 未经许可,禁止以任何形式转载,若要引用,请标注链接地址。 全文共计3077字,阅读大概需要3分钟 🌈更多学习内容, 欢迎👏关注👀【文末】我的个人微信公众号&#xf…

【设计模式】函数式编程范式工厂模式(Factory Method Pattern)

目录标题 定义函数式接口函数式接口实现类工厂类封装实际应用总结 定义函数式接口 ISellIPad.java /*** 定义一个函数式接口* param <T>*/ FunctionalInterface public interface ISellIPad<T> {T getSellIPadInfo();}函数式接口实现类 HuaWeiSellIPad.java pu…

rust数据类型转换,as和TryInto使用

Rust 是类型安全的语言&#xff0c;因此在 Rust 中做类型转换不是一件简单的事&#xff0c;这一章节我们将对 Rust 中的类型转换进行详尽讲解。 as转换 先来看一段代码&#xff1a; fn main() {let a: i32 10;let b: u16 100;if a < b {println!("Ten is less tha…

无U盘基于本地硬盘无损制作虚拟U盘(Windows、Linux系统安装启动盘)

知识点 实验环境 名称版本使用平台Win11本地硬盘格式GPT待安装镜像deepin-desktop-community-20.9-amd64.iso 文中工具下载链接&#xff1a; https://download.csdn.net/download/xzzteach/89263714 deepin-desktop-community-20.9-amd64.iso 文件结构如下&#xff1a; 在Li…