免费阅读篇 | 芒果YOLOv8改进110:注意力机制GAM:用于保留信息以增强渠道空间互动

💡🚀🚀🚀本博客 改进源代码改进 适用于 YOLOv8 按步骤操作运行改进后的代码即可

该专栏完整目录链接: 芒果YOLOv8深度改进教程

该篇博客为免费阅读内容,直接改进即可🚀🚀🚀

文章目录

      • 1. GAM论文
      • 2. YOLOv8 核心代码改进部分
      • 2.1 核心新增代码
        • 2.2 修改部分
      • 2.3 YOLOv8-gam 网络配置文件
      • 2.4 运行代码
      • 改进说明


1. GAM论文

在这里插入图片描述

研究了多种注意力机制来提高各种计算机视觉任务的性能。然而,现有的方法忽略了保留通道和空间方面的信息以增强跨维度交互的重要性。因此,我们提出了一种全局注意力机制,通过减少信息缩减和放大全局交互式表示来提高深度神经网络的性能。我们引入了带有多层感知器的 3D 排列,用于通道注意力以及卷积空间注意力子模块。对CIFAR-100和ImageNet-1K上图像分类任务的所提机制的评估表明,我们的方法在ResNet和轻量级MobileNet上都稳定地优于最近的几种注意力机制。

在这里插入图片描述

具体细节可以去看原论文:https://arxiv.org/pdf/2112.05561v1.pdf


2. YOLOv8 核心代码改进部分

2.1 核心新增代码

首先在ultralytics/nn/modules文件夹下,创建一个 gam.py文件,新增以下代码

import numpy as np
import torch
from torch import nn
from torch.nn import initclass GAMAttention(nn.Module):#https://paperswithcode.com/paper/global-attention-mechanism-retain-informationdef __init__(self, c1, c2, group=True,rate=4):super(GAMAttention, self).__init__()self.channel_attention = nn.Sequential(nn.Linear(c1, int(c1 / rate)),nn.ReLU(inplace=True),nn.Linear(int(c1 / rate), c1))self.spatial_attention = nn.Sequential(nn.Conv2d(c1, c1//rate, kernel_size=7, padding=3,groups=rate)if group else nn.Conv2d(c1, int(c1 / rate), kernel_size=7, padding=3), nn.BatchNorm2d(int(c1 /rate)),nn.ReLU(inplace=True),nn.Conv2d(c1//rate, c2, kernel_size=7, padding=3,groups=rate) if group else nn.Conv2d(int(c1 / rate), c2, kernel_size=7, padding=3), nn.BatchNorm2d(c2))def forward(self, x):b, c, h, w = x.shapex_permute = x.permute(0, 2, 3, 1).view(b, -1, c)x_att_permute = self.channel_attention(x_permute).view(b, h, w, c)x_channel_att = x_att_permute.permute(0, 3, 1, 2)x = x * x_channel_attx_spatial_att = self.spatial_attention(x).sigmoid()x_spatial_att=channel_shuffle(x_spatial_att,4) #last shuffle out = x * x_spatial_attreturn out  def channel_shuffle(x, groups=2):B, C, H, W = x.size()out = x.view(B, groups, C // groups, H, W).permute(0, 2, 1, 3, 4).contiguous()out=out.view(B, C, H, W) return out   
2.2 修改部分

在ultralytics/nn/modules/init.py中导入 定义在 gam.py 里面的模块

from .gam import GAMAttention'GAMAttention' 加到 __all__ = [...] 里面

第一步:
ultralytics/nn/tasks.py文件中,新增

from ultralytics.nn.modules import GAMAttention

然后在 在tasks.py中配置
找到

        elif m is nn.BatchNorm2d:args = [ch[f]]

在这句上面加一个

        elif m is GAMAttention:c1, c2 = ch[f], args[0]if c2 != nc:  # if c2 not equal to number of classes (i.e. for Classify() output)c2 = make_divisible(min(c2, max_channels) * width, 8)args = [c1, c2, *args[1:]]

2.3 YOLOv8-gam 网络配置文件

新增YOLOv8-gam.yaml

# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect# Parameters
nc: 80  # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'# [depth, width, max_channels]n: [0.33, 0.25, 1024]  # YOLOv8n summary: 225 layers,  3157200 parameters,  3157184 gradients,   8.9 GFLOPss: [0.33, 0.50, 1024]  # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients,  28.8 GFLOPsm: [0.67, 0.75, 768]   # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients,  79.3 GFLOPsl: [1.00, 1.00, 512]   # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPsx: [1.00, 1.25, 512]   # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs# YOLOv8.0n backbone
backbone:# [from, repeats, module, args]- [-1, 1, Conv, [64, 3, 2]]  # 0-P1/2- [-1, 1, Conv, [128, 3, 2]]  # 1-P2/4- [-1, 3, C2f, [128, True]]- [-1, 1, Conv, [256, 3, 2]]  # 3-P3/8- [-1, 6, C2f, [256, True]]- [-1, 1, Conv, [512, 3, 2]]  # 5-P4/16- [-1, 6, C2f, [512, True]]- [-1, 1, Conv, [1024, 3, 2]]  # 7-P5/32- [-1, 3, C2f, [1024, True]]- [-1, 3, GAMAttention, [1024]]- [-1, 1, SPPF, [1024, 5]]  # 9# YOLOv8.0n head
head:- [-1, 1, nn.Upsample, [None, 2, 'nearest']]- [[-1, 6], 1, Concat, [1]]  # cat backbone P4- [-1, 3, C2f, [512]]  # 12- [-1, 1, nn.Upsample, [None, 2, 'nearest']]- [[-1, 4], 1, Concat, [1]]  # cat backbone P3- [-1, 3, C2f, [256]]  # 15 (P3/8-small)- [-1, 1, Conv, [256, 3, 2]]- [[-1, 13], 1, Concat, [1]]  # cat head P4- [-1, 3, C2f, [512]]  # 18 (P4/16-medium)- [-1, 1, Conv, [512, 3, 2]]- [[-1, 10], 1, Concat, [1]]  # cat head P5- [-1, 3, C2f, [1024]]  # 21 (P5/32-large)- [[16, 19, 22], 1, Detect, [nc]]  # Detect(P3, P4, P5)

2.4 运行代码

直接替换YOLOv8-gam.yaml 进行训练即可

到这里就完成了这篇的改进。

改进说明

这里改进是放在了主干后面,如果想放在改进其他地方,也是可以的。直接新增,然后调整通道,配齐即可,如果有不懂的,可以添加博主联系方式,如下


🥇🥇🥇
添加博主联系方式:

友好的读者可以添加博主QQ: 2434798737, 有空可以回答一些答疑和问题

🚀🚀🚀


参考

https://github.com/ultralytics/ultralytics

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/754166.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

c++算法学习笔记 (7) BFS

1.走迷宫 #include <iostream> #include <algorithm> #include <queue> #include <cstring> using namespace std; typedef pair<int, int> PII; const int N 105; int n, m; int g[N][N]; // 存图 int d[N][N]; // 每个点到起点的距离 queue&…

【每日一题】134. 加油站

在一条环路上有 n 个加油站&#xff0c;其中第 i 个加油站有汽油 gas[i] 升。 你有一辆油箱容量无限的的汽车&#xff0c;从第 i 个加油站开往第 i1 个加油站需要消耗汽油 cost[i] 升。你从其中的一个加油站出发&#xff0c;开始时油箱为空。 给定两个整数数组 gas 和 cost &…

信号处理--基于正则化聚合的共空间模态(CSP)脑电信号分类

目录 理论 工具 方法实现 代码获取 参考文献 理论 传统的通用空间模式 (CSP) 是一种流行的算法,用于对脑电图 (EEG) 信号进行分类。本文主要介绍小样本设置 (SSS) 中 CSP 的正则化和聚合技术。传统的 CSP 基于样本协方差矩阵估计。如果训练样本数量较少,其脑电图分类的…

万字数据仓库面试题及参考答案

数仓架构设计的方法和原则&#xff1a; 数仓架构设计的方法主要包括需求驱动、数据驱动和技术驱动。需求驱动是指根据业务需求进行设计&#xff0c;数据驱动是指基于数据的特点和规律进行设计&#xff0c;技术驱动是指充分利用现有技术和工具进行设计。 数仓架构设计的原则包…

Java八股文(XXL-JOB)

Java八股文のXXL-JOB XXL-JOB XXL-JOB xxl-job 是什么&#xff1f;它的主要作用是什么&#xff1f; xxl-job 是一款分布式任务调度平台&#xff0c;用于解决分布式系统中的定时任务和异步任务调度问题。 它提供了任务的注册、调度、执行和监控等功能&#xff0c;能够帮助开发者…

MindGraph:文字生成知识图

欢迎来到MindGraph&#xff0c;这是一个概念验证、开源的、以API为先的基于图形的项目&#xff0c;旨在通过自然语言的交互&#xff08;输入和输出&#xff09;来构建和定制CRM解决方案。该原型旨在便于集成和扩展。以下是关于X的公告&#xff0c;提供更多背景信息。开始之前&a…

Python错题集-9PermissionError:[Errno 13] (权限错误)

1问题描述 Traceback (most recent call last): File "D:\pycharm\projects\5-《Python数学建模算法与应用》程序和数据\02第2章 Python使用入门\ex2_38_1.py", line 9, in <module> fpd.ExcelWriter(data2_38_3.xlsx) #创建文件对象 File "D:…

[Vue]路由

