Yolov8有效涨点,添加多种注意力机制,修改损失函数提高目标检测准确率

目录

简介

CBAM注意力机制原理及代码实现

原理

 代码实现

 GAM注意力机制

原理

代码实现

修改损失函数

YAML文件

完整代码


🚀🚀🚀订阅专栏,更新及时查看不迷路🚀🚀🚀

http://t.csdnimg.cn/sVHxvicon-default.png?t=N7T8http://t.csdnimg.cn/sVHxv

简介

Ultralytics 推出了最新版本的 YOLO 模型。注意力机制是提高模型性能最热门的方法之一。

本次将介绍几种常见的注意力机制,这些注意力机制在大多数的数据集上均能有效的提升目标检测的精度/召回率/准确率。

CBAM注意力机制原理及代码实现
原理
CBAM注意力机制结构图

CBAM(Convolutional Block Attention Module)是一种用于卷积神经网络(CNN)的注意力机制,它能够增强网络对输入特征的关注度,提高网络性能。CBAM 主要包含两个子模块:通道注意力模块(Channel Attention Module)和空间注意力模块(Spatial Attention Module)。

以下是CBAM注意力机制的基本原理:

1. 通道注意力模块(Channel Attention Module):
输入:经过卷积层的特征图。
处理步骤:
对每个通道进行全局平均池化,得到通道的全局平均值。
通过两个全连接层,将全局平均值映射为两个权重向量(一个用于缩放,一个用于偏置)。
将这两个权重向量与原始特征图相乘,以加权调整每个通道的重要性。

2. 空间注意力模块(Spatial Attention Module):**
输入:通道注意力模块的输出。
处理步骤:
     对每个通道的特征图进行分别的最大池化和平均池化,得到两个空间特征图。
     将这两个空间特征图相加,通过一个卷积层产生一个权重图。
     将原始特征图与权重图相乘,以加权调整每个空间位置的重要性。

3. 整合:
   将通道注意力模块和空间注意力模块的输出相乘,得到最终的注意力增强特征图。
   将这个注意力增强的特征图传递给网络的下一层进行进一步处理。

CBAM的关键优势在于它能够同时考虑通道和空间信息,有助于网络更好地理解和利用输入特征。这种注意力机制有助于提高网络在视觉任务上的性能,使其能够更有针对性地关注重要的特征。

 代码实现

路径:"./ultralytics/nn/modules/conv.py"

class ChannelAttention(nn.Module):"""Channel-attention module https://github.com/open-mmlab/mmdetection/tree/v3.0.0rc1/configs/rtmdet."""def __init__(self, channels: int) -> None:super().__init__()self.pool = nn.AdaptiveAvgPool2d(1)self.fc = nn.Conv2d(channels, channels, 1, 1, 0, bias=True)self.act = nn.Sigmoid()def forward(self, x: torch.Tensor) -> torch.Tensor:return x * self.act(self.fc(self.pool(x)))class SpatialAttention(nn.Module):"""Spatial-attention module."""def __init__(self, kernel_size=7):"""Initialize Spatial-attention module with kernel size argument."""super().__init__()assert kernel_size in (3, 7), 'kernel size must be 3 or 7'padding = 3 if kernel_size == 7 else 1self.cv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)self.act = nn.Sigmoid()def forward(self, x):"""Apply channel and spatial attention on input for feature recalibration."""return x * self.act(self.cv1(torch.cat([torch.mean(x, 1, keepdim=True), torch.max(x, 1, keepdim=True)[0]], 1)))class CBAM(nn.Module):"""Convolutional Block Attention Module."""def __init__(self, c1, kernel_size=7):  # ch_in, kernelssuper().__init__()self.channel_attention = ChannelAttention(c1)self.spatial_attention = SpatialAttention(kernel_size)def forward(self, x):"""Applies the forward pass through C1 module."""return self.spatial_attention(self.channel_attention(x))

添加完代码以后需要在"./ultralytics/nn/tasks.py"进行注册

 GAM注意力机制
原理

目标的设计是一种减少信息缩减并放大全局维度交互特征的机制。我们采用 CBAM 的顺序通道空间注意力机制并重新设计子模块。整个过程如图所示。

GAM结构图


通道注意力机制
通道注意力子模块使用 3D 排列来保留三个维度的信息。然后,它使用两层 MLP(多层感知器)放大跨维度通道空间依赖性。 (MLP是一种编码器-解码器结构,其缩减比为r,与BAM相同。)通道注意子模块如图所示。 

