YOLOv8改进 | 注意力机制 | 增强模型在图像分类和目标检测BAM注意力【小白必备 + 附完整代码】

秋招面试专栏推荐 :深度学习算法工程师面试问题总结【百面算法工程师】——点击即可跳转


💡💡💡本专栏所有程序均经过测试,可成功执行💡💡💡


专栏目录 :《YOLOv8改进有效涨点》专栏介绍 & 专栏目录 | 目前已有50+篇内容,内含各种Head检测头、损失函数Loss、Backbone、Neck、NMS等创新点改进——点击即可跳转


近期深度神经网络的发展已经通过架构搜索实现了更强的表征能力。提出了一种名为瓶颈注意力模块(BAM)的简单而有效的注意力模块。它可以与任何前馈卷积神经网络集成。我们的模块沿着两个独立的路径(通道和空间)推断注意力图。我们将模块放置在模型的每个瓶颈处,即特征图降采样发生的地方。我们的模块在瓶颈处构建了一个分层的注意力,并且具有可训练的参数,可以与任何前馈模型进行端到端的联合训练。文章在介绍主要的原理后,将手把手教学如何进行模块的代码添加和修改,并将修改后的完整代码放在文章的最后方便大家一键运行,小白也可轻松上手实践。以帮助您更好地学习深度学习目标检测YOLO系列的挑战。

专栏地址YOLOv8改进——更新各种有效涨点方法——点击即可跳转

目录

1.原理

2. 将BAM添加到YOLOv8中

2.1 BAM代码实现

2.2 更改init.py文件

2.3 添加yaml文件

2.4 在task.py中进行注册

2.5 执行程序

3. 完整代码分享

4. GFLOPs

5. 进阶

6. 总结


1.原理

论文地址:BAM: Bottleneck Attention Module——点击即可跳转

官方代码:官方代码仓库——点击即可跳转

BAM(瓶颈注意模块)的主要原理如下:

1. 模块概述

BAM是一种简单且有效的注意力模块,旨在提高深度神经网络的表征能力。它可以与任何前馈卷积神经网络(CNN)集成使用,主要在网络的瓶颈处(即特征图下采样的位置)进行操作。

2. 结构设计

BAM模块通过两个独立的路径推断出一个注意力图:通道路径和空间路径。它的具体结构如下:

通道注意力分支

  1. 全局平均池化:对输入特征图 F \in \mathbb{R}^{C \times H \times W}进行全局平均池化,得到通道向量 F_c \in \mathbb{R}^{C \times 1 \times 1}

  2. 多层感知机(MLP):使用一个带有一个隐藏层的MLP来估计通道间的注意力。为了减少参数开销,隐藏层的激活大小设置为 \frac{C}{r} \times 1 \times 1 ,其中 r 是缩减比。

  3. 批量归一化:在MLP之后添加批量归一化层 BN以调整与空间分支输出的规模。

  4. 通道注意力:计算通道注意力 M_c(F)

空间注意力分支

  1. 空间注意力:通过卷积操作生成空间注意力图 M_s(F)

3. 注意力融合

BAM模块将通道和空间注意力进行融合,得到最终的3D注意力图 M(F) : M(F) = \sigma(M_c(F) + M_s(F)) 其中( \sigma ) 是sigmoid函数。融合后的注意力图用于增强输入特征图 ( F ),计算公式为: F' = F + F \otimes M(F) 其中 \otimes表示逐元素乘法。

4. 优势和应用

BAM模块具有轻量级设计,参数和计算开销很小,可以在多个基准测试(如CIFAR-100、ImageNet-1K、VOC 2007和MS COCO)上验证其有效性。它可以显著提高分类和检测性能,并能构建层次化的注意力机制,从低级特征(如背景纹理)逐渐聚焦到高级语义目标。

2. 将BAM添加到YOLOv8中

2.1 BAM代码实现

关键步骤一: 将下面代码粘贴到在/ultralytics/ultralytics/nn/modules/block.py中,并在该文件的__all__中添加“BAMBlock”

