YOLO算法改进5【中阶改进篇】:添加SENet注意力机制

在这里插入图片描述

SE-Net 是 ImageNet 2017(ImageNet 收官赛)的冠军模型,是由WMW团队发布。具有复杂度低,参数少和计算量小的优点。且SENet 思路很简单,很容易扩展到已有网络结构如 Inception 和 ResNet 中。
我们可以看到,已经有很多工作在空间维度上来提升网络的性能。那么很自然想到,网络是否可以从其他层面来考虑去提升性能,比如考虑特征通道之间的关系?作者基于这一点并提出了Squeeze-and-Excitation Networks(简称SE-Net)。在该结构中,SqueezeExcitation是两个非常关键的操作,所以以此来命名。作者出发点是希望建立特征通道之间的相互依赖关系。并未引入一个新的空间维度来进行特征通道间的融合,而是采用了一种全新的“特征重标定”策略。具体来说,就是通过学习的方式来自动获取到每个特征通道的重要程度,然后依照这个重要程度去提升有用的特征并抑制对当前任务用处不大的特征。

一、不改变原网络深度的改进方法

在这里插入图片描述
首先是打开models/yolov5s.yaml文件,我们在backbone中的SPPF之前增加SENet。增添位置如下,是将backbone中第4个C3模块替换为SE_Block,如上图。需要注意的是通道数要匹配,SENet并不改变通道数,由于原C3的输出通道数为1024*0.5=512,所以我们这里的写的是1024,这里的1024是传入到上面我们定义的Class SE_Block(nn.Moudel)中的c2参数,c1参数是由上一层的输出通道数控制的。参考链接

1.添加SENet.yaml文件
添加至/models/文件中

