YOLOv8改进 | 注意力机制 | 增强模型在图像分类和目标检测BAM注意力【小白必备 + 附完整代码】

秋招面试专栏推荐 :深度学习算法工程师面试问题总结【百面算法工程师】——点击即可跳转


💡💡💡本专栏所有程序均经过测试,可成功执行💡💡💡


专栏目录 :《YOLOv8改进有效涨点》专栏介绍 & 专栏目录 | 目前已有50+篇内容,内含各种Head检测头、损失函数Loss、Backbone、Neck、NMS等创新点改进——点击即可跳转


近期深度神经网络的发展已经通过架构搜索实现了更强的表征能力。提出了一种名为瓶颈注意力模块(BAM)的简单而有效的注意力模块。它可以与任何前馈卷积神经网络集成。我们的模块沿着两个独立的路径(通道和空间)推断注意力图。我们将模块放置在模型的每个瓶颈处,即特征图降采样发生的地方。我们的模块在瓶颈处构建了一个分层的注意力,并且具有可训练的参数,可以与任何前馈模型进行端到端的联合训练。文章在介绍主要的原理后,将手把手教学如何进行模块的代码添加和修改,并将修改后的完整代码放在文章的最后方便大家一键运行,小白也可轻松上手实践。以帮助您更好地学习深度学习目标检测YOLO系列的挑战。

专栏地址YOLOv8改进——更新各种有效涨点方法——点击即可跳转

目录

1.原理

2. 将BAM添加到YOLOv8中

2.1 BAM代码实现

2.2 更改init.py文件

2.3 添加yaml文件

2.4 在task.py中进行注册

2.5 执行程序

3. 完整代码分享

4. GFLOPs

5. 进阶

6. 总结


1.原理

论文地址:BAM: Bottleneck Attention Module——点击即可跳转

官方代码:官方代码仓库——点击即可跳转

BAM(瓶颈注意模块)的主要原理如下:

1. 模块概述

BAM是一种简单且有效的注意力模块,旨在提高深度神经网络的表征能力。它可以与任何前馈卷积神经网络(CNN)集成使用,主要在网络的瓶颈处(即特征图下采样的位置)进行操作。

2. 结构设计

BAM模块通过两个独立的路径推断出一个注意力图:通道路径和空间路径。它的具体结构如下:

通道注意力分支

  1. 全局平均池化:对输入特征图 F \in \mathbb{R}^{C \times H \times W}进行全局平均池化,得到通道向量 F_c \in \mathbb{R}^{C \times 1 \times 1}

  2. 多层感知机(MLP):使用一个带有一个隐藏层的MLP来估计通道间的注意力。为了减少参数开销,隐藏层的激活大小设置为 \frac{C}{r} \times 1 \times 1 ,其中 r 是缩减比。

  3. 批量归一化:在MLP之后添加批量归一化层 BN以调整与空间分支输出的规模。

  4. 通道注意力:计算通道注意力 M_c(F)

空间注意力分支

  1. 空间注意力:通过卷积操作生成空间注意力图 M_s(F)

3. 注意力融合

BAM模块将通道和空间注意力进行融合,得到最终的3D注意力图 M(F) : M(F) = \sigma(M_c(F) + M_s(F)) 其中( \sigma ) 是sigmoid函数。融合后的注意力图用于增强输入特征图 ( F ),计算公式为: F' = F + F \otimes M(F) 其中 \otimes表示逐元素乘法。

4. 优势和应用

BAM模块具有轻量级设计,参数和计算开销很小,可以在多个基准测试(如CIFAR-100、ImageNet-1K、VOC 2007和MS COCO)上验证其有效性。它可以显著提高分类和检测性能,并能构建层次化的注意力机制,从低级特征(如背景纹理)逐渐聚焦到高级语义目标。

2. 将BAM添加到YOLOv8中

2.1 BAM代码实现

关键步骤一: 将下面代码粘贴到在/ultralytics/ultralytics/nn/modules/block.py中,并在该文件的__all__中添加“BAMBlock”