class Flatten(nn.Module):def forward(self, x):return x.view(x.shape[0], -1)class ChannelAttention(nn.Module):def __init__(self, channel, reduction=16, num_layers=3):super().__init__()self.avgpool = nn.AdaptiveAvgPool2d(1)gate_channels = [channel]gate_channels += [channel // reduction] * num_layersgate_channels += [channel]self.ca = nn.Sequential()self.ca.add_module('flatten', Flatten())for i in range(len(gate_channels) - 2):self.ca.add_module('fc%d' % i, nn.Linear(gate_channels[i], gate_channels[i + 1]))self.ca.add_module('bn%d' % i, nn.BatchNorm1d(gate_channels[i + 1]))self.ca.add_module('relu%d' % i, nn.SiLU())self.ca.add_module('last_fc', nn.Linear(gate_channels[-2], gate_channels[-1]))def forward(self, x):res = self.avgpool(x)res = self.ca(res)res = res.unsqueeze(-1).unsqueeze(-1).expand_as(x)return resclass SpatialAttention(nn.Module):def __init__(self, channel, reduction=16, num_layers=3, dia_val=2):super().__init__()self.sa = nn.Sequential()self.sa.add_module('conv_reduce1',nn.Conv2d(kernel_size=1, in_channels=channel, out_channels=channel // reduction))self.sa.add_module('bn_reduce1', nn.BatchNorm2d(channel // reduction))self.sa.add_module('relu_reduce1', nn.SiLU())for i in range(num_layers):self.sa.add_module('conv_%d' % i, nn.Conv2d(kernel_size=3, in_channels=channel // reduction,out_channels=channel // reduction,padding=autopad(3, None, dia_val), dilation=dia_val))self.sa.add_module('bn_%d' % i, nn.BatchNorm2d(channel // reduction))self.sa.add_module('relu_%d' % i, nn.SiLU())self.sa.add_module('last_conv', nn.Conv2d(channel // reduction, 1, kernel_size=1))def forward(self, x):res = self.sa(x)res = res.expand_as(x)return resclass BAMBlock(nn.Module):def __init__(self, channel=512, reduction=16, dia_val=2):super().__init__()self.ca = ChannelAttention(channel=channel, reduction=reduction)self.sa = SpatialAttention(channel=channel, reduction=reduction, dia_val=dia_val)self.sigmoid = nn.Sigmoid()def init_weights(self):for m in self.modules():if isinstance(m, nn.Conv2d):init.kaiming_normal_(m.weight, mode='fan_out')if m.bias is not None:init.constant_(m.bias, 0)elif isinstance(m, nn.BatchNorm2d):init.constant_(m.weight, 1)init.constant_(m.bias, 0)elif isinstance(m, nn.Linear):init.normal_(m.weight, std=0.001)if m.bias is not None:init.constant_(m.bias, 0)def forward(self, x):b, c, _, _ = x.size()sa_out = self.sa(x)ca_out = self.ca(x)weight = self.sigmoid(sa_out + ca_out)out = (1 + weight) * xreturn out

BAM (Bottleneck Attention Module) 处理图片的主要流程如下:

  1. 输入图像: 首先输入图像到卷积神经网络(CNN)中。

  2. 初始卷积处理: 使用一系列卷积层提取图像的特征,这些卷积层能够捕捉不同尺度和复杂度的特征信息。

  3. 特征提取和降维: 在初始特征提取之后,经过多个卷积层的处理,将特征图传递到 BAM 中。在这里,首先会进行特征图的降维处理,以减少计算量。

  4. 计算通道注意力: 通过全局平均池化 (Global Average Pooling, GAP) 和全局最大池化 (Global Max Pooling, GMP) 计算特征图在通道维度上的注意力分布。

  5. 计算空间注意力: 使用卷积操作计算特征图在空间维度上的注意力分布,捕捉图像中重要的空间位置。

  6. 融合注意力权重: 将通道注意力和空间注意力相结合,得到综合的注意力权重。这些权重用于调整原始特征图中的特征响应。

  7. 特征增强: 使用融合后的注意力权重对初始特征图进行加权调整,增强重要特征,抑制无关或冗余特征。

  8. 输出特征图: 最终输出经过 BAM 处理后的特征图,送入后续的网络层进行进一步的处理和分类等任务。

这个流程通过在特征提取过程中引入注意力机制,有效地增强了卷积神经网络对图像重要信息的捕捉能力,提升了模型的整体性能。

2.2 更改init.py文件

关键步骤二:修改modules文件夹下的__init__.py文件,先导入函数

然后在下面的__all__中声明函数

2.3 添加yaml文件

关键步骤三:在/ultralytics/ultralytics/cfg/models/v8下面新建文件yolov8_BAM.yaml文件,粘贴下面的内容

# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect# Parameters
nc: 80  # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8-SPPCSPC.yaml with scale 'n'# [depth, width, max_channels]n: [0.33, 0.25, 1024]  # YOLOv8n summary: 225 layers,  3157200 parameters,  3157184 gradients,   8.9 GFLOPss: [0.33, 0.50, 1024]  # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients,  28.8 GFLOPsm: [0.67, 0.75, 768]   # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients,  79.3 GFLOPsl: [1.00, 1.00, 512]   # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPsx: [1.00, 1.25, 512]   # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs# YOLOv8.0n backbone
backbone:# [from, repeats, module, args]- [-1, 1, Conv, [64, 3, 2]]  # 0-P1/2- [-1, 1, Conv, [128, 3, 2]]  # 1-P2/4- [-1, 3, C2f, [128, True]]- [-1, 1, Conv, [256, 3, 2]]  # 3-P3/8- [-1, 6, C2f, [256, True]]- [-1, 1, BAMBlock, [256]]- [-1, 1, Conv, [512, 3, 2]]  # 6-P4/16- [-1, 6, C2f, [512, True]]- [-1, 1, BAMBlock, [512]]- [-1, 1, Conv, [1024, 3, 2]]  # 9-P5/32- [-1, 3, C2f, [1024, True]]- [-1, 1, SPPF, [1024, 5]]  # 11- [-1, 1, BAMBlock, [1024]]# YOLOv8.0n head
head:- [-1, 1, nn.Upsample, [None, 2, 'nearest']]- [[-1, 8], 1, Concat, [1]]  # cat backbone P4- [-1, 3, C2f, [512]]  # 15- [-1, 1, nn.Upsample, [None, 2, 'nearest']]- [[-1, 5], 1, Concat, [1]]  # cat backbone P3- [-1, 3, C2f, [256]]  # 18 (P3/8-small)- [-1, 1, Conv, [256, 3, 2]]- [[-1, 15], 1, Concat, [1]]  # cat head P4- [-1, 3, C2f, [512]]  # 21 (P4/16-medium)- [-1, 1, Conv, [512, 3, 2]]- [[-1, 12], 1, Concat, [1]]  # cat head P5- [-1, 3, C2f, [1024]]  # 24 (P5/32-large)- [[18, 21, 24], 1, Detect, [nc]]  # Detect(P3, P4, P5)

2.4 在task.py中进行注册

关键步骤四:在task.py的parse_model函数中进行注册, ​​

elif m in {BAMBlock}:c1, c2 = ch[f], args[0]if c2 != nc:  # if not outputc2 = make_divisible(min(c2, max_channels) * width, 8)args = [c1, c2, *args[1:]]

2.5 执行程序

关键步骤五:在ultralytics文件中新建train.py,将model的参数路径设置为yolov8_BAM.yaml的路径即可

如果你的添加导致报错请看: YOLOv8报错 | 添加注意力机制报错 | ValueError: Expected more than 1 value per channel when training, got input

from ultralytics import YOLO# Load a model
# model = YOLO('yolov8n.yaml')  # build a new model from YAML
# model = YOLO('yolov8n.pt')  # load a pretrained model (recommended for training)model = YOLO(r'/projects/ultralytics/ultralytics/cfg/models/v8/yolov8_BAM.yaml')  # build from YAML and transfer weights# Train the model
model.train(batch=16)

 🚀运行程序,如果出现下面的内容则说明添加成功🚀

3. 完整代码分享

https://pan.baidu.com/s/18gR-Da_hD3BZZt-jCM5IeQ?pwd=hkxo 

提取码:hkxo  

4. GFLOPs

关于GFLOPs的计算方式可以查看:百面算法工程师 | 卷积基础知识——Convolution

未改进的YOLOv8nGFLOPs

改进后的GFLOPs

5. 进阶

可以结合损失函数或者卷积模块进行多重改进

6. 总结

Bottleneck Attention Module (BAM) 的主要原理是通过结合通道注意力和空间注意力两种机制来提高卷积神经网络(CNN)的特征提取能力。BAM 的设计灵感来自于人类视觉系统的“是什么”(通道)和“在哪里”(空间)路径,这两条路径共同作用于视觉信息的处理。

BAM 的主要结构包括两个分支:通道注意力分支和空间注意力分支。通道注意力分支通过全局平均池化来捕捉每个通道上的全局信息,然后使用多层感知器(MLP)计算每个通道的注意力权重。空间注意力分支则通过卷积操作来捕捉特征图在空间维度上的注意力分布。两者的输出通过逐元素加法(element-wise summation)进行融合,形成最终的注意力图。这个注意力图用于加权调整原始特征图,以增强重要特征和抑制无关或冗余特征。

具体来说,给定输入特征图F \in \mathbb{R}^{C \times H \times W},BAM 计算出一个 3D 注意力图 M(F) \in \mathbb{R}^{C \times H \times W}。经过 BAM 处理后的特征图 \(F'\) 通过以下公式计算得到:
F' = F + F \otimes M(F)
其中,\otimes表示逐元素相乘。

在通道注意力分支中,首先对特征图 \(F\) 进行全局平均池化,得到通道向量 F_c \in \mathbb{R}^{C \times 1 \times 1}。然后通过一个多层感知器(MLP)计算通道注意力M_c(F)。空间注意力分支则通过卷积操作计算空间注意力图 M_s(F) \in \mathbb{R}^{H \times W}。最终注意力图 M(F) 通过以下公式计算得到:
M(F) = \sigma(M_c(F) + M_s(F)) 
其中,\sigma 是 sigmoid 函数。

BAM 的设计不仅提升了模型的性能,而且只增加了极少的计算开销。这使得 BAM 特别适合在资源受限的移动设备和嵌入式系统中应用。通过在网络中的瓶颈处(bottleneck)插入 BAM,可以显著提升模型的特征提取能力和整体性能 。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/44046.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

极狐GitLab 17.0 重磅发布,100+ DevSecOps功能更新来啦~【四】

GitLab 是一个全球知名的一体化 DevOps 平台,很多人都通过私有化部署 GitLab 来进行源代码托管。极狐GitLab :https://gitlab.cn/install?channelcontent&utm_sourcecsdn 是 GitLab 在中国的发行版,专门为中国程序员服务。可以一键式部署…

安防管理平台LntonCVS视频汇聚融合云平台智慧火电厂安全生产管理应用方案

中国的电力产业作为国民经济发展的重要能源支柱,被视为国民经济的基础产业之一。目前,我国主要依赖火力发电,主要燃料包括煤炭、石油和天然气等,通过燃烧转化为动能,再转变为电能输送至全国各地。火力发电量占全国发电…

【软件测试】 1+X初级 功能测试试题

【软件测试】 1X初级 功能测试试题 普通员工登录系统,在“个人信息维护”模块,可以查看和维护个人信息。个人信息维护需求包括用户(UI)页面、业务规则两部分。 UI 界面 个人信息维护 修改基本信息 业务规则 1. 个人信息维护页面…

环形链表1-2 js 快慢指针

环形链表1: 设置两个指针, 慢指针一次走一步,快指针一次走两步, 如果 fast null 或者 fast.next null 不存在环, 如果存在环,两个指针进入环中,是一个追及问题,一定会相遇 var h…

CB-LLM 可信大模型,让大模型可解释

CB-LLM 可信大模型,让大模型可解释 提出背景解法拆解目的问题框架图第1步:概念生成第2步:自动概念评分(ACS)第3步:训练概念瓶颈层(CBL)第4步:学习预测器 例子&#xff1a…

图片批量重命名bat,一个脚本快速搞定图片批量重命名

BAT 批处理 是一种在 Microsoft Windows 操作系统中使用的脚本语言,用于自动执行一系列预定义的命令或任务。这些命令集合通常存储在一个文本文件中,文件扩展名为 .bat 或 .cmd。批处理脚本可以包含简单的命令,如文件复制、移动、删除&#x…

单片机中有FLASH为啥还需要EEROM?

在开始前刚好我有一些资料,是我根据网友给的问题精心整理了一份「单片机的资料从专业入门到高级教程」, 点个关注在评论区回复“888”之后私信回复“888”,全部无偿共享给大家!!! 一是EEPROM操作简单&…

FPGA设计之跨时钟域(CDC)设计篇(2)----如何科学地设计复位信号?

1、复位是干嘛的? 时钟信号和复位信号应该是一个数字系统最重要和最常用的两个信号了。时钟的重要性大家都懂,没有时钟整个系统就无法同步,自然也就谈不上运行了。那么复位(reset)到底是干嘛的? 所有的数字系统在上电的时候都会进行复位,这样才能确保该系统的初始运行状…

WebRTC API接口教程:实现高效会议的步骤?

WebRTC api接口教程如何使用?WebRTC api接口的功能? WebRTC无需中间服务器即可传输音视频流,为视频会议、在线教育等应用提供了强大的支持。AokSend将详细介绍如何利用WebRTC API接口实现高效会议的步骤。 WebRTC API接口教程:获…

Python 上位机开发

Python 上位机开发 第一节:入门介绍 在这第一节中,我们将对 Python 上位机开发进行一个初步的了解和探索。 首先,什么是上位机?上位机通常是指可以与下位机(如单片机、传感器等硬件设备)进行通信和交互,实现数据采集、控制指令发送以及数据处理和展示的计算机程序。 Pyt…

随着人工智能和机器学习的发展,如何在 C# 中有效地集成深度学习框架,以实现复杂的模型训练和预测功能,并且能够在不同的平台上进行部署和优化?

在C#中集成深度学习框架并实现复杂的模型训练和预测功能可以通过以下步骤进行: 选择适合的深度学习框架:目前在C#中可用的深度学习框架有多种选择,如TensorFlow.NET、CNTK、ML.NET等。根据具体需求选择一个适合的框架。 安装和配置深度学习框…

2024年福州延安中学夏季拿云杯拔尖创新人才素养测试

1、选择题 那么,mn的值是( ) A、1243 B、1343 C、4029 D、4049 2、填空题 一副扑克牌共54张,其中1到13点各有 4张,每个数字黑色红色各两张,还有两张王牌,至少要取出( )…

存储产品选型策略 OSS生命周期管理与运维

最近在看阿里云的 云存储通关实践认证训练营这个课程还是不错的。 存储产品选型策略、对象存储OSS入门、基于对象存储OSS快速搭建网盘、 如何做好权限控制、如何做好数据安全、如何做好数据管理、涉及对象存储OSS的权限控制、使用OSS完成静态网站托管、对OSS中存储的数据进行分…

论项目管理工作中的成本管理(20240528)

论项目管理工作中的成本管理 20240528 随着《“十四五”智能制造发展规划》的发布及其提出的2025发展目标及2035远景规划,国家对智能制造发展的重视程度进一步提升。生产制造企业对于智能制造转型的需求愈加迫切。2023年2月,XX电器制造企业为了解决企业…

C++设计模式---备忘录模式

1、介绍 备忘录模式是一种行为型设计模式,它允许在不破坏封装性的前提下,捕获一个对象的内部状态,并在该对象之外保存这个状态,以便以后将对象恢复到原先保存的状态。 该模式主要涉及三个角色: (1&#xf…

前端直连小票打印机,前端静默打印,js静默打印解决方案

最近公司开发了一个vue3收银系统,需要使用小票打印机打印小票,但是又不想结账的时候弹出打印预览,找了很多方案,解决不了js打印弹出的打印预览窗口! 没办法,自己写了一个winform版本的静默打印软件&#xf…

面试真题 | 操作系统中断知识

操作系统中断知识 什么是中断?在嵌入式系统中,为什么中断很重要? 参考答案 中断是计算机系统中的一种机制,用于在当前执行的程序或任务被中断处理程序(Interrupt Service Routine,ISR)中断执行时…

【鸿蒙学习笔记】Stage模型

官方文档:Stage模型开发概述 目录标题 Stage模型好处Stage模型概念图ContextAbilityStageUIAbility组件和ExtensionAbility组件WindowStage Stage模型-组件模型Stage模型-进程模型Stage模型-ArkTS线程模型和任务模型关于任务模型,我们先来了解一下什么是…

鸿蒙语言基础类库:【@ohos.util.ArrayList (线性容器ArrayList)】

线性容器ArrayList 说明: 本模块首批接口从API version 8开始支持。后续版本的新增接口,采用上角标单独标记接口的起始版本。开发前请熟悉鸿蒙开发指导文档:gitee.com/li-shizhen-skin/harmony-os/blob/master/README.md点击或者复制转到。 …