通道注意力子模块


空间注意力机制
在空间注意力子模块中,为了关注空间信息,我们使用两个卷积层进行空间信息融合。我们还使用与 BAM 相同的通道注意子模块的缩减率 r。同时,最大池化会减少信息并产生负面影响。我们删除池化以进一步保留特征图。因此,空间注意力模块有时会显着增加参数的数量。为了防止参数显着增加,我们在 ResNet50 中采用带有通道洗牌的组卷积。没有组卷积的空间注意力子模块如图所示。 

空间注意力子模块
代码实现

代码添加在 ./ultralytics/nn/modules/conv.py 中,同样需要在task.py中注册

class GAM_Attention(nn.Module):def __init__(self, c1, c2, group=True, rate=4):super(GAM_Attention, self).__init__()self.channel_attention = nn.Sequential(nn.Linear(c1, int(c1 / rate)),nn.ReLU(inplace=True),nn.Linear(int(c1 / rate), c1))self.spatial_attention = nn.Sequential(nn.Conv2d(c1, c1 // rate, kernel_size=7, padding=3, groups=rate) if group else nn.Conv2d(c1, int(c1 / rate),kernel_size=7,padding=3),nn.BatchNorm2d(int(c1 / rate)),nn.ReLU(inplace=True),nn.Conv2d(c1 // rate, c2, kernel_size=7, padding=3, groups=rate) if group else nn.Conv2d(int(c1 / rate), c2,kernel_size=7,padding=3),nn.BatchNorm2d(c2))def forward(self, x):b, c, h, w = x.shapex_permute = x.permute(0, 2, 3, 1).view(b, -1, c)x_att_permute = self.channel_attention(x_permute).view(b, h, w, c)x_channel_att = x_att_permute.permute(0, 3, 1, 2)# x_channel_att=channel_shuffle(x_channel_att,4) #last shufflex = x * x_channel_attx_spatial_att = self.spatial_attention(x).sigmoid()x_spatial_att = channel_shuffle(x_spatial_att, 4)  # last shuffleout = x * x_spatial_att# out=channel_shuffle(out,4) #last shufflereturn out
修改损失函数

WIoU是一种新型的损失函数,代码实现

def WIoU(cls, pred, target, self=None):self = self if self else cls(pred, target)dist = torch.exp(self.l2_center / self.l2_box.detach())return self._scaled_loss(dist * self.iou)

 这个其实就是修改了loss.py中的BboxLoss,在本段代码的第十二行,将type改成了WIoU

class BboxLoss(nn.Module):def __init__(self, reg_max, use_dfl=False):"""Initialize the BboxLoss module with regularization maximum and DFL settings."""super().__init__()self.reg_max = reg_maxself.use_dfl = use_dfldef forward(self, pred_dist, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask):"""IoU loss."""weight = target_scores.sum(-1)[fg_mask].unsqueeze(-1)loss,iou = bbox_iou(pred_bboxes[fg_mask], target_bboxes[fg_mask], xywh=False,type_='WIoU')loss_iou=loss.sum()/target_scores_sum# DFL lossif self.use_dfl:target_ltrb = bbox2dist(anchor_points, target_bboxes, self.reg_max)loss_dfl = self._df_loss(pred_dist[fg_mask].view(-1, self.reg_max + 1), target_ltrb[fg_mask]) * weightloss_dfl = loss_dfl.sum() / target_scores_sumelse:loss_dfl = torch.tensor(0.0).to(pred_dist.device)return loss_iou, loss_dfl
YAML文件
# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect# Parameters
nc: 9  # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'# [depth, width, max_channels]n: [0.33, 0.25, 1024]  # YOLOv8n summary: 225 layers,  3157200 parameters,  3157184 gradients,   8.9 GFLOPss: [0.33, 0.50, 1024]  # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients,  28.8 GFLOPsm: [0.67, 0.75, 768]   # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients,  79.3 GFLOPsl: [1.00, 1.00, 512]   # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPsx: [1.00, 1.25, 512]   # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs# YOLOv8.0n backbone
backbone:# [from, repeats, module, args]- [-1, 1, Conv, [64, 3, 2]]  # 0-P1/2- [-1, 1, Conv, [128, 3, 2]]  # 1-P2/4- [-1, 3, C2f, [128, True]]- [-1, 1, Conv, [256, 3, 2]]  # 3-P3/8- [-1, 6, C2f, [256, True]]- [-1, 1, Conv, [512, 3, 2]]  # 5-P4/16- [-1, 6, C2f, [512, True]]- [-1, 1, Conv, [1024, 3, 2]]  # 7-P5/32- [-1, 3, C2f, [1024, True]]- [-1, 1, SPPF, [1024, 5]]  # 9# YOLOv8.0n head
head:- [-1, 1, nn.Upsample, [None, 2, 'nearest']]- [[-1, 6], 1, Concat, [1]]  # cat backbone P4- [-1, 3, C2f, [512]]  # 12- [-1, 1, GAM_Attention, [512,512]]- [-1, 1, nn.Upsample, [None, 2, 'nearest']]- [[-1, 4], 1, Concat, [1]]  # cat backbone P3- [-1, 3, C2f, [256]]  # 16 (P3/8-small)- [-1, 1, GAM_Attention, [256,256]]- [-1, 1, Conv, [256, 3, 2]]- [[-1, 12], 1, Concat, [1]]  # cat head P4- [-1, 3, C2f, [512]]  # 20 (P4/16-medium)- [-1, 1, GAM_Attention, [512,512]]- [-1, 1, Conv, [512, 3, 2]]- [[-1, 9], 1, Concat, [1]]  # cat head P5- [-1, 3, C2f, [1024]]  # 24 (P5/32-large)- [-1, 1, GAM_Attention, [1024,1024]]- [[17, 21, 25], 1, Detect, [nc]]  # Detect(P3, P4, P5)

在head部分,可以将GAM_attention改成不同的注意力机制,来改变网络结构,从而提升目标检测 的精度

完整代码

链接: https://pan.baidu.com/s/1IDnEZxpcaEgBowlTxX2iNA?pwd=vdrs 提取码: vdrs 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/723009.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Mol2文件处理-拆分、合并、提取名称、计数与格式转换

欢迎浏览我的CSND博客! Blockbuater_drug …点击进入 文章目录 前言一、Mol2文件合并二、Mol2文件拆分为含有单个分子的文件三、Mol2文件分子名称修改与提取3.1 分子名称修改去除空格3.2 文件名称提取 四、Mol2文件包含分子计数4.1 Mol2文件中分子计数4.2 分子计数传…

Python——与Matlab对应的Python版本

参考资料: Python——与Matlab对应的Python版本

Rust 开发的高性能 Python 包管理工具,可替换 pip、pip-tools 和 virtualenv

最近,我在 Python 潮流周刊 中分享了一个超级火爆的项目,这还不到一个月,它在 Github 上已经拿下了 8K star 的亮眼成绩,可见其受欢迎程度极高!国内还未见有更多消息,我趁着周末把一篇官方博客翻译出来了&a…

请说明Vue中的解耦能力

Vue中的解耦能力是指在Vue框架中,我们能够有效地将代码分离成独立的组件或模块,使得这些组件之间的依赖关系减少,实现高内聚、低耦合的设计目标。利用Vue中的组件化开发,可以让不同的模块之间更容易地通信和协作,提高代…

【小白学机器学习7】相关系数R,决定系数R2和SST=SSR+SSE, 离差,偏差,方差,标准差,编译系数,标志误。

目录 1 各种数据指标,分类整理 1.0 关于数据/值有3种 1.1 第1类:描述一堆数据特征的指标:集中度,离散度,形状特征 1.2 第2类:判断预测y值和观测值差距的指标 1.3 第3类:描述误差的各种指标…

无线地勘答题模板

(三)无线网络配置 CII集团公司拟投入13万元(网络设备采购部分),项目要求重点覆盖楼层、走廊和办公室。平面布局如图1所示。 图1 平面布局图 1.绘制AP点位图(包括:AP型号、编号、信道等信息,其中信道采用2.4G的1、6、11三个信道进行规划)。 2.使用无线地勘软件,输出…

html标签之表格标签,程序员必看

突破困境: 1. 提升学历 前端找工作,学历重要吗? 重要。谁要是告诉你不重要那一定是在骗你。现实情况是大专吃紧,本科够用,硕士占优,大专以下找到工作靠运气和 戳这里领取完整开源项目:【一线大…

【力扣经典面试题】14. 最长公共前缀

目录 一、题目描述 二、解题思路 三、解题步骤 四、代码实现(C版详细注释) 五、总结 欢迎点赞关注哦!创作不易,你的支持是我的不竭动力,更多精彩等你哦。 一、题目描述 编写一个函数来查找字符串数组中的最长公共前缀。…

微软研究深度报告:Sora文转视频AI模型全景剖析及未来展望

论文由微软研究团队撰写,这篇论文深入探讨了Sora的发展背景、核心技术、新兴应用场景、现有的局限性以及未来的发展机会,基于公开资料和团队自行进行的逆向工程分析。文中详尽且逻辑清晰,建议细读全文以获得深入了解。 原文:Sora…

第四节 JDBC简单示例代码

本文章教程中将演示如何创建一个简单的JDBC应用程序的示例。 这将显示如何打开数据库连接,执行SQL查询并显示结果。 这个示例代码中涉及所有步骤,一些步骤将在本教程的后续章节中进行说明。 创建JDBC应用程序 构建JDBC应用程序涉及以下六个步骤 - 导…

Java并发编程-进程和线程

一、进程和线程 1. 进程 什么是进程? 简单来说,进程就是程序的一次启动和执行。进程是操作系统中的一个概念,它代表正在运行的程序的实例。每个进程都有自己的内存空间、代码和数据,以及其他操作系统资源,如文件和设备…

Git分布式管理-头歌实验远程版本库

Git的一大特点就是,能为不同系统下的开发者提供了一个协作开发的平台。而团队如果要基于Git进行协同开发,就必须依赖远程版本库。远程版本库允许,我们将本地版本库保存在远端服务器,而且,不同的开发者也是基于远程版本…

力扣hot100:560.和为K的子数组(前缀和+哈希表)

分析: 这个题目乍一看,数据大小用暴力解法大概率会超时,可能想用双指针,但是问题出现在 可能存在负数,也就是说即使是找到了一个答案,后面也可能存在负数和正数抵消,又是答案,因此不…

SpringBoot集成Logback

logback logback-core:其它两个模块的基础模块。logback-classic:它是log4j的一个改良版本,同时它完整实现了slf4j API使你可以很方便地更换成其它日志系统如log4j或JDK14 Logging。logback-access:访问模块与Servlet容器集成提供…

08-prometheus监控的告警通知-alertmanager组件工具

一、概述 prometheus通过规则文件对比抓取到的数据,来判断是否触发告警,我们通过配置告警的工具altermanager进行告警通知; 规则文件,写的就是,当我们获取到的PromeQL的值到达一个设置的规则后,触发告警&am…

刷题笔记day27-回溯算法3

39. 组合总和 var path []int var tmp []int var result [][]int// 还是需要去重复,题目中要求的是至少一个数字备选的数量不同。 // 所以需要剪枝操作,右边的要比左边的> func combinationSum(candidates []int, target int) [][]int {// 组合问题pa…

白皮书发布|超融合运行 K8s 的场景、功能与优势

目前,不少企业都使用虚拟化/超融合运行 Kubernetes 和容器化应用。一些用户可能会有疑惑:既然 Kubernetes 可以部署在裸金属上,使用虚拟化不是“多此一举”吗? 在电子书《IT 基础架构团队的 Kubernetes 管理:从入门到…

详细分析Vue中的$refs用法

目录 1. 基本知识2. Demo 1. 基本知识 在Vue.js中,$refs是一个特殊的属性,用于在组件内部直接访问子组件或者DOM元素 作用: 访问DOM元素: 直接访问模板中的DOM元素,以便执行DOM操作,如聚焦、改变样式等 访…

[极客大挑战 2020]Roamphp1-Welcome ---不会编程的崽

buuctf上的题难度适中。越到后边会越难&#xff0c;但也有例外 页面报错了。报错的原因可能有很多种猜想。所以有没有一种可能是故意这么设计的。先抓包吧 发现是GET请求。修改请求方法再试试呢&#xff1f; <?php error_reporting(0); if ($_SERVER[REQUEST_METHOD] ! P…

Android Studio开发(一) 构建项目

1、项目创建测试 1.1 前言 Android Studio 是由 Google 推出的官方集成开发环境&#xff08;IDE&#xff09;&#xff0c;专门用于开发 Android 应用程序。 基于 IntelliJ IDEA: Android Studio 是基于 JetBrains 的 IntelliJ IDEA 开发的&#xff0c;提供了丰富的功能和插件…