【深度学习实战(32)】模型结构之解耦头(de-coupled head)与耦合头(coupled head)

一、传统耦合头局限性

传统的检测模型,如YOLOv3和YOLOv4,使用的是单一的检测头,它同时预测目标类别和框的位置。然而,这种设计存在一些问题。首先,将类别预测和位置预测合并在一个头中,可能导致一个任务的误差对另一个任务的影响。其次,类别预测和位置预测的问题域不同,类别预测是一个多类分类问题,而位置预测是一个回归问题。这意味着它们需要不同的损失函数和网络层。

二、解耦头优势

解耦头的设计解决了上述问题。它将类别预测和位置预测分离开来,分别使用两个独立的网络分支进行处理。其中,类别预测使用一个全连接层来输出各个类别的概率,位置预测使用一系列卷积层来生成边界框的坐标。这样做的好处是可以分别优化类别预测和位置预测的损失函数,并且能够更灵活地设计网络结构和调整超参数。

三、哪些模型使用了解耦头?

1 FCOS

在这里插入图片描述

2 YOLOX

在这里插入图片描述

3 FastestDet

在这里插入图片描述

四 代码示例

耦合头demo

import torch
import torch.nn as nn
import torchvision.models as modelsclass CouplingHead(nn.Module):def __init__(self, num_classes, num_boxes):super(CouplingHead, self).__init__()self.num_classes = num_classesself.num_boxes = num_boxes# 使用预训练的ResNet18作为基础模型self.base_model = models.resnet18(pretrained=True)# 修改最后一层的输出通道数num_ftrs = self.base_model.fc.in_featuresself.base_model.fc = nn.Conv2d(num_ftrs, num_classes + 5 * num_boxes, kernel_size=1)# 分类分支self.classification = nn.Conv2d(num_classes, num_classes, kernel_size=1)# 回归分支self.regression = nn.Conv2d(5 * num_boxes, 5 * num_boxes, kernel_size=1)def forward(self, x):x = self.base_model(x)# 目标类别预测classification = self.classification(x[:, :self.num_classes, :, :])# 目标框回归regression = self.regression(x[:, self.num_classes:, :, :])return classification, regression# 创建耦合头模型
num_classes = 10  # 类别数量
num_boxes = 4  # 每个目标的边界框数量
model = CouplingHead(num_classes, num_boxes)# 随机生成输入数据
batch_size = 8
input_size = (224, 224)
x = torch.randn(batch_size, 3, *input_size)# 前向传播
classification, regression = model(x)# 输出结果
print("分类结果尺寸:", classification.shape)
print("回归结果尺寸:", regression.shape)

解耦头demo

import torch.nn as nn
import torch# 定义解耦头模型
class DecouplingHeader(nn.Module):def __init__(self, num_classes=20):super(CouplingHeader, self).__init__()self.num_classes = num_classes# 分类模块self.classification = nn.Sequential(nn.Conv2d(512, 256, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(256, 256, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(256, num_classes, kernel_size=1))# 回归模块self.regression = nn.Sequential(nn.Conv2d(512, 256, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(256, 256, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(256, 4, kernel_size=1))def forward(self, x):classification = self.classification(x)regression = self.regression(x)return classification, regression# 创建ResNet18主干网络
def resnet18():model = nn.Sequential(nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False),nn.BatchNorm2d(64),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=3, stride=2, padding=1),nn.Sequential(BasicBlock(64, 64, stride=1),BasicBlock(64, 64, stride=1)),nn.Sequential(BasicBlock(64, 128, stride=2),BasicBlock(128, 128, stride=1)),nn.Sequential(BasicBlock(128, 256, stride=2),BasicBlock(256, 256, stride=1)),nn.Sequential(BasicBlock(256, 512, stride=2),BasicBlock(512, 512, stride=1)),nn.AvgPool2d(7, stride=1),nn.Flatten())return model# 定义BasicBlock模块
class BasicBlock(nn.Module):def __init__(self, in_channels, out_channels, stride=1):super(BasicBlock, self).__init__()self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)self.bn1 = nn.BatchNorm2d(out_channels)self.relu = nn.ReLU(inplace=True)self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)self.bn2 = nn.BatchNorm2d(out_channels)self.stride = stridedef forward(self, x):identity = xout = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)if self.stride != 1:identity = self.downsample(x)out += identityout = self.relu(out)return out# 创建一个输入样本进行测试
input_sample = torch.randn(1, 3, 224, 224)# 创建ResNet18主干网络实例
backbone = resnet18()# 创建解耦头模型实例
header = DecouplingHeader()# 将输入样本通过主干网络和解耦模型进行前向传播
features = backbone(input_sample)
classification, regression = header(features)# 打印输出结果的形状
print("Classification output shape:", classification.shape)
print("Regression output shape:", regression.shape)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/6494.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

