YOLOv5-7.0改进(四)添加EMA注意力机制

前言

关于网络中注意力机制的改进有很多种,本篇内容从EMA注意力机制开始!

往期回顾

YOLOv5-7.0改进(一)MobileNetv3替换主干网络

YOLOv5-7.0改进(二)BiFPN替换Neck网络

YOLOv5-7.0改进(三)添加损失函数EIoU、AlphaIoU、SIoU、WIoU、MPDIoU、NWD

目录

  • 一、EMA简介
  • 二、Neck端添加EMA
    • 第一步:在common.py中添加EMA模块
    • 第二步:在yolo.py中的parse_model函数加入类名
    • 第三步:制作模型配置文件
    • 第四步:验证新加入的Neck网络
  • 三、C3中添加EMA
    • 第一步:在common.py中添加EMA模块
    • 第二步:在yolo.py中的parse_model函数加入类名
    • 第三步:制作模型配置文件
    • 第四步:验证新加入的Neck网络

一、EMA简介

论文题目:Efficient Multi-Scale Attention Module with Cross-Spatial Learning

EMA注意力机制:基于跨空间学习的高效多尺度注意力机制,该模块首先将部分通道维度重塑为批量维度,以避免通用卷积进行某种形式的降维,接着在每个并行子网络中构建局部的跨通道交互,利用一种新的跨空间学习方法融合两个并行子网络的输出特征图,设计了一个多尺度并行子网络来建立长短依赖关系。

网络结构

在这里插入图片描述

二、Neck端添加EMA

第一步:在common.py中添加EMA模块

代码如下:

#EMA
class EMA(nn.Module):def __init__(self, channels, factor=8):super(EMA, self).__init__()self.groups = factor # 分组因子assert channels // self.groups > 0self.softmax = nn.Softmax(-1) #softmax操作self.agp = nn.AdaptiveAvgPool2d((1, 1)) # 1×1平均池化层self.pool_h = nn.AdaptiveAvgPool2d((None, 1)) # X平均池化层 h=1self.pool_w = nn.AdaptiveAvgPool2d((1, None)) # Y平均池化层 w=1self.gn = nn.GroupNorm(channels // self.groups, channels // self.groups) # 分组操作self.conv1x1 = nn.Conv2d(channels // self.groups, channels // self.groups, kernel_size=1, stride=1, padding=0) # 1×1卷积分支 self.conv3x3 = nn.Conv2d(channels // self.groups, channels // self.groups, kernel_size=3, stride=1, padding=1) # 3×3卷积分支def forward(self, x):b, c, h, w = x.size()group_x = x.reshape(b * self.groups, -1, h, w)  # b*g,c//g,h,wx_h = self.pool_h(group_x) # 得到平均池化之后的hx_w = self.pool_w(group_x).permute(0, 1, 3, 2) # 得到平均池化之后的whw = self.conv1x1(torch.cat([x_h, x_w], dim=2)) # 先拼接,然后送入1×1卷积x_h, x_w = torch.split(hw, [h, w], dim=2)x1 = self.gn(group_x * x_h.sigmoid() * x_w.permute(0, 1, 3, 2).sigmoid())x2 = self.conv3x3(group_x) # 3×3卷积分支x11 = self.softmax(self.agp(x1).reshape(b * self.groups, -1, 1).permute(0, 2, 1))x12 = x2.reshape(b * self.groups, c // self.groups, -1)  # b*g, c//g, hwx21 = self.softmax(self.agp(x2).reshape(b * self.groups, -1, 1).permute(0, 2, 1))x22 = x1.reshape(b * self.groups, c // self.groups, -1)  # b*g, c//g, hwweights = (torch.matmul(x11, x12) + torch.matmul(x21, x22)).reshape(b * self.groups, 1, h, w)return (group_x * weights.sigmoid()).reshape(b, c, h, w)

插入效果:

在这里插入图片描述

第二步:在yolo.py中的parse_model函数加入类名

将EMA类名添加到注册表中,效果如下:

在这里插入图片描述

第三步:制作模型配置文件

1、复制models/yolov5s.yaml文件,并重命名

在这里插入图片描述
2、将以下代码复制到新创建的yaml文件

