YOLOv8 | 注意力机制 | ShuffleAttention注意力机制 提升检测精度

YOLOv8成功添加ShuffleAttention


⭐欢迎大家订阅我的专栏一起学习⭐

🚀🚀🚀订阅专栏,更新及时查看不迷路🚀🚀🚀
       YOLOv5涨点专栏:http://t.csdnimg.cn/1Aqzu

YOLOv8涨点专栏:http://t.csdnimg.cn/jMjHb

YOLOv7专栏:http://t.csdnimg.cn/yhXBl

💡魔改网络、复现论文、优化创新💡

 

目录

原理 

代码实现

yaml文件实现

完整代码分享

启动命令

注意事项


注意力机制使神经网络能够准确地关注输入的所有相关元素,已成为提高深度神经网络性能的重要组成部分。计算机视觉研究中广泛使用的注意力机制主要有两种:空间注意力和通道注意力,其目的分别是捕获像素级的成对关系和通道依赖性。虽然将它们融合在一起可能会比它们单独的实现获得更好的性能,但它不可避免地会增加计算开销。高效的ShuffleAttention(SA)模块可以解决这个问题,它采用ShuffleAttention单元来有效地结合两种类型的注意机制。具体来说,SA 首先将通道维度分组为多个子特征,然后并行处理它们。然后,对于每个子特征,SA 利用洗牌单元来描述空间和通道维度上的特征依赖性。之后,所有子特征被聚合,并采用“通道洗牌”算子来实现不同子特征之间的信息通信。 SA 模块高效且有效,例如,SA 针对主干 ResNet50 的参数和计算量分别为 300 vs. 25.56M 和 2.76e-3 GFLOPs vs. 4.12 GFLOPs,并且性能提升超过 1.34% Top-1 准确度方面。

使用 ResNets 作为主干的 ImageNet-1k 上最近的 SOTA 注意力模型(包括 SENet、CBAM、ECA-Net、SGE-Net 和 SA-Net)在准确性、网络参数和 GFLOP 方面的比较。圆圈的大小表示 GFLOP。显然,所提出的 SA-Net 实现了更高的精度,同时模型复杂度更低 

原理 

首先介绍构建SA模块的过程,该模块将输入特征图分组,并使用Shuffle Unit将通道注意力和空间注意力整合到每个组的一个块中。之后,所有子特征被聚合,并利用“通道洗牌”算子来实现不同子特征之间的信息通信。然后,我们展示如何在深度 CNN 中采用 SA。最后,我们可视化效果并验证所提出的 SA 的可靠性。 SA模块整体架构如图所示

Shuffle Attention结构图

它采用“通道分割”并行处理各组的子特征。对于通道注意力分支,使用 GAP 生成通道统计量,然后使用一对参数来缩放和移动通道向量。对于空间注意力分支,采用群范数生成空间统计量,然后创建类似于通道分支的紧凑特征。然后将两个分支连接起来。之后,所有子特征被聚合,最后我们利用“通道洗牌”运算符来实现不同子特征之间的信息通信。 

完全捕获通道依赖性的一个选项是利用SE块。然而,它会带来太多的参数,这不利于在速度和准确性之间进行权衡,设计更轻量级的注意力模块。此外,像 ECA 一样,不适合通过执行大小为 k 的更快一维卷积来生成通道权重,因为 k 往往会更大。为了改进,我们提供了一种替代方案,首先通过简单地使用全局平均池化(GAP)来嵌入全局信息,生成通道统计量 s ∈ RC/2G×1×1,可以通过空间维度缩小Xk1 来计算高×宽此外,还创建了一个紧凑的功能来指导精确和自适应的选择。这是通过带有 sigmoid 激活的简单门控机制来实现的。

与通道注意力不同,空间注意力关注的是“哪里”,是信息性的部分,与通道注意力是互补的。首先,我们在 Xk2 上使用群范数 (GN)来获得空间统计数据。然后,采用Fc(·)来增强^ Xk2的表示。

之后,所有子特征都被聚合。最后,与ShuffleNet v2 类似,我们采用“channel shuffle”算子来实现沿通道维度的跨组信息流。 SA模块的最终输出与X的大小相同,使得SA非常容易与现代架构集成