机器学习小tip

有监督学习 有监督学习是通过现有训练数据集进行建模,再用模型对新的数据样本进行分类或者回归分析的机器学习 方法。 无监督学习 而无监督学习,或者说非监督式学习,则是在没有训练数据集的情况下,对没有标 签的数据进行分析并…

Wireshark CLI | 过滤包含特定字符串的流

问题背景 源自于和朋友的一次技术讨论,关于 Wireshark 如何查找特定字符串所在的 TCP 流,原始问题如下: 仔细琢磨了下,基于我对 Wireshark 的使用经验,感觉一步到位实现比较困难,所以想着说用 Wireshark C…

Mybatis Interview Question Summary

1. In best practice, usually an Xml mapping file will write a Dao interface corresponding to it. What is the working principle of the Dao interface? Can the methods in the Dao interface be overloaded when the parameters are different? Answer: The Dao in…

旅游系列之:庐山美景

旅游系列之:庐山美景 一、路线二、住宿二、庐山美景 一、路线 庐山北门乘坐大巴上山,住在上山的酒店东线大巴游览三叠泉,不需要乘坐缆车,步行上下三叠泉即可,线路很短 二、住宿 长江宾馆庐山分部 二、庐山美景

Photoshop中图像编辑的基本操作

Photoshop中图像编辑的基本操作 Photoshop中调整图像窗口大小Photoshop中辅助工具的使用网格的使用标尺的使用注释工具的使用 Photoshop中置入嵌入式对象Photoshop中图像与画布的调整画布大小的修改画布的旋转图像尺寸的修改 Photoshop中撤销与还原采用快捷键进行撤销与还原采用…

机器学习之基于Jupyter多种混合模型的糖尿病预测

欢迎大家点赞、收藏、关注、评论啦 ,由于篇幅有限,只展示了部分核心代码。 文章目录 一项目简介 二、功能三、系统四. 总结 一项目简介 一、项目背景 随着现代生活方式的改变,糖尿病的患病率在全球范围内呈现上升趋势。糖尿病是一种慢性代谢…

上位机图像处理和嵌入式模块部署(树莓派4b使用lua)

【 声明:版权所有,欢迎转载,请勿用于商业用途。 联系信箱:feixiaoxing 163.com】 lua是一个脚本语言,比c语言开发容易,也没有python那么重,整体使用还是非常方便的。一般当成胶水语言进行开发&a…

【Hadoop】--基于hadoop和hive实现聊天数据统计分析,构建聊天数据分析报表[17]

目录 一、需求分析 1、背景介绍 2、目标 3、需求 4、数据内容 5、建库建表 二、ETL数据清洗 1、数据问题 2、需求 3、实现 4、扩展概念:ETL 三、指标计算 1、指标1:统计今日消息总量 2、指标2:统计每小时消息量、发送量和接收用…

哥白尼高程Copernicus DEM下载(CSDN_20240505)