# YOLOv5 🚀 by Ultralytics, GPL-3.0 license# Parameters
nc: 12  # number of classes
depth_multiple: 0.33  # model depth multiple
width_multiple: 0.50  # layer channel multiple
anchors:- [10,13, 16,30, 33,23]  # P3/8- [30,61, 62,45, 59,119]  # P4/16- [116,90, 156,198, 373,326]  # P5/32# YOLOv5 v6.0 backbone
backbone:# [from, number, module, args][[-1, 1, Conv, [64, 6, 2, 2]],  # 0-P1/2[-1, 1, Conv, [128, 3, 2]],  # 1-P2/4[-1, 3, C3, [128]],[-1, 1, Conv, [256, 3, 2]],  # 3-P3/8[-1, 6, C3, [256]],[-1, 1, Conv, [512, 3, 2]],  # 5-P4/16[-1, 9, C3, [512]],[-1, 1, Conv, [1024, 3, 2]],  # 7-P5/32[-1, 3, C3, [1024]],[-1, 1, SPPF, [1024, 5]],  # 9]# YOLOv5 v6.0 head
head:[[-1, 1, Conv, [512, 1, 1]],[-1, 1, nn.Upsample, [None, 2, 'nearest']],[[-1, 6], 1, Concat, [1]],  # cat backbone P4[-1, 3, C3, [512, False]],  # 13[-1, 1, Conv, [256, 1, 1]],[-1, 1, nn.Upsample, [None, 2, 'nearest']],[[-1, 4], 1, Concat, [1]],  # cat backbone P3[-1, 3, C3, [256, False]],  # 17 (P3/8-small)[-1, 1, EMA, [256]],  # 加入到小目标层后[-1, 1, Conv, [256, 3, 2]],[[-1, 14], 1, Concat, [1]],  # cat head P4[-1, 3, C3, [512, False]],  # 20 (P4/16-medium)[-1, 1, EMA, [512]],  # 加入到中目标层后[-1, 1, Conv, [512, 3, 2]],[[-1, 10], 1, Concat, [1]],  # cat head P5[-1, 3, C3, [1024, False]],  # 23 (P5/32-large)[-1, 1, EMA, [1024]],  # 加入到大目标层后[[18, 22, 26], 1, Detect, [nc, anchors]],  # Detect(P3, P4, P5)]

第四步:验证新加入的Neck网络

1、修改yolo.py中以下两个地方

(1)DetectionModel函数下的cfg

在这里插入图片描述
(2)parser = argparse.ArgumentParser()下的cfg

在这里插入图片描述
2、运行yolo.py

(1)yolov5s_EMA.yaml

在这里插入图片描述

好了,到这一步在Neck端添加EMA基本完成,接下就可以开始训练~

三、C3中添加EMA

第一步:在common.py中添加EMA模块

代码如下:

#EMA
class EMA(nn.Module):def __init__(self, channels, factor=8):super(EMA, self).__init__()self.groups = factor # 分组率assert channels // self.groups > 0self.softmax = nn.Softmax(-1) # Softmaxself.agp = nn.AdaptiveAvgPool2d((1, 1)) # 平均池化层self.pool_h = nn.AdaptiveAvgPool2d((None, 1)) # x平均池化层 h=1self.pool_w = nn.AdaptiveAvgPool2d((1, None)) # y平均池化层 w=1self.gn = nn.GroupNorm(channels // self.groups, channels // self.groups) # 分组操作self.conv1x1 = nn.Conv2d(channels // self.groups, channels // self.groups, kernel_size=1, stride=1, padding=0) # 1×1卷积分支self.conv3x3 = nn.Conv2d(channels // self.groups, channels // self.groups, kernel_size=3, stride=1, padding=1) # 3×3卷积分支def forward(self, x):b, c, h, w = x.size()group_x = x.reshape(b * self.groups, -1, h, w)  # b*g,c//g,h,wx_h = self.pool_h(group_x)x_w = self.pool_w(group_x).permute(0, 1, 3, 2)hw = self.conv1x1(torch.cat([x_h, x_w], dim=2))x_h, x_w = torch.split(hw, [h, w], dim=2)x1 = self.gn(group_x * x_h.sigmoid() * x_w.permute(0, 1, 3, 2).sigmoid())x2 = self.conv3x3(group_x)x11 = self.softmax(self.agp(x1).reshape(b * self.groups, -1, 1).permute(0, 2, 1))x12 = x2.reshape(b * self.groups, c // self.groups, -1)  # b*g, c//g, hwx21 = self.softmax(self.agp(x2).reshape(b * self.groups, -1, 1).permute(0, 2, 1))x22 = x1.reshape(b * self.groups, c // self.groups, -1)  # b*g, c//g, hwweights = (torch.matmul(x11, x12) + torch.matmul(x21, x22)).reshape(b * self.groups, 1, h, w)return (group_x * weights.sigmoid()).reshape(b, c, h, w)class C3_EMA3(nn.Module):# CSP Bottleneck with 3 convolutionsdef __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):  # ch_in, ch_out, number, shortcut, groups, expansionsuper().__init__()c_ = int(c2 * e)  # hidden channelsself.cv1 = Conv(c1, c_, 1, 1)self.cv2 = Conv(c1, c_, 1, 1)self.cv3 = Conv(2 * c_, c2, 1)  # optional act=FReLU(c2)self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)))self.m1 = nn.ModuleList([EMA(2 * c_)])  # 添加在最后一个卷积之前def forward(self, x):return self.cv3(self.m1[0](torch.cat((self.m(self.cv1(x)), self.cv2(x)), 1)))class C3_EMA2(nn.Module):# CSP Bottleneck with 3 convolutionsdef __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):  # ch_in, ch_out, number, shortcut, groups, expansionsuper().__init__()c_ = int(c2 * e)  # hidden channelsself.cv1 = Conv(c1, c_, 1, 1)self.cv2 = Conv(c1, c_, 1, 1)self.cv3 = Conv(2 * c_, c2, 1)  # optional act=FReLU(c2)self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)))self.m1 = nn.ModuleList([EMA(c1)])  # 添加在最后一个卷积之前def forward(self, x):return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(self.m1[0](x))), 1))class C3_EMA1(nn.Module):# CSP Bottleneck with 3 convolutionsdef __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):  # ch_in, ch_out, number, shortcut, groups, expansionsuper().__init__()c_ = int(c2 * e)  # hidden channelsself.cv1 = Conv(c1, c_, 1, 1)self.cv2 = Conv(c1, c_, 1, 1)self.cv3 = Conv(2 * c_, c2, 1)  # optional act=FReLU(c2)self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)))self.m1 = nn.ModuleList([EMA(c_)])  # 添加在最后一个卷积之前def forward(self, x):return self.cv3(torch.cat((self.m(self.m1[0](self.cv1(x))), self.cv2(x)), 1))

效果如下:

在这里插入图片描述

第二步:在yolo.py中的parse_model函数加入类名

将以下类名添加到注册表中

EMA, C3_EMA1, C3_EMA2, C3_EMA3

效果如下:
在这里插入图片描述

第三步:制作模型配置文件

将以下代码复制到yaml文件中

# YOLOv5 🚀 by Ultralytics, GPL-3.0 license# Parameters
nc: 12  # number of classes
depth_multiple: 0.33  # model depth multiple
width_multiple: 0.50  # layer channel multiple
anchors:- [10,13, 16,30, 33,23]  # P3/8- [30,61, 62,45, 59,119]  # P4/16- [116,90, 156,198, 373,326]  # P5/32# YOLOv5 v6.0 backbone
backbone:# [from, number, module, args][[-1, 1, Conv, [64, 6, 2, 2]],  # 0-P1/2[-1, 1, Conv, [128, 3, 2]],  # 1-P2/4[-1, 3, C3_EMA1, [128]],[-1, 1, Conv, [256, 3, 2]],  # 3-P3/8[-1, 6, C3, [256]],[-1, 1, Conv, [512, 3, 2]],  # 5-P4/16[-1, 9, C3, [512]],[-1, 1, Conv, [1024, 3, 2]],  # 7-P5/32[-1, 3, C3, [1024]],[-1, 1, SPPF, [1024, 5]],  # 9]# YOLOv5 v6.0 head
head:[[-1, 1, Conv, [512, 1, 1]],[-1, 1, nn.Upsample, [None, 2, 'nearest']],[[-1, 6], 1, Concat, [1]],  # cat backbone P4[-1, 3, C3, [512, False]],  # 13[-1, 1, Conv, [256, 1, 1]],[-1, 1, nn.Upsample, [None, 2, 'nearest']],[[-1, 4], 1, Concat, [1]],  # cat backbone P3[-1, 3, C3, [256, False]],  # 17 (P3/8-small)[-1, 1, Conv, [256, 3, 2]],[[-1, 14], 1, Concat, [1]],  # cat head P4[-1, 3, C3, [512, False]],  # 20 (P4/16-medium)[-1, 1, Conv, [512, 3, 2]],[[-1, 10], 1, Concat, [1]],  # cat head P5[-1, 3, C3, [1024, False]],  # 23 (P5/32-large)[[17, 20, 23], 1, Detect, [nc, anchors]],  # Detect(P3, P4, P5)]

第四步:验证新加入的Neck网络

1、运行yolo.py

在这里插入图片描述
接下来也是对这个模型进行训练,需要注意的是这是在主干网络部分改进~

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/10508.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【智能算法】鹭鹰优化算法(SBOA)原理及实现

目录 1.背景2.算法原理2.1算法思想2.2算法过程 3.结果展示4.参考文献5.代码获取 1.背景 2024年,Y Fu受到自然界中鹭鹰生存行为启发,提出了鹭鹰优化算法(Secretary Bird Optimization Algorithm, SBOA)。 2.算法原理 2.1算法思想…

【声呐仿真】学习记录2-运行ROV(带camera、sonar、dvl等传感器)例程

【声呐仿真】学习记录2-运行ROV(带camera、sonar、dvl等传感器)例程 前言第一阶段-学习Gazebo第二阶段-学习URDF、xacro第三阶段-寻找例程跑一个rexrov示例程序1.uuvsimulator quick_start2.能键盘控制的示例程序(失败)3.能键盘控…

Excel如何设置密码保护【图文详情】

文章目录 前言一、Excel如何设置密码保护?二、Excel如何取消密码保护?总结 前言 在软件项目开发过程中,会输出很多技术文档,其中也包括保密级别很高的服务器账号Excel文档。为了确保服务器账号相关的Excel文档的安全性&#xff0…

CSS 网格布局一行X个排列

<div class"icon-box"><divv-for"(item,index) in icon" :key"index" class"icon"style"cursor: pointer">{{item}}</div></div>.icon-box{display: grid; /**网格布局*/grid-template-columns: r…

NSS题目练习2

[LitCTF 2023]我Flag呢&#xff1f; 打开题目后查看源码即可发现flag [第五空间 2021]WebFTP 看到提示&#xff0c;首先想到用dirsearch扫描链接&#xff0c;看是否存在git泄露 发现存在git泄露&#xff0c;用githack解决 克隆提示目录为空&#xff0c;说明不正确&#xff0c…

容器监控与日志管理

前言&#xff1a;本博客仅作记录学习使用&#xff0c;部分图片出自网络&#xff0c;如有侵犯您的权益&#xff0c;请联系删除 一、Docker监控工具 二、容器日志工具docker logs 三、第三方日志工具 四、容器日志驱动 五、示例 5.1、查看容器中运行的进程的信息 5.2、查看…

液晶显示模块强光实验类目及太阳光模拟器

科技日新月异&#xff0c;液晶显示模块运用得也越来越广泛&#xff0c;用户在购买和使用时&#xff0c;都希望能买到显示效果好&#xff0c;性价比高的产品。本文主要介绍LCM&#xff0f;LED模块在光学方面主要测试项目类别及实验仪器。 测试项目类别 1. 透过率 透过率是指透…

英语学习笔记7——Are you a teacher?

Are you a teacher? 你是教师吗&#xff1f; 词汇 Vocabulary name /neɪm/ n. 名字&#xff0c;名声 英文名字构成&#xff1a; 名 字 姓      given name family name  也叫做&#xff1a;first name last name      例&#xff1a;Yanyan Gao 例句&#xff1…

学习网络安全现在还有前景吗?行业分析报告

如果你现阶段选择入行网络安全&#xff0c;就相当于10年前学IT&#xff0c;当它发展起来的时候&#xff0c;你刚好遇到行业红利期。 网络安全这个职业完全可以改变很多人的人生轨迹。 因为它是个不需要你有多强大的情商&#xff0c;不需要你去学习更多复杂的职场和人际关系技…

halcon获取Licenses--每月一换

转到https://www.51halcon.com/ 点击授权&#xff0c;根据你的版本选择progress或者steady进行下载 记住每月一换哦

2024年小程序视频如何下载到电脑上

随着2024年的到来&#xff0c;将小程序视频无缝下载到电脑上&#xff0c;从此让精彩内容触手可及&#xff0c;不受时间和网络的限制&#xff0c;随时随地启发你的生活和工作。 小程序视频我已经打包好了&#xff0c;有需要的自己下载 小程序视频下载工具打包链接&#xff1a;…

AI数据中心网络技术选型,InfiniBand与RoCE对比分析

InfiniBand与RoCE对比分析&#xff1a;AI数据中心网络选择指南 随着 AI 技术的蓬勃发展&#xff0c;其对数据中心网络的要求也日益严苛。低延迟、高吞吐量的网络对于处理复杂的数据密集型工作负载至关重要。本文分析了 InfiniBand 和 RoCE 两种数据中心网络技术&#xff0c;帮助…

付费文章合集第二期

☞☞付费文章合集第一期 感谢大家一年来的陪伴与支持&#xff01; 对于感兴趣的文章点标题能跳转原文阅读啦~~ 21、Matlab信号处理——基于LSB和DCB音频水印嵌入提取算法 22、CV小目标识别——AITOD数据集&#xff08;已处理&#xff09; 23、Matlab信号发生器——三角波、…

【Redis】Redis 事务

Redis 的事务的本质是一组命令的批处理。这组命令在执行过程中会被顺序地、一次性 全部执行完毕&#xff0c;只要没有出现语法错误&#xff0c;这组命令在执行期间不会被中断 1.事务特性 仅保证了数据的一致性 这组命令中的某些命令的执行失败不会影响其它命令的执行&#xff…

【JVM】ASM开发

认识ASM ASM是一个Java字节码操纵框架&#xff0c;它能被用来动态生成类或者增强既有类的功能。 ASM可以直接产生二进制class文件&#xff0c;也可以在类被加载入虚拟机之前动态改变类行为&#xff0c;ASM从类文件中读入信息后能够改变类行为&#xff0c;分析类信息&#xff…

课程设计 大学生竞赛系统

课程设计 大学生竞赛系统 wx:help-assignment 学生用户&#xff1a; wx:help-assignment 首页&#xff1a;推荐一些竞赛&#xff0c;热门活动等&#xff1b; 广场&#xff1a;用户可以通过广场来发表动态&#xff0c;同时也可以查看别人发布的动态&#xff0c;并且可以 关注…

解决常见的Android问题

常见问题&#xff1a; 1、查杀&#xff1a; 查杀一般分为两个方向一种是内存不足的查杀&#xff0c;一种的是因为温度限频查杀&#xff0c;统称为内存查杀&#xff0c;两个问题的分析思路不同 1、内存不足查杀&#xff1a; 主要是因为当用户出现后台运行多个APP或者是相机等…

汇昌联信科技:拼多多可以做无货源吗?

在探讨电商平台的经营模式时&#xff0c;"无货源"这一概念经常被提及。它指的是卖家在不需要事先囤积大量商品的情况下&#xff0c;通过与供应商的合作&#xff0c;直接将订单信息传递给他们&#xff0c;由供应商完成发货的过程。针对“拼多多可以做无货源吗?”这一…

内网渗透之如何批量PTH获取主机权限?

—— 利用CrakMapExec工具进行全网段批量PTH CrackMapExec&#xff08;CME&#xff09;是一款后渗透利用工具&#xff0c;可帮助自动化大型活动目录(AD)网络安全评估任务。其缔造者byt3bl33d3r称&#xff0c;该工具的生存概念是&#xff0c;“利用AD内置功能/协议达成其功能&…

【练习2】

1.汽水瓶 ps:注意涉及多个输入&#xff0c;我就说怎么老不对&#xff0c;无语~ #include <cmath> #include <iostream> using namespace std;int main() {int n;int num,flag,kp,temp;while (cin>>n) {flag1;num0;temp0;kpn;while (flag1) {if(kp<2){if(…