代码实现
class ShuffleAttention(nn.Module):def __init__(self, channel=512, reduction=16, G=8):super().__init__()self.G = Gself.channel = channelself.avg_pool = nn.AdaptiveAvgPool2d(1)self.gn = nn.GroupNorm(channel // (2 * G), channel // (2 * G))self.cweight = Parameter(torch.zeros(1, channel // (2 * G), 1, 1))self.cbias = Parameter(torch.ones(1, channel // (2 * G), 1, 1))self.sweight = Parameter(torch.zeros(1, channel // (2 * G), 1, 1))self.sbias = Parameter(torch.ones(1, channel // (2 * G), 1, 1))self.sigmoid = nn.Sigmoid()def init_weights(self):for m in self.modules():if isinstance(m, nn.Conv2d):init.kaiming_normal_(m.weight, mode='fan_out')if m.bias is not None:init.constant_(m.bias, 0)elif isinstance(m, nn.BatchNorm2d):init.constant_(m.weight, 1)init.constant_(m.bias, 0)elif isinstance(m, nn.Linear):init.normal_(m.weight, std=0.001)if m.bias is not None:init.constant_(m.bias, 0)@staticmethoddef channel_shuffle(x, groups):b, c, h, w = x.shapex = x.reshape(b, groups, -1, h, w)x = x.permute(0, 2, 1, 3, 4)# flattenx = x.reshape(b, -1, h, w)return xdef forward(self, x):b, c, h, w = x.size()# group into subfeaturesx = x.view(b * self.G, -1, h, w)  # bs*G,c//G,h,w# channel_splitx_0, x_1 = x.chunk(2, dim=1)  # bs*G,c//(2*G),h,w# channel attentionx_channel = self.avg_pool(x_0)  # bs*G,c//(2*G),1,1x_channel = self.cweight * x_channel + self.cbias  # bs*G,c//(2*G),1,1x_channel = x_0 * self.sigmoid(x_channel)# spatial attentionx_spatial = self.gn(x_1)  # bs*G,c//(2*G),h,wx_spatial = self.sweight * x_spatial + self.sbias  # bs*G,c//(2*G),h,wx_spatial = x_1 * self.sigmoid(x_spatial)  # bs*G,c//(2*G),h,w# concatenate along channel axisout = torch.cat([x_channel, x_spatial], dim=1)  # bs*G,c//G,h,wout = out.contiguous().view(b, -1, h, w)# channel shuffleout = self.channel_shuffle(out, 2)return out
yaml文件实现
# Ultralytics YOLO  , GPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect# Parameters
nc: 6  # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'# [depth, width, max_channels]n: [0.33, 0.25, 1024]  # YOLOv8n summary: 225 layers,  3157200 parameters,  3157184 gradients,   8.9 GFLOPss: [0.33, 0.50, 1024]  # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients,  28.8 GFLOPsm: [0.67, 0.75, 768]   # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients,  79.3 GFLOPsl: [1.00, 1.00, 512]   # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPsx: [1.00, 1.25, 512]   # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs# YOLOv8.0n backbone
backbone:# [from, repeats, module, args]- [-1, 1, Conv, [64, 3, 2]]  # 0-P1/2- [-1, 1, Conv, [128, 3, 2]]  # 1-P2/4- [-1, 3, C2f, [128, True]]- [-1, 1, Conv, [256, 3, 2]]  # 3-P3/8- [-1, 6, C2f, [256, True]]- [-1, 1, Conv, [512, 3, 2]]  # 5-P4/16- [-1, 6, C2f, [512, True]]- [-1, 1, Conv, [1024, 3, 2]]  # 7-P5/32- [-1, 3, C2f, [1024, True]]- [-1, 1, SPPF, [1024, 5]]  # 9# YOLOv8.0n head
head:- [-1, 1, nn.Upsample, [None, 2, 'nearest']]- [[-1, 6], 1, Concat, [1]]  # cat backbone P4- [-1, 3, C2f, [512]]  # 12- [-1, 1, nn.Upsample, [None, 2, 'nearest']]- [[-1, 4], 1, Concat, [1]]  # cat backbone P3- [-1, 3, C2f, [256]]  # 15 (P3/8-small)- [-1, 1, Conv, [256, 3, 2]]- [[-1, 12], 1, Concat, [1]]  # cat head P4- [-1, 3, C2f, [512]]  # 18 (P4/16-medium)- [-1, 1, Conv, [512, 3, 2]]- [[-1, 9], 1, Concat, [1]]  # cat head P5- [-1, 3, C2f, [1024]]  # 21 (P5/32-large)- [-1, 3, ShuffleAttention, [1024]]- [[15, 18, 22], 1, Detect, [nc]]  # Detect(P3, P4, P5)
完整代码分享

链接: https://pan.baidu.com/s/1NPb6C6svuNGqyZIYUVgsVw?pwd=yjnw 提取码: yjnw

启动命令
yolo detect train model=/path/yolov8_ShuffleAttention.yaml data=/path/coco128.com
注意事项

如果报错,查看这篇文章

YOLOv8 | 添加注意力机制报错KeyError:已解决,详细步骤_yolov8 keyerroe-CSDN博客

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/765400.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

mysql字段多个值,mybatis/mybatis-plus匹配查询

mysql中有一个字段是字符串类型的,category字段值有多个用逗号分割的,例如:娱乐,时尚美妆,美食 。现在想实现这么一个功能, 前端传参 字符串,美食,娱乐。现在想在mybatis的xml中实现,查询,能查到…

软件测试 -- Selenium常用API全面解答(java)

写在前面 // 如果文章有问题的地方, 欢迎评论区或者私信指正 目录 什么是Selenium 一个简单的用例 元素定位 id定位 xpath定位 name定位 tag name 定位和class name 定位 操作元素 click send_keys submit text getAttribute 添加等待 显示等待 隐式等待 显示等…

【wubuntu】披着Win11皮肤主题的Ubuntu系统

wubuntu - 一款外观类似于 Windows 的 Linux 操作系统,没有任何硬件限制。以下是官方的描述 Wubuntu is an operating system based on Ubuntu LTS that has a similar appearance to Windows using the open-source themes. Wubuntu also comes with a set of adva…

JavaScript 权威指南第七版(GPT 重译)(二)

第四章:表达式和运算符 本章记录了 JavaScript 表达式以及构建许多这些表达式的运算符。表达式 是 JavaScript 的短语,可以 评估 以产生一个值。在程序中直接嵌入的常量是一种非常简单的表达式。变量名也是一个简单表达式,它评估为分配给该变…

2024 Mazing 3 中文版新功能介绍Windows and macOS

iMazing 3中文版(ios设备管理软件)是一款管理苹果设备的软件, Windows 平台上的一款帮助用户管理 IOS 手机的应用程序。iMazing中文版与苹果设备连接后,可以轻松传输文件,浏览保存信息等,软件功能非常强大,界面简洁明晰…

【运维】MacOS Wifi热点设置

目录 打开热点 配置共享网段 打开热点 打开macOS设置,进入通用->共享 点击如下图标进行配置, 会进入如下界面(⚠️目前是打开共享状态,无法修改配置,只有在未打开状态才能进入配置) 配置完成后&#x…

2024.3.23

具体是哪些参数 每个点是如何投票的 删除那里不懂 是先给了,再重新分组,具体怎么再重新分组 聚的簇 真实的标签 为什么没有PI 加上其他的 基因调控网络

【前端寻宝之路】学习和总结有无序列表的实现和样式修改

🌈个人主页: Aileen_0v0 🔥热门专栏: 华为鸿蒙系统学习|计算机网络|数据结构与算法|MySQL| ​💫个人格言:“没有罗马,那就自己创造罗马~” #mermaid-svg-WuZk6y8cqVpDsE8W {font-family:"trebuchet ms",verdana,arial,sans-serif;f…

深度学习之本地部署大模型ChatGLM3-6B【大模型】【报错】

文章目录 0.前言1.模型下载2.配置环境2.1 下载项目文件2.2 配置环境 3.开始推理4.总结 0.前言 本博客将介绍ChatGLM3-6B大模型在Ubuntu上的本地部署教程 1.模型下载 由于毛毛张的服务器服务无法科学上网,所以模型的相关文件必须现在本地下载好,再上传到…

【C语言】数据在内存中的存储(包含大小端字节序问题)~

一、前言 我们在刚开始学习C语言的时候,就接触到了很多数据的不同类型。我们也知道,数据是存储在一块内存空间的,且我们只知道数据的类型决定着,该数据在内存中所占内存空间的大小,且超过一个字节的数据在内存中存储的…

HTTP --- 上

目录 1. HTTP协议 2. 认识URL 2.1. URL中的四个主要字段 2.2. URL Encode && URL Decode 3. HTTP 协议格式 3.1. 快速构建 HTTP 请求和响应的报文格式 3.1.1. HTTP 请求格式 3.1.2. HTTP 响应格式 3.1.3. 关于 HTTP 请求 && 响应的宏观理解 3.2. 实现…

SOPHON算能服务器SDK环境配置和相关库安装

目录 1 SDK大包下载 2 安装libsophon 2.1 安装依赖 1.2 安装libsophon 2 安装 sophon-mw 参考文献: 1 SDK大包下载 首先需要根据之前的博客,下载SDK大包:SOPHON算能科技新版SDK环境配置以及C demo使用过程_sophon sdk yolo-CSDN博客 …

计算机三级——网络技术(综合题第四题)

综合题第四题考点总结: 1.DSN域名解析 2.ICMP控制报文协议 3.TCP三次握手 4.HTTP超文本传输协议 5.FTP文件传输协议 DNS域名解析 域名系统(英文:Domain Name system,缩写:DNS)是互联网的一项服务。它作为将…

数据中台:如何构建企业核心竞争力_光点科技

在当今信息化快速发展的商业环境下,“数据中台”已经成为构建企业核心竞争力的关键步骤。数据中台不仅是数据集成与管理的平台,更是企业智能化转型的加速器。本文将深入探讨数据中台的定义、特点、构建方法及其在企业中的作用。 数据中台的定义 数据中台…

8-深度学习

声明 本文章基于哔哩哔哩付费课程《小白也能听懂的人工智能原理》。仅供学习记录、分享,严禁他用!!如有侵权,请联系删除 目录 一、知识引入 (一)深度学习 (二)Tensorflo…

嵌入式Linux内核启动过程详解(第二阶段:C语言部分)

目录 概述 1 启动内核第二阶段流程图 2 嵌入式Linux内核启动分析(C语言部分) 2.1 start_kernel()函数 2.2 rest_init()函数 2.3 kernel_init()函数 2.4 kernel_init_freeable()函数 概述 本文主要介绍linux内核启动(内核版本&#xff…

FPGA学习_时序分析

文章目录 前言一、组合逻辑与时序逻辑二、建立时间和保持时间三、建立时间和保持时间 前言 心中有电路,下笔自然神!!! 一、组合逻辑与时序逻辑 组合逻辑:没有时钟控制的数字电路,代码里的判断逻辑都是组…

先进电机技术 —— 何为轴电压?

一、特定场景举例 长线驱动指的是在电动机与变频器之间存在较长的连接电缆,尤其是在大型工业应用中,电机可能远离变频器几十米甚至上百米。这种情况下,变频器通过长线向电动机传输功率时,可能会加剧电机轴电压和轴电流的产生&…

《明日边缘2》AI制作电影宣传片

《明日边缘2》AI制作电影宣传片 In the dawn of a new war, the cycle of death and rebirth begins. 在新战争的曙光中,生死轮回的循环悄然开启。 Each repetition brings a new horror, but also a chance for redemption. 每一次循环都带来新的恐怖,却…

第六十二回 宋江兵打大名城 关胜议取梁山泊-飞桨ONNX推理部署初探

石秀和卢俊义在城内走投无路,又被抓住。梁中书把他两个人押入死牢。蔡福把他俩关在一处,好酒好菜照顾着,没让两人吃苦。 第二天就接到城外梁山泊的帖子,说大军已经来到,要替天行道,让他放人,并…