# Parameters
nc: 80  # number of classes
depth_multiple: 0.33  # model depth multiple
width_multiple: 0.50  # layer channel multiple
anchors:- [10,13, 16,30, 33,23]  # P3/8- [30,61, 62,45, 59,119]  # P4/16- [116,90, 156,198, 373,326]  # P5/32# YOLOv5 v6.0 backbone
backbone:# [from, number, module, args][[-1, 1, Conv, [64, 6, 2, 2]],  # 0-P1/2  conv1(3,32,k=6,s=2,p=2)[-1, 1, Conv, [128, 3, 2]],  # 1-P2/4  conv2(32,64,k=3,s=2,p=1)[-1, 3, C3, [128]],  # C3_1 有Bottleneck[-1, 1, Conv, [256, 3, 2]],  # 3-P3/8  conv3(64,128,k=3,s=2,p=1)[-1, 6, C3, [256]], # C3_2 Bottleneck重复两次[-1, 1, Conv, [512, 3, 2]],  # 5-P4/16  conv4(128,256,k=3,s=2,p=1)[-1, 9, C3, [512]], # C3_3 Bottleneck重复三次 输出256通道[-1, 1, Conv, [1024, 3, 2]],  # 7-P5/32   Conv5(256,512,k=3,s=2,p=1)#[-1, 3, C3, [1024]],  # C3_4 Bottleneck重复1次  输出512通道[-1, 1, SE_Block, [1024]],  # 增加通道注意力机制 输出为512通道[-1, 1, SPPF, [1024, 5]],  # 9  每个都是K为5的池化]# YOLOAir v6.0 head
head:[[-1, 1, Conv, [512, 1, 1]],[-1, 1, nn.Upsample, [None, 2, 'nearest']],[[-1, 6], 1, Concat, [1]],  # cat backbone P4[-1, 3, C3, [512, False]],  # 13[-1, 1, Conv, [256, 1, 1]],[-1, 1, nn.Upsample, [None, 2, 'nearest']],[[-1, 4], 1, Concat, [1]],  # cat backbone P3[-1, 3, C3, [256, False]],  # 17 (P3/8-small)[-1, 1, Conv, [256, 3, 2]],[[-1, 14], 1, Concat, [1]],  # cat head P4[-1, 3, C3, [512, False]],  # 20 (P4/16-medium)[-1, 1, Conv, [512, 3, 2]],[[-1, 10], 1, Concat, [1]],  # cat head P5[-1, 3, C3, [1024, False]],  # 23 (P5/32-large)[[17, 20, 24], 1, Detect, [nc, anchors]],  # Detect(P3, P4, P5)]

2.common配置
在models/common.py文件中增加以下代码
在这里插入图片描述

  • 上图是作者提出的SE模块的示意图。给定一个输入 x x x,其特征通道数为 c 1 c_1 c1,通过一系列卷积变换后得到一个特征通道数为 c 2 c_2 c2的特征。与传统的CNN不一样的是,接下来将通过三个操作来重标定前面得到的特征。
  • 首先是Squeeze操作,顺着空间维度来进行特征压缩,将每个二维的特征通道变成一个实数,这个实数某种程度上具有全局的感受野,并且输出的维度和输入的特征通道数相匹配。它表征着在特征通道上响应的全局分布,而且使得靠近输入的层也可以获得全局的感受野,这一点在很多任务中都是非常有用的。
  • 其次是Excitation操作,它是一个类似于循环神经网络中门的机制。通过参数来为每个特征通道生成权重,其中参数被学习用来显式地建模特征通道间的相关性。
  • 最后是一个Reweight的操作,我们将Excitation的输出的权重看做是进过特征选择后的每个特征通道的重要性,然后通过乘法逐通道加权到先前的特征上,完成在通道维度上的对原始特征的重标定。
    ——————————————————————————————————————————
    图2 SE模块应用举例
  • 这里的注意力机制想法非常简单,即针对每一个 channel 进行池化处理,就得到了 channel
    个元素,通过两个全连接层,得到输出的这个向量。值得注意的是,第一个全连接层的节点个数等于 channel 个数的 1 4 \frac{1}{4} 41论文作者发现如果将第一个全连接层的节点个数替换成原来的 1 4 \frac{1}{4} 41,可以在参数数量适度增加的情况下提高准确性,而且并没有明显的延迟。),然后第二个全连接层的节点就和channel 保持一致。这个得到的输出就相当于对原始的特征矩阵的每个 channel 分析其重要程度,越重要的赋予越大的权重,越不重要的就赋予越小的权重。
  • 就拿上图来说,首先对四个通道进行平均池化得到四个值,然后经过两个全连接层之后得到通道权重的输出。等权重输出以后,则将对应通道的权重乘以原来的特征矩阵就得到了新的特征矩阵,以上便是SE模块的详细实现过程。
class SE_Block(nn.Module):def __init__(self, c1, c2):super().__init__()self.avg_pool = nn.AdaptiveAvgPool2d(1)  # 平均池化self.fc = nn.Sequential(nn.Linear(c1, c2 // 16, bias=False),nn.ReLU(inplace=True),nn.Linear(c2 // 16, c2, bias=False),nn.Sigmoid())def forward(self, x):# 添加注意力模块b, c, _, _ = x.size()  # 分别获取batch_size,channely = self.avg_pool(x).view(b, c)  # y的shape为【batch_size, channels】y = self.fc(y).view(b, c, 1, 1)  # shape为【batch_size, channels, 1, 1】out = x * y.expand_as(x)  # shape 为【batch, channels,feature_w, feature_h】return out

3.yolo.py配置
找到 models/yolo.py 文件中 parse_model() 类,在列表中添加SE_Block,这样可以获得我们要传入的参数。

if m in [Conv, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, MixConv2d, Focus, CrossConv,BottleneckCSP, C3, C3TR, C3SPP, C3Ghost, SE_Block]:

二、改变原网络深度的改进方法

在这里插入图片描述
在这里插入图片描述
比如我要在第一个C3后面加一个SE。yaml的修改如下。接下来稍微麻烦一点了【需要你了解v5的每层结构】,由于我们在backbone中加入了一层,也就是相当于后面的网络与之前相比都往后移动了一层,那么在后面的Concat部分中融合的特征层的索引也会收到影响,因此我们需要的是修改Concat层的from参数。参考链接

1.添加SENet.yaml文件
添加至/models/文件中

