YOLOv8 | 有效涨点,添加GAM注意力机制,使用Wise-IoU有效提升目标检测效果(附报错解决技巧,全网独家)

 目录

摘要

基本原理

通道注意力机制

空间注意力机制

GAM代码实现 

Wise-IoU 

WIoU代码实现

yaml文件编写

完整代码分享(含多种注意力机制)


摘要

人们已经研究了各种注意力机制来提高各种计算机视觉任务的性能。然而,现有方法忽视了保留通道和空间方面的信息以增强跨维度交互的重要性。因此,我们提出了一种全局注意力机制,通过减少信息减少和放大全局交互表示来提高深度神经网络的性能。引入了具有多层感知器的 3D 排列,用于通道注意以及卷积空间注意子模块。在 CIFAR-100 和 ImageNet-1K 上对所提出的图像分类任务机制的评估表明,我们的方法稳定优于最近使用 ResNet 和轻量级 MobileNet 的几种注意力机制。

基本原理

目标的设计是一种减少信息缩减并放大全局维度交互特征的机制。我们采用 CBAM 的顺序通道空间注意力机制并重新设计子模块。整个过程如图 所示。

GAM结构图
通道注意力机制

通道注意力子模块使用 3D 排列来保留三个维度的信息。然后,它使用两层 MLP(多层感知器)放大跨维度通道空间依赖性。 (MLP是一种编码器-解码器结构,其缩减比为r,与BAM相同。)通道注意子模块如图所示。 

通道注意力子模块
空间注意力机制

在空间注意力子模块中,为了关注空间信息,我们使用两个卷积层进行空间信息融合。我们还使用与 BAM 相同的通道注意子模块的缩减率 r。同时,最大池化会减少信息并产生负面影响。我们删除池化以进一步保留特征图。因此,空间注意力模块有时会显着增加参数的数量。为了防止参数显着增加,我们在 ResNet50 中采用带有通道洗牌的组卷积。没有组卷积的空间注意力子模块如图所示。 