Vue路由 Vue中的路由&#xff1a;路径和组件的映射关系 路由基本使用 下载 VueRouter 模块到当前工程&#xff0c;版本3.6.5 (vue2) npm i vue-router3.6.5 main.js中引入VueRouter import VueRouter from vue-router 注册插件 App.use(VueRouter) 创建路由对象 const rou…

机器学习----特征缩放

目录 一、什么是特征缩放&#xff1a; 二、为什么要进行特征缩放&#xff1f; 三、如何进行特征缩放&#xff1a; 1、归一化&#xff1a; 2、均值归一化&#xff1a; 3、标准化&#xff08;数据需要符合正态分布&#xff09;&#xff1a; 一、什么是特征缩放&#xff1a; 通…

Python 基于 OpenCV 视觉图像处理实战 之 OpenCV 简单实战案例 之二 素描画风格效果

Python 基于 OpenCV 视觉图像处理实战 之 OpenCV 简单实战案例 之二 素描画风格效果 目录 Python 基于 OpenCV 视觉图像处理实战 之 OpenCV 简单实战案例 之二 素描画风格效果 一、简单介绍 二、素描画风格效果实现原理 三、案例简单实现步骤 一、简单介绍 Python是一种跨…

react native 实现自定义底部导航与路由文件配置

首先先把需要的一些库引入 yarn install react-navigation/native yarn install react-native-screens react-native-safe-area-context yarn install react-navigation/native-stack yarn add react-navigation/bottom-tabs 创建路由文件及四个底部导航页面 router文件下的bot…

opengl使用着色器的示例程序

使用了glew库和freeglut库 #include <GL/glew.h> #include <GL/freeglut.h> #include <iostream>// 窗口大小 const GLint WIDTH = 800, HEIGHT = 600;

python 深度学习 记录遇到的报错问题12

本篇继python 深度学习 记录遇到的报错问题11_undefined symbol: __nvjitlinkadddata_12_1, version-CSDN博客 目录 一、AttributeError: module ‘tensorflow‘ has no attribute ‘app‘ 二、AttributeError: module tensorflow has no attribute placeholder 三、Attribu…

pytorch升级打怪(八)

保存模型和加载已有模型 保存并加载模型保存加载 保存并加载模型 在本节中&#xff0c;我们将研究如何通过保存、加载和运行模型预测来保持模型状态。 import torch import torchvision.models as models保存 PyTorch模型将学习的参数存储在内部状态字典中&#xff0c;称为s…

掘根宝典之C++RTTI和类型转换运算符

什么是RTTI RTTI是运行阶段类型识别的简称。 哪些是RTTI? C有3个支持RTTI的元素。 1.dynamic_cast运算符将使用一个指向基类的指针来生成一个指向派生类的指针&#xff0c;否则该运算符返回0——空指针。 2.typeid运算符返回一个指出对象类型的信息 3.type_info结构存储…

后端配置拦截器的一个问题【问题】

后端配置拦截器的一个问题【问题】 前言版权后端配置拦截器的一个问题问题解决 最后 前言 2024-3-14 00:07:28 以下内容源自《【问题】》 仅供学习交流使用 版权 禁止其他平台发布时删除以下此话 本文首次发布于CSDN平台 作者是CSDN日星月云 博客主页是https://jsss-1.blog…

el-form 的表单校验,如何验证某一项或者多项;validateField 的使用

通常对form表单的校验都是整体校验&#xff1a; this.$refs.form.validate( valid > {if (valid) {// 校验通过&#xff0c;业务逻辑代码...} }); 如果需要对表单里的特定一项或几项进行校验&#xff0c;应该如何实现&#xff1f; 业务场景&#xff1a;下图点探测按钮时…

Python 井字棋游戏

井字棋是一种在3 * 3格子上进行的连珠游戏&#xff0c;又称井字游戏。井字棋的游戏有两名玩家&#xff0c;其中一个玩家画圈&#xff0c;另一个玩家画叉&#xff0c;轮流在3 * 3格子上画上自己的符号&#xff0c;最先在横向、纵向、或斜线方向连成一条线的人为胜利方。如图1所示…

【数据可信流通,从运维信任到技术信任】

1. 数据可信流通体系 信任的基石&#xff1a; 身份的可确认利益可依赖能力有预期行为有后果 2.内循环——>外循环 内循环&#xff1a;数据持有方在自己的运维安全域内队自己的数据使用和安全拥有全责。 外循环&#xff1a;数据要素在离开持有方安全域后&#xff0c;持有方…

蓝桥杯刷题(六)

[蓝桥杯 2022 省 A] 求和 题目描述输入格式输出格式样例 #1样例输入 #1样例输出 #1 提示代码题目描述输入格式输出格式样例 #1样例输入 #1样例输出 #1 提示代码 题目描述 给定 n n n 个整数 a 1 , a 2 , ⋯ , a n a_{1}, a_{2}, \cdots, a_{n} a1​,a2​,⋯,an​, 求它们两两…