# Parameters
nc: 80  # number of classes
depth_multiple: 0.33  # model depth multiple
width_multiple: 0.50  # layer channel multiple
anchors:- [10,13, 16,30, 33,23]  # P3/8- [30,61, 62,45, 59,119]  # P4/16- [116,90, 156,198, 373,326]  # P5/32# YOLOv5 v6.0 backbone
backbone:# [from, number, module, args][[-1, 1, Conv, [64, 6, 2, 2]],  # 0-P1/2  conv1(3,32,k=6,s=2,p=2)[-1, 1, Conv, [128, 3, 2]],  # 1-P2/4  conv2(32,64,k=3,s=2,p=1)[-1, 3, C3, [128]],  # C3_1 有Bottleneck[-1, 1, SE_Block, [128]],  # 增加通道注意力机制 输出为512通道[-1, 1, Conv, [256, 3, 2]],  # 3-P3/8  conv3(64,128,k=3,s=2,p=1)[-1, 6, C3, [256]], # C3_2 Bottleneck重复两次[-1, 1, Conv, [512, 3, 2]],  # 5-P4/16  conv4(128,256,k=3,s=2,p=1)[-1, 9, C3, [512]], # C3_3 Bottleneck重复三次 输出256通道[-1, 1, Conv, [1024, 3, 2]],  # 7-P5/32   Conv5(256,512,k=3,s=2,p=1)[-1, 3, C3, [1024]],  # C3_4 Bottleneck重复1次  输出512通道[-1, 1, SPPF, [1024, 5]],  # 9  每个都是K为5的池化]
"""可以看到实际就是每个Concat也后面移动一层,因此yaml修改为一下。最终的Detect的from也需要修改。""
# YOLOAir v6.0 head
head:[[-1, 1, Conv, [512, 1, 1]],  # conv1(512,256,1,1)[-1, 1, nn.Upsample, [None, 2, 'nearest']],[[-1, 7], 1, Concat, [1]],  # cat backbone P4  将C3_3与SPPF出来后的上采样拼接 拼接后的通道为512[-1, 3, C3, [512, False]],  # 13  conv(256,256,k=1,s=1)  没有残差边[-1, 1, Conv, [256, 1, 1]], # conv2(256,128,1,1)[-1, 1, nn.Upsample, [None, 2, 'nearest']],[[-1, 5], 1, Concat, [1]],  # cat backbone P3  与C3_2拼接,输出256通道[-1, 3, C3, [256, False]],  # 17 (P3/8-small) conv3(128,128,1,1)[-1, 1, Conv, [256, 3, 2]],# conv4(128,128,3,2,1)[[-1, 15], 1, Concat, [1]],  # cat head P4  拼接后256通道[-1, 3, C3, [512, False]],  # 20 (P4/16-medium)  conv5(256,256,1,1)[-1, 1, Conv, [512, 3, 2]],# conv6(256,256,3,2,1)[[-1, 11], 1, Concat, [1]],  # cat head P5  拼接后是512[-1, 3, C3, [1024, False]],  # 23 (P5/32-large)[[18, 21, 24], 1, Detect, [nc, anchors]],  # Detect(P3, P4, P5)]

2.common配置
在models/common.py文件中增加以下代码

class SE_Block(nn.Module):def __init__(self, c1, c2):super().__init__()self.avg_pool = nn.AdaptiveAvgPool2d(1)  # 平均池化self.fc = nn.Sequential(nn.Linear(c1, c2 // 16, bias=False),nn.ReLU(inplace=True),nn.Linear(c2 // 16, c2, bias=False),nn.Sigmoid())def forward(self, x):# 添加注意力模块b, c, _, _ = x.size()  # 分别获取batch_size,channely = self.avg_pool(x).view(b, c)  # y的shape为【batch_size, channels】y = self.fc(y).view(b, c, 1, 1)  # shape为【batch_size, channels, 1, 1】out = x * y.expand_as(x)  # shape 为【batch, channels,feature_w, feature_h】return out

3.yolo.py配置
找到 models/yolo.py 文件中 parse_model() 类,在列表中添加SE_Block,这样可以获得我们要传入的参数。

if m in [Conv, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, MixConv2d, Focus, CrossConv,BottleneckCSP, C3, C3TR, C3SPP, C3Ghost, SE_Block]:

4.训练模型

python train.py --cfg SENet.yaml

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/130082.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

java配置GDAL

<gdal.version>3.7.0</gdal.version><!-- gdal--><dependency><groupId>org.gdal</groupId><artifactId>gdal</artifactId><version>${gdal.version}</version></dependency> GDAL环境安装 downlo…

PHP进销存ERP系统源码

PHP进销存ERP系统源码 系统介绍&#xff1a; 扫描入库库存预警仓库管理商品管理供应商管理。 1、电脑端手机端&#xff0c;手机实时共享&#xff0c;手机端一目了然。 2、多商户Saas营销版 无限开商户&#xff0c;用户前端自行注册&#xff0c;后台管理员审核开通 3、管理…

HarmonyOS鸿蒙原生应用开发设计- 元服务(原子化服务)图标

HarmonyOS设计文档中&#xff0c;为大家提供了独特的元服务图标&#xff0c;开发者可以根据需要直接引用。 开发者直接使用官方提供的元服务图标内容&#xff0c;既可以符合HarmonyOS原生应用的开发上架运营规范&#xff0c;又可以防止使用别人的元服务图标侵权意外情况等&…

ROS学习笔记(4):ROS架构和通讯机制

前提 前4篇文章以及帮助大家快速入门ROS了&#xff0c;而从第5篇开始我们会更加注重知识积累。同时我强烈建议配合B站大学的视频一起服用。 1.ROS架构三层次&#xff1a; 1.基于Linux系统的OS层&#xff1b; 2.实现ROS核心通信机制以及众多机器人开发库的中间层&#xff1b…

提升ChatGPT答案质量和准确性的方法Prompt engineering

文章目录 怎么获得优质的答案设计一个优质prompt的步骤:Prompt公式:示例怎么获得优质的答案 影响模型回答精确度的因素 我们应该知道一个好的提示词,要具备一下要点: 清晰简洁,不要有歧义; 有明确的任务/问题,任务如果太复杂,需要拆分成子任务分步完成; 确保prompt中…

ElasticSearch集群环境搭建

1、准备三台服务器 这里准备三台服务器如下: IP地址主机名节点名192.168.225.65linux1node-1192.168.225.66linux2node-2192.168.225.67linux3node-3 2、准备elasticsearch安装环境 (1)编辑/etc/hosts&#xff08;三台服务器都执行&#xff09; vim /etc/hosts 添加如下内…

硬盘坏道检测修复工具下载,仅支持机械盘

硬盘坏道检测修复工具下载&#xff0c;仅支持机械盘 下载路径&#xff0c;最下方官网——软件下载——常用工具下载——硬盘坏道修复工具硬盘检测修复工具 【软件试用版下载、软件资讯或技术支持服务可点击文章最下方官网】

NLP 模型中的偏差和公平性检测

一、说明 近年来&#xff0c;自然语言处理 &#xff08;NLP&#xff09; 模型广受欢迎&#xff0c;彻底改变了我们与文本数据交互和分析的方式。这些基于深度学习技术的模型在广泛的应用中表现出了卓越的能力&#xff0c;从聊天机器人和语言翻译到情感分析和文本生成。然而&…

谷歌推出基于AI的产品图像生成工具;[微软免费课程:12堂课入门生成式AI

&#x1f989; AI新闻 &#x1f680; 谷歌推出基于AI的产品图像生成工具&#xff0c;帮助商家提升广告创意能力 摘要&#xff1a;谷歌推出了一套基于AI的产品图像生成工具&#xff0c;使商家能够利用该工具免费创建新的产品图像。该工具可以帮助商家进行简单任务&#xff08;…

MySQL---搜索引擎

MySQL的存储引擎是什么 MySQL当中数据用各种不同的技术存储在文件中&#xff0c;每一种技术都使用不同的存储机制&#xff0c;索引技巧 锁定水平&#xff0c;以及最终提供的不同的功能和能力&#xff0c;这些就是我们说的存储引擎。 MySQL存储引擎的功能 1.MySQL将数据存储在文…

【leetcode】88. 合并两个有序数组(图解)

目录 1. 思路&#xff08;图解&#xff09;2. 代码 题目链接&#xff1a;leetcode 88. 合并两个有序数组 题目描述&#xff1a; 1. 思路&#xff08;图解&#xff09; 思路一&#xff1a;&#xff08;不满足题目要求&#xff09; 1. 创建一个大小为nums1和nums2长度之和的…

leetCode 494. 目标和 + 动态规划 + 记忆化搜索 + 递推 + 空间优化

关于本题我的往期文章&#xff1a; LeetCode 494.目标和 &#xff08;动态规划 性能优化&#xff09;二维数组 压缩成 一维数组_呵呵哒(&#xffe3;▽&#xffe3;)"的博客-CSDN博客https://heheda.blog.csdn.net/article/details/133253822 给你一个非负整数数组 nums…

mysql:B+树/事务

B树 : 为了数据库量身定做的数据结构 我们当前这里的讨论都是围绕 mysql 的 innodb 这个存储引擎来讨论的 其他存储引擎可能会用到hash 作为索引,此时就只能应对这种精准匹配的情况了 要了解 B树 我们先了解 B树, B树 是 B树 的改进 B树 有时候会写作 B-树 (这里的" -…

axios 实现请求重试

前景提要&#xff1a; ts 简易封装 axios&#xff0c;统一 API 实现在 config 中配置开关拦截器 请求重试的核心是可以重放请求&#xff0c;具体实现就是在 axios 中&#xff0c;拿到当前请求的 config 对象&#xff0c;再用 axios 实例&#xff0c;就能重放请求。 在无感刷新…

【WinForm详细教程七】WinForm中的DataGridView控件

文章目录 1.主要属性DataSource行&#xff08;Row 相关属性&#xff09;列&#xff08;Column 相关属性&#xff09;单元格&#xff08;Cell 相关属性&#xff09;逻辑删除AllowUserToAddRowsAllowUserToDeleteRowsAllowUserToOrderColumns其他布局和行为属性 2.控件中的行、列…

PHP foreach 循环跳过本次循环

$a [[id>1],[id>2],[id>3],[id>4],[id>5],[id>6],[id>7],[id>18],];foreach($a as $v){if($v[id] 5){continue;}$b[] $v[id];}return show_data(,$b); 结果&#xff1a;

ASTM F963-23美国玩具安全新标准发布

新标准发布 2023年10月13日&#xff0c;美国材料与试验协会&#xff08;ASTM&#xff09;发布了新版玩具安全标准ASTM F963-23。 主要更新内容 与ASTM F963-17相比&#xff0c;此次更新包括&#xff1a;单独描述了基材重金属元素的豁免情况&#xff0c;更新了邻苯二甲酸酯的管控…

上班族必备:制作电子宣传册的网站

​对于上班族来说&#xff0c;制作电子宣传册是一项非常重要的技能。因为宣传册是展示公司形象、产品特点、服务优势的重要工具&#xff0c;也是与客户沟通交流的重要手段。那么&#xff0c;如何制作一份高质量的电子宣传册呢&#xff1f;今天就为大家推荐几个制作电子宣传册的…

如何让 Bean 深度感知 Spring 容器

Spring 有一个特点&#xff0c;就是创建出来的 Bean 对容器是无感的&#xff0c;一个 Bean 是怎么样被容器从一个 Class 整成一个 Bean 的&#xff0c;对于 Bean 本身来说是不知道的&#xff0c;当然也不需要知道&#xff0c;也就是 Bean 对容器的存在是无感的。 但是有时候我…

【ChatGLM2-6B】P-Tuning训练微调

机器配置 阿里云GPU规格ecs.gn6i-c4g1.xlargeNVIDIA T4显卡*1GPU显存16G*1 准备训练数据 进入/ChatGLM-6B/ptuningmkdir AdvertiseGencd AdvertiseGen上传 dev.json 和 train.json内容都是 {"content": "你是谁", "summary": "你好&…