class Flatten(nn.Module):def forward(self, x):return x.view(x.shape[0], -1)class ChannelAttention(nn.Module):def __init__(self, channel, reduction=16, num_layers=3):super().__init__()self.avgpool = nn.AdaptiveAvgPool2d(1)gate_channels = [channel]gate_channels += [channel // reduction] * num_layersgate_channels += [channel]self.ca = nn.Sequential()self.ca.add_module('flatten', Flatten())for i in range(len(gate_channels) - 2):self.ca.add_module('fc%d' % i, nn.Linear(gate_channels[i], gate_channels[i + 1]))self.ca.add_module('bn%d' % i, nn.BatchNorm1d(gate_channels[i + 1]))self.ca.add_module('relu%d' % i, nn.SiLU())self.ca.add_module('last_fc', nn.Linear(gate_channels[-2], gate_channels[-1]))def forward(self, x):res = self.avgpool(x)res = self.ca(res)res = res.unsqueeze(-1).unsqueeze(-1).expand_as(x)return resclass SpatialAttention(nn.Module):def __init__(self, channel, reduction=16, num_layers=3, dia_val=2):super().__init__()self.sa = nn.Sequential()self.sa.add_module('conv_reduce1',nn.Conv2d(kernel_size=1, in_channels=channel, out_channels=channel // reduction))self.sa.add_module('bn_reduce1', nn.BatchNorm2d(channel // reduction))self.sa.add_module('relu_reduce1', nn.SiLU())for i in range(num_layers):self.sa.add_module('conv_%d' % i, nn.Conv2d(kernel_size=3, in_channels=channel // reduction,out_channels=channel // reduction,padding=autopad(3, None, dia_val), dilation=dia_val))self.sa.add_module('bn_%d' % i, nn.BatchNorm2d(channel // reduction))self.sa.add_module('relu_%d' % i, nn.SiLU())self.sa.add_module('last_conv', nn.Conv2d(channel // reduction, 1, kernel_size=1))def forward(self, x):res = self.sa(x)res = res.expand_as(x)return resclass BAMBlock(nn.Module):def __init__(self, channel=512, reduction=16, dia_val=2):super().__init__()self.ca = ChannelAttention(channel=channel, reduction=reduction)self.sa = SpatialAttention(channel=channel, reduction=reduction, dia_val=dia_val)self.sigmoid = nn.Sigmoid()def init_weights(self):for m in self.modules():if isinstance(m, nn.Conv2d):init.kaiming_normal_(m.weight, mode='fan_out')if m.bias is not None:init.constant_(m.bias, 0)elif isinstance(m, nn.BatchNorm2d):init.constant_(m.weight, 1)init.constant_(m.bias, 0)elif isinstance(m, nn.Linear):init.normal_(m.weight, std=0.001)if m.bias is not None:init.constant_(m.bias, 0)def forward(self, x):b, c, _, _ = x.size()sa_out = self.sa(x)ca_out = self.ca(x)weight = self.sigmoid(sa_out + ca_out)out = (1 + weight) * xreturn out

BAM (Bottleneck Attention Module) 处理图片的主要流程如下:

  1. 输入图像: 首先输入图像到卷积神经网络(CNN)中。

  2. 初始卷积处理: 使用一系列卷积层提取图像的特征,这些卷积层能够捕捉不同尺度和复杂度的特征信息。

  3. 特征提取和降维: 在初始特征提取之后,经过多个卷积层的处理,将特征图传递到 BAM 中。在这里,首先会进行特征图的降维处理,以减少计算量。

  4. 计算通道注意力: 通过全局平均池化 (Global Average Pooling, GAP) 和全局最大池化 (Global Max Pooling, GMP) 计算特征图在通道维度上的注意力分布。

  5. 计算空间注意力: 使用卷积操作计算特征图在空间维度上的注意力分布,捕捉图像中重要的空间位置。

  6. 融合注意力权重: 将通道注意力和空间注意力相结合,得到综合的注意力权重。这些权重用于调整原始特征图中的特征响应。

  7. 特征增强: 使用融合后的注意力权重对初始特征图进行加权调整,增强重要特征,抑制无关或冗余特征。

  8. 输出特征图: 最终输出经过 BAM 处理后的特征图,送入后续的网络层进行进一步的处理和分类等任务。

这个流程通过在特征提取过程中引入注意力机制,有效地增强了卷积神经网络对图像重要信息的捕捉能力,提升了模型的整体性能。

2.2 更改init.py文件

关键步骤二:修改modules文件夹下的__init__.py文件,先导入函数

然后在下面的__all__中声明函数

2.3 添加yaml文件

关键步骤三:在/ultralytics/ultralytics/cfg/models/v8下面新建文件yolov8_BAM.yaml文件,粘贴下面的内容

# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect# Parameters
nc: 80  # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8-SPPCSPC.yaml with scale 'n'# [depth, width, max_channels]n: [0.33, 0.25, 1024]  # YOLOv8n summary: 225 layers,  3157200 parameters,  3157184 gradients,   8.9 GFLOPss: [0.33, 0.50, 1024]  # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients,  28.8 GFLOPsm: [0.67, 0.75, 768]   # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients,  79.3 GFLOPsl: [1.00, 1.00, 512]   # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPsx: [1.00, 1.25, 512]   # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs# YOLOv8.0n backbone
backbone:# [from, repeats, module, args]- [-1, 1, Conv, [64, 3, 2]]  # 0-P1/2- [-1, 1, Conv, [128, 3, 2]]  # 1-P2/4- [-1, 3, C2f, [128, True]]- [-1, 1, Conv, [256, 3, 2]]  # 3-P3/8- [-1, 6, C2f, [256, True]]- [-1, 1, BAMBlock, [256]]- [-1, 1, Conv, [512, 3, 2]]  # 6-P4/16- [-1, 6, C2f, [512, True]]- [-1, 1, BAMBlock, [512]]- [-1, 1, Conv, [1024, 3, 2]]  # 9-P5/32- [-1, 3, C2f, [1024, True]]- [-1, 1, SPPF, [1024, 5]]  # 11- [-1, 1, BAMBlock, [1024]]# YOLOv8.0n head
head:- [-1, 1, nn.Upsample, [None, 2, 'nearest']]- [[-1, 8], 1, Concat, [1]]  # cat backbone P4- [-1, 3, C2f, [512]]  # 15- [-1, 1, nn.Upsample, [None, 2, 'nearest']]- [[-1, 5], 1, Concat, [1]]  # cat backbone P3- [-1, 3, C2f, [256]]  # 18 (P3/8-small)- [-1, 1, Conv, [256, 3, 2]]- [[-1, 15], 1, Concat, [1]]  # cat head P4- [-1, 3, C2f, [512]]  # 21 (P4/16-medium)- [-1, 1, Conv, [512, 3, 2]]- [[-1, 12], 1, Concat, [1]]  # cat head P5- [-1, 3, C2f, [1024]]  # 24 (P5/32-large)- [[18, 21, 24], 1, Detect, [nc]]  # Detect(P3, P4, P5)

2.4 在task.py中进行注册

关键步骤四:在task.py的parse_model函数中进行注册, ​​

elif m in {BAMBlock}:c1, c2 = ch[f], args[0]if c2 != nc:  # if not outputc2 = make_divisible(min(c2, max_channels) * width, 8)args = [c1, c2, *args[1:]]

2.5 执行程序

关键步骤五:在ultralytics文件中新建train.py,将model的参数路径设置为yolov8_BAM.yaml的路径即可

如果你的添加导致报错请看: YOLOv8报错 | 添加注意力机制报错 | ValueError: Expected more than 1 value per channel when training, got input

from ultralytics import YOLO# Load a model
# model = YOLO('yolov8n.yaml')  # build a new model from YAML
# model = YOLO('yolov8n.pt')  # load a pretrained model (recommended for training)model = YOLO(r'/projects/ultralytics/ultralytics/cfg/models/v8/yolov8_BAM.yaml')  # build from YAML and transfer weights# Train the model
model.train(batch=16)

 🚀运行程序,如果出现下面的内容则说明添加成功🚀

3. 完整代码分享

https://pan.baidu.com/s/18gR-Da_hD3BZZt-jCM5IeQ?pwd=hkxo 

提取码:hkxo  

4. GFLOPs

关于GFLOPs的计算方式可以查看:百面算法工程师 | 卷积基础知识——Convolution

未改进的YOLOv8nGFLOPs

改进后的GFLOPs

5. 进阶

可以结合损失函数或者卷积模块进行多重改进

6. 总结

Bottleneck Attention Module (BAM) 的主要原理是通过结合通道注意力和空间注意力两种机制来提高卷积神经网络(CNN)的特征提取能力。BAM 的设计灵感来自于人类视觉系统的“是什么”(通道)和“在哪里”(空间)路径,这两条路径共同作用于视觉信息的处理。

BAM 的主要结构包括两个分支:通道注意力分支和空间注意力分支。通道注意力分支通过全局平均池化来捕捉每个通道上的全局信息,然后使用多层感知器(MLP)计算每个通道的注意力权重。空间注意力分支则通过卷积操作来捕捉特征图在空间维度上的注意力分布。两者的输出通过逐元素加法(element-wise summation)进行融合,形成最终的注意力图。这个注意力图用于加权调整原始特征图,以增强重要特征和抑制无关或冗余特征。

具体来说,给定输入特征图F \in \mathbb{R}^{C \times H \times W},BAM 计算出一个 3D 注意力图 M(F) \in \mathbb{R}^{C \times H \times W}。经过 BAM 处理后的特征图 \(F'\) 通过以下公式计算得到:
F' = F + F \otimes M(F)
其中,\otimes表示逐元素相乘。

在通道注意力分支中,首先对特征图 \(F\) 进行全局平均池化,得到通道向量 F_c \in \mathbb{R}^{C \times 1 \times 1}。然后通过一个多层感知器(MLP)计算通道注意力M_c(F)。空间注意力分支则通过卷积操作计算空间注意力图 M_s(F) \in \mathbb{R}^{H \times W}。最终注意力图 M(F) 通过以下公式计算得到:
M(F) = \sigma(M_c(F) + M_s(F)) 
其中,\sigma 是 sigmoid 函数。

BAM 的设计不仅提升了模型的性能,而且只增加了极少的计算开销。这使得 BAM 特别适合在资源受限的移动设备和嵌入式系统中应用。通过在网络中的瓶颈处(bottleneck)插入 BAM,可以显著提升模型的特征提取能力和整体性能 。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/44046.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

安防管理平台LntonCVS视频汇聚融合云平台智慧火电厂安全生产管理应用方案

中国的电力产业作为国民经济发展的重要能源支柱,被视为国民经济的基础产业之一。目前,我国主要依赖火力发电,主要燃料包括煤炭、石油和天然气等,通过燃烧转化为动能,再转变为电能输送至全国各地。火力发电量占全国发电…

【软件测试】 1+X初级 功能测试试题

【软件测试】 1X初级 功能测试试题 普通员工登录系统,在“个人信息维护”模块,可以查看和维护个人信息。个人信息维护需求包括用户(UI)页面、业务规则两部分。 UI 界面 个人信息维护 修改基本信息 业务规则 1. 个人信息维护页面…

CB-LLM 可信大模型,让大模型可解释

CB-LLM 可信大模型,让大模型可解释 提出背景解法拆解目的问题框架图第1步:概念生成第2步:自动概念评分(ACS)第3步:训练概念瓶颈层(CBL)第4步:学习预测器 例子&#xff1a…

图片批量重命名bat,一个脚本快速搞定图片批量重命名

BAT 批处理 是一种在 Microsoft Windows 操作系统中使用的脚本语言,用于自动执行一系列预定义的命令或任务。这些命令集合通常存储在一个文本文件中,文件扩展名为 .bat 或 .cmd。批处理脚本可以包含简单的命令,如文件复制、移动、删除&#x…

单片机中有FLASH为啥还需要EEROM?

在开始前刚好我有一些资料,是我根据网友给的问题精心整理了一份「单片机的资料从专业入门到高级教程」, 点个关注在评论区回复“888”之后私信回复“888”,全部无偿共享给大家!!! 一是EEPROM操作简单&…

WebRTC API接口教程:实现高效会议的步骤?

WebRTC api接口教程如何使用?WebRTC api接口的功能? WebRTC无需中间服务器即可传输音视频流,为视频会议、在线教育等应用提供了强大的支持。AokSend将详细介绍如何利用WebRTC API接口实现高效会议的步骤。 WebRTC API接口教程:获…

2024年福州延安中学夏季拿云杯拔尖创新人才素养测试

1、选择题 那么,mn的值是( ) A、1243 B、1343 C、4029 D、4049 2、填空题 一副扑克牌共54张,其中1到13点各有 4张,每个数字黑色红色各两张,还有两张王牌,至少要取出( )…

存储产品选型策略 OSS生命周期管理与运维

最近在看阿里云的 云存储通关实践认证训练营这个课程还是不错的。 存储产品选型策略、对象存储OSS入门、基于对象存储OSS快速搭建网盘、 如何做好权限控制、如何做好数据安全、如何做好数据管理、涉及对象存储OSS的权限控制、使用OSS完成静态网站托管、对OSS中存储的数据进行分…

论项目管理工作中的成本管理(20240528)

论项目管理工作中的成本管理 20240528 随着《“十四五”智能制造发展规划》的发布及其提出的2025发展目标及2035远景规划,国家对智能制造发展的重视程度进一步提升。生产制造企业对于智能制造转型的需求愈加迫切。2023年2月,XX电器制造企业为了解决企业…

前端直连小票打印机,前端静默打印,js静默打印解决方案

最近公司开发了一个vue3收银系统,需要使用小票打印机打印小票,但是又不想结账的时候弹出打印预览,找了很多方案,解决不了js打印弹出的打印预览窗口! 没办法,自己写了一个winform版本的静默打印软件&#xf…

【鸿蒙学习笔记】Stage模型

官方文档:Stage模型开发概述 目录标题 Stage模型好处Stage模型概念图ContextAbilityStageUIAbility组件和ExtensionAbility组件WindowStage Stage模型-组件模型Stage模型-进程模型Stage模型-ArkTS线程模型和任务模型关于任务模型,我们先来了解一下什么是…

鸿蒙语言基础类库:【@ohos.util.ArrayList (线性容器ArrayList)】

线性容器ArrayList 说明: 本模块首批接口从API version 8开始支持。后续版本的新增接口,采用上角标单独标记接口的起始版本。开发前请熟悉鸿蒙开发指导文档:gitee.com/li-shizhen-skin/harmony-os/blob/master/README.md点击或者复制转到。 …

基于Java中的SSM框架实现疫情冷链追溯系统项目【项目源码+论文说明】

基于Java中的SSM框架实现疫情冷链追溯系统演示 摘要 近几年随着城镇化发展和居民消费水平的不断提升,人们对健康生活方式的追求意识逐渐加强,生鲜食品逐渐受到大众青睐,诸如盒马鲜生、7-fresh等品牌生鲜超市,一时间如雨后春笋般迅…

合合信息大模型加速器重磅上线,释放智能文档全新可能

目录 0 写在前面1 高速文档解析引擎:拓宽大模型认知边界2 文本嵌入模型acge:克服大模型感知缺陷3 行业赋能:以百川智能为例总结 0 写在前面 随着人工智能技术的飞速发展,大模型以强大的数字处理能力和深度学习能力,不…

Canvas:掌握图像变换合成与裁剪状态像素操作

想象一下,用几行代码就能创造出如此逼真的图像和动画,仿佛将艺术与科技完美融合,前端开发的Canvas技术正是这个数字化时代中最具魔力的一环,它不仅仅是网页的一部分,更是一个无限创意的画布,一个让你的想象…

java使用poi-tl模版引擎导出word之if判断条件的使用

文章目录 模版中if语句条件的使用1.数据为False或空集合2.非False或非空集合 模版中if语句条件的使用 如果区块对的值是 null 、false 或者空的集合,位于区块中的所有文档元素将不会显示,这就等同于if语句的条件为 false。语法示例:{{?stat…

视图库对接系列(GA-T 1400)十四、视图库对接系列(本级)新增、修改订阅

说明 之前我们已经对接的设备,设备的话比较简单,是设备主动推送数据到平台的。 相信大家已经会了,那今天开始的话,我们来做对接平台,相对难点点。 但搞懂了核心的订阅流程的话,其实就不难了。 对接平台 订阅接口 订阅接口的话,有几个,添加、查询、更新、删除、取消…

Linux镜像源设置不再难:一键脚本,新手也能成为优化高手(一键切换镜像源/Docker一键安装脚本)

文章目录 📖 介绍 📖🏡 演示环境 🏡📒 更换镜像源 📒📝 一键切换软件源📝 Docker一键安装脚本⚓️ 相关链接 ⚓️📖 介绍 📖 在国内,Linux系统用户经常会遇到下载软件包时速度慢的问题,这通常是因为默认的镜像源并不总是最优选择。对于新手来说,手动设置…

亚马逊速卖通卖家必看:自养号测评策略,下单高效防关联全攻略

在跨境电商的激烈竞争中,自养号测评策略已成为众多卖家追求低成本、高效推广的优选路径。然而,其成功实施离不开一系列精心策划与严格执行的关键要素。以下是对这些核心条件的深入剖析,旨在指导您安全、有效地构建并运营自养号测评体系。 一、…