空间注意力子模块
GAM代码实现 
class GAM_Attention(nn.Module):def __init__(self, c1, c2, group=True, rate=4):super(GAM_Attention, self).__init__()self.channel_attention = nn.Sequential(nn.Linear(c1, int(c1 / rate)),nn.ReLU(inplace=True),nn.Linear(int(c1 / rate), c1))self.spatial_attention = nn.Sequential(nn.Conv2d(c1, c1 // rate, kernel_size=7, padding=3, groups=rate) if group else nn.Conv2d(c1, int(c1 / rate),kernel_size=7,padding=3),nn.BatchNorm2d(int(c1 / rate)),nn.ReLU(inplace=True),nn.Conv2d(c1 // rate, c2, kernel_size=7, padding=3, groups=rate) if group else nn.Conv2d(int(c1 / rate), c2,kernel_size=7,padding=3),nn.BatchNorm2d(c2))def forward(self, x):b, c, h, w = x.shapex_permute = x.permute(0, 2, 3, 1).view(b, -1, c)x_att_permute = self.channel_attention(x_permute).view(b, h, w, c)x_channel_att = x_att_permute.permute(0, 3, 1, 2)# x_channel_att=channel_shuffle(x_channel_att,4) #last shufflex = x * x_channel_attx_spatial_att = self.spatial_attention(x).sigmoid()x_spatial_att = channel_shuffle(x_spatial_att, 4)  # last shuffleout = x * x_spatial_att# out=channel_shuffle(out,4) #last shufflereturn out

以上代码添加在 ./ultralytics/nn/modules/conv.py 中

Wise-IoU 

Yolov7提出的损失函数是GIoU(Generalized Intersection over Union),能在更广义的层面上计算IoU(Intersection over Union),但是当两个预测框完全重合时,不能反映出实际情况,此时GIoU就要退化为IoU,并且GIoU对每个预测框与真实框均要计算最小外接框,故损失函数计算及收敛速度受到限制。
为了弥补这种遗憾,改进的网络中使用了WIoU(Wise-IoU)作为损失函数。WIoU v3作为边界框回归损失,包含一种动态非单调机制,并设计了一种合理的梯度增益分配,该策略减少了极端样本中出现的大梯度或有害梯度。该损失方法计算更多地关注普通质量的样本,进而提高网络模型的泛化能力和整体性能。

虽然几种主流损失函数都采用静态聚焦机制,但WIoU不仅考虑了方位角、质心距离和重叠面积,还引入了动态非单调聚焦机制。 WIoU应用合理的梯度增益分配策略来评估锚框的质量。WIoU有三个版本。 WIoU v1 设计了基于注意力的预测框损失,WIoU v2 和 WIoU v3 添加了聚焦系数。

wiou原理图

最小的包围盒(绿色)和中心点的连接(红色),其中并集的面积为 Su = wh + wgthgt − WiHi .

WIoU代码实现
def WIoU(cls, pred, target, self=None):self = self if self else cls(pred, target)dist = torch.exp(self.l2_center / self.l2_box.detach())return self._scaled_loss(dist * self.iou)

 下面的代码替换loss.py的class BboxLoss

class BboxLoss(nn.Module):def __init__(self, reg_max, use_dfl=False):"""Initialize the BboxLoss module with regularization maximum and DFL settings."""super().__init__()self.reg_max = reg_maxself.use_dfl = use_dfldef forward(self, pred_dist, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask):"""IoU loss."""weight = target_scores.sum(-1)[fg_mask].unsqueeze(-1)loss,iou = bbox_iou(pred_bboxes[fg_mask], target_bboxes[fg_mask], xywh=False,type_='WIoU')loss_iou=loss.sum()/target_scores_sum# DFL lossif self.use_dfl:target_ltrb = bbox2dist(anchor_points, target_bboxes, self.reg_max)loss_dfl = self._df_loss(pred_dist[fg_mask].view(-1, self.reg_max + 1), target_ltrb[fg_mask]) * weightloss_dfl = loss_dfl.sum() / target_scores_sumelse:loss_dfl = torch.tensor(0.0).to(pred_dist.device)return loss_iou, loss_dfl
yaml文件编写
# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect# Parameters
nc: 1  # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'# [depth, width, max_channels]n: [0.33, 0.25, 1024]  # YOLOv8n summary: 225 layers,  3157200 parameters,  3157184 gradients,   8.9 GFLOPss: [0.33, 0.50, 1024]  # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients,  28.8 GFLOPsm: [0.67, 0.75, 768]   # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients,  79.3 GFLOPsl: [1.00, 1.00, 512]   # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPsx: [1.00, 1.25, 512]   # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs# YOLOv8.0n backbone
backbone:# [from, repeats, module, args]- [-1, 1, Conv, [64, 3, 2]]  # 0-P1/2- [-1, 1, Conv, [128, 3, 2]]  # 1-P2/4- [-1, 3, C2f, [128, True]]- [-1, 1, Conv, [256, 3, 2]]  # 3-P3/8- [-1, 6, C2f, [256, True]]- [-1, 1, Conv, [512, 3, 2]]  # 5-P4/16- [-1, 6, C2f, [512, True]]- [-1, 1, Conv, [1024, 3, 2]]  # 7-P5/32- [-1, 3, C2f, [1024, True]]- [-1, 3, GAM_Attention, [1024]]- [-1, 1, SPPF, [1024, 5]]  # 10# YOLOv8.0n head
head:- [-1, 1, nn.Upsample, [None, 2, 'nearest']]- [[-1, 6], 1, Concat, [1]]  # cat backbone P4- [-1, 3, C2f, [512]]  # 13#- [-1, 1, GAM_Attention, [512,512]]- [-1, 1, nn.Upsample, [None, 2, 'nearest']]- [[-1, 4], 1, Concat, [1]]  # cat backbone P3- [-1, 3, C2f, [256]]  # 16 (P3/8-small)#- [-1, 1, GAM_Attention, [256,256]]- [-1, 1, Conv, [256, 3, 2]]- [[-1, 13], 1, Concat, [1]]  # cat head P4- [-1, 3, C2f, [512]]  # 19 (P4/16-medium)#- [-1, 1, GAM_Attention, [512,512]]- [-1, 1, Conv, [512, 3, 2]]- [[-1, 10], 1, Concat, [1]]  # cat head P5- [-1, 3, C2f, [1024]]  # 22 (P5/32-large)#- [-1, 1, GAM_Attention, [1024,1024]]- [[16, 19, 22], 1, Detect, [nc]]  # Detect(P3, P4, P5)
完整代码分享(含多种注意力机制)

内涵SA,CBAM,GAM,ECA等多种注意力机制

链接: https://pan.baidu.com/s/1T9bVifTPCRMv2t7eREsuEw?pwd=nbrt 提取码: nbrt 

报错解决办法

YOLOv8 | 添加注意力机制报错KeyError:已解决,详细步骤-CSDN博客

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/747460.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

C语言例3-20:使用逻辑运算符的例子

代码如下&#xff1a; #include<stdio.h> int main(void) {int x3, y100;float f11.0f, f22.1f;char cd; //d(100)printf("!x 的值为&#xff1a; %d\n",!x); //0printf("x||y 的值为&#xff1a; %d\n",x||y); //1print…

ai怎么制作ppt?保姆级的ai一键生成ppt教程来了!

面对市面上多如牛毛的 ai 生成 ppt 软件&#xff0c;哪一款更适合日常使用呢&#xff1f;与此同时&#xff0c;在选定一款 ai 软件后&#xff0c;如何用 ai 制作 ppt&#xff0c;也是很多人第一次使用 pptai 工具会面临的具体问题。 就着这些问题&#xff0c;在接下来的文章中…

有哪些便宜的通配符(泛域名)证书?怎么申请?

通配符&#xff08;泛域名&#xff09;SSL证书就是用来保护一个主域名以及所有二级子域名的证书&#xff0c;相对于单域名证书更具有性价比。 主要优势在于&#xff1a; 一&#xff1a;一个整数覆盖所有子域名 仅仅用一张证书就可以保护一个主域名以及所有子域名&#xff0c;…

HPA数据库及HPAanalyze包使用

关于HPA数据库的介绍&#xff1a;Human Protein Atlas 数据库 – 王进的个人网站 (jingege.wang) The Human Protein Atlas 文献 HPAanalyze: an R package that facilitates the retrieval and analysis of the Human Protein Atlas data | BMC Bioinformatics | Full Text …

【PPO】近端策略优化【Clip版本,离散动作】

本博客代码参考了《动手学强化学习-PPO》 PPO算法是在Actor-Critic的基础上进行训练目标的调整。其改进的地方在于对每次参数更新进行了限制。 PPO 是 TRPO 的一种改进算法&#xff0c;它在实现上简化了 TRPO 中的复杂计算&#xff0c;并且它在实验中的性能大多数情况下会比 …

服务模块划分规范

一、PO :(persistant object )&#xff0c;持久对象 可以看成是与数据库中的表相映射的java对象。使用Hibernate来生成PO是不错的选择。 二、VO :(value object) &#xff0c;值对象 通常用于业务层之间的数据传递&#xff0c;和PO一样也是仅仅包含数据而已。但应是抽象出的…

功能问题:如何用Docker部署一个后端项目?

大家好&#xff0c;我是大澈&#xff01; 本文约1800字&#xff0c;整篇阅读大约需要3分钟。 关注微信公众号&#xff1a;“程序员大澈”&#xff0c;免费加入问答群&#xff0c;一起交流技术难题与未来&#xff01; 现在关注公众号&#xff0c;免费送你 ”前后端入行大礼包…

SwiftU的组件 - TabView

SwiftU的组件 - TabView 记录一下SwiftU的组件 - TabView的两种style分别的使用方式 import SwiftUIstruct TabViewBootCamp: View {State var selectedIndex 0var body: some View {NavigationView {TabView(selection: $selectedIndex) {HomeView(selectedIndex: $selected…

基于python的《彩图版飞机大战》程序使用说明(附源码下载)

在PyCharm中运行《彩图版飞机大战》即可进入如图1所示的游戏界面。 图1 游戏主界面 具体的操作步骤如下&#xff1a; &#xff08;1&#xff09;玩游戏。在游戏主界面中&#xff0c;从屏幕的顶部不断出现下落的敌机&#xff0c;玩家按下键盘上的↑、↓、←、→方向键移动飞机…

Android 深入Http(2)加密与编码

可以对二进制数据&#xff08;比如图片、视频&#xff09; 经典算法&#xff1a; DES&#xff08;密钥短被弃用了&#xff09; AES &#xff08;密钥很长 很顶&#xff09; 速度快&#xff0c;效率高 IDEA 3DES&#xff08;三重DES&#xff0c;听起来就很慢和重 &#xf…

VGG论文学习笔记

题目&#xff1a;VERY DEEP CONVOLUTIONAL NETWORKS FOR LARGE-SCALE IMAGE RECOGNITION 论文下载地址&#xff1a;VGG论文 摘要 目的&#xff1a;研究深度对精度的影响 方法&#xff1a;使用3*3滤波器不断增加深度&#xff0c;16和19效果显著 成绩&#xff1a;在ImageNet 20…

搭建知识管理系统并不复杂,这篇教程来帮你

许多人都有这样的体验&#xff1a;我们抓住的想法和知识总在不经意间溜走&#xff0c;我们想要的信息总是一时无法找到。因此&#xff0c;搭建一个能够系统化、分类和索引存储这些知识的“知识管理系统”是必要的。听上去很专业&#xff0c;其实并不复杂&#xff0c;让我们一步…

mysql: 如何开启慢查询日志?

1 确认慢查询日志功能已开启 执行以下sql语句&#xff0c;查看慢查询功能是否开启&#xff1a; show VARIABLES like slow_query_log;如果为ON&#xff0c;表示打开&#xff1b;如果为OFF&#xff0c;表示没有打开&#xff0c;需要开启慢查询功能。 执行以下sql语句&#xff0…

修改 MySQL update_time 默认值的坑

由于按规范需要对 update_time 字段需要对它做默认值的设置 现在有一个原始的表是这样的 CREATE TABLE test_up (id bigint(20) unsigned NOT NULL AUTO_INCREMENT COMMENT 主键id,update_time datetime default null COMMENT 操作时间,PRIMARY KEY (id) ) ENGINEInnoDB DEF…

MapStruct代替BeanUtils.copyProperties ()使用

1.为什么MapStruct代替BeanUtils.copyProperties () 第一&#xff1a;因为BeanUtils 采用反射的机制动态去进行拷贝映射&#xff0c;特别是Apache的BeanUtils的性能很差&#xff0c;而且并不支持所有数据类型的拷贝&#xff0c;虽然使用较为方便&#xff0c;但是强烈不建议使用…

鸿蒙Harmony应用开发—ArkTS声明式开发(基础手势:NavRouter)

导航组件&#xff0c;默认提供点击响应处理&#xff0c;不需要开发者自定义点击事件逻辑。 说明&#xff1a; 该组件从API Version 9开始支持。后续版本如有新增内容&#xff0c;则采用上角标单独标记该内容的起始版本。 子组件 必须包含两个子组件&#xff0c;其中第二个子组…

分析型数据库的主要使用场景有哪些?

如今数据已经成为了企业和组织的核心资产。如何有效地管理和利用这些数据&#xff0c;成为了决定竞争力的关键。分析型数据库作为数据处理领域的重要工具&#xff0c;为各行各业提供了强大的数据分析和洞察能力。基于分析型数据库&#xff08;Apache Doris &#xff09;构建的现…

当模型足够大时,Bias项不会有什么特别的作用

问题来源&#xff1a; 阅读OLMo论文时&#xff0c;发现有如下一段话&#xff1a; 加上前面研究llama和mistral结构时好奇为什么都没有偏置项了 偏置项的作用&#xff1a; 回到第一性原理来分析&#xff0c;为什么要有偏置项的存在呢&#xff1f; 在神经网络中&#xff0c;…

跨境热点!TikTok直播网络要求是什么?

TikTok直播作为一种互动性强、实时性要求高的社交媒体形式&#xff0c;对网络环境有着一系列特定的需求。了解并满足这些需求&#xff0c;对于确保用户体验、提高直播质量至关重要。本文将深入探讨TikTok直播对网络环境的要求以及如何优化网络设置以满足这些要求。 TikTok直播的…

mac启动elasticsearch

1.首先下载软件&#xff0c;然后双击解压&#xff0c;我用的是7.17.3的版本 2.然后执行如下命令 Last login: Thu Mar 14 23:14:44 on ttys001 diannao1xiejiandeMacBook-Air ~ % cd /Users/xiejian/local/software/elasticsearch/elasticsearch-7.17.3 diannao1xiejiandeMac…