哥白尼数字高程模型(Copernicus DEM, COP-DEM)由欧洲航天局(European Space Agency, 简称ESA或欧空局)发布,全球范围免费提供30米和90米分辨率DEM。COP-DEM是数字表面模型(DSM),它表示地球表面(包括建筑物、基础设施和植被)的高程。COP-DEM是经过编辑的D…

循环神经网络模块介绍(Pytorch 12)

到目前为止,我们遇到过两种类型的数据:表格数据和图像数据。对于图像数据,我们设计了专门的卷积神经网络架构(cnn)来为这类特殊的数据结构建模。换句话说,如果我们拥有一张图像,我们 需要有效地利用其像素位置&#xf…

指针,解引用,空指针,野指针,常量指针(const+指针),指针常量(const+常量)

指针变量通过*操作符,操作指针变量指向的内存空间,被称为解引用。 所有指针类型在32位操作系统下是4个字节,64位是8个字节。 int a 10;int *p; // 定义了一个整型指针p。 *表示p是一个指针 p &a; // 将变量a的地址赋给指针p。 &是取地址运算符…

算法课程笔记——蓝桥云课第六次直播

(只有一个数,或者因子只有一个)先自己打表,找找规律函数就是2的n次方 异或前缀和 相等就抵消 先前缀和再二分

斐波那契数列,Java版本实现

斐波那契数列是一个著名的数列,其中每个数字(从第三个开始)是前两个数字的和。数列的前几个数字是 0, 1, 1, 2, 3, 5, 8, 13, 21, 34, … 等。下面是一个用Java实现的斐波那契数列的详细版本,包括递归方法、迭代方法以及一个优化的…

[HDLBits] Simple wire

Create a module with one input and one output that behaves like a wire. module top_module( input in, output out );assign out in; endmodule

【Python】机器学习之Sklearn基础教程大纲

机器学习之Sklearn基础教程大纲 1. 引言 机器学习简介Scikit-learn(Sklearn)库介绍安装和配置Sklearn 2. 数据预处理 2.1 数据加载与查看 - 加载CSV、Excel等格式的数据- 查看数据的基本信息(如形状、数据类型等)2.2 数据清洗…

本地部署大模型ollama+docker+open WebUI/Lobe Chat

文章目录 大模型工具Ollama下载安装运行Spring Ai 代码测试加依赖配置写代码 ollama的web&Desktop搭建部署Open WebUI有两种方式Docker DesktopDocker部署Open WebUIDocker部署Lobe Chat可以配置OpenAI的key也可以配置ollama 大模型的选择 本篇基于windows环境下配置 大模型…

翔云优配恒生指数涨1.85%、恒生科技指数涨3.74% 小鹏汽车涨超8%

5月3日港股开盘,恒生指数涨1.85%,报18543.3点,恒生科技指数涨3.74%,报4009.96点,国企指数涨2.23%,报6580.81点, 翔云优配是一家领先的在线投资平台,提供全球范围内的股票、期货、基金等交易服务…

小程序引入 Vant Weapp 极简教程

一切以 Vant Weapp 官方文档 为准 Vant Weapp 官方文档 - 快速入手 1. 安装nodejs 前往官网下载安装即可 nodejs官网 安装好后 在命令行(winr,输入cmd)输入 node -v若显示版本信息,即为安装成功 2. 在 小程序根目录 命令行/终端…

C++类的小结

1、类定义 使用class关键字定义类。 类名通常以大写字母开头,以符合命名规范。 类包含成员变量(也称为属性或数据成员)和成员函数(也称为方法或行为)。 class MyClass { public: int x; // 数据成员 void setX…

yolov5网络结构图要点和难点实际案例和代码解析

YOLOv5网络结构图主要可以分为四个部分:输入端(Input)、Backbone(主干网络)、Neck(颈部)和Prediction(输出端)。以下是对这四个部分的简要说明: 输入端(Input): 数据增强:YOLOv5在输入端使用了Mosaic数据增强技术,这是一种将四张训练图像混合成一张的方式,可以…