YOLOv8改进算法之添加CA注意力机制

1. CA注意力机制

CA(Coordinate Attention)注意力机制是一种用于加强深度学习模型对输入数据的空间结构理解的注意力机制。CA 注意力机制的核心思想是引入坐标信息,以便模型可以更好地理解不同位置之间的关系。如下图:

1. 输入特征: CA 注意力机制的输入通常是一个特征图,它通常是卷积神经网络(CNN)中的某一层的输出,具有以下形状:[C, H, W],其中:

  • C 是通道数,表示特征图中的不同特征通道。
  • H 是高度,表示特征图的垂直维度。
  • W 是宽度,表示特征图的水平维度。

2. 全局平均池化: CA 注意力机制首先对输入特征图进行两次全局平均池化,一次在宽度方向上,一次在高度方向上。这两次操作分别得到两个特征映射:

  • 在宽度方向上的平均池化得到特征映射 [C, H, 1]
  • 在高度方向上的平均池化得到特征映射 [C, 1, W]

这两个特征映射分别捕捉了在宽度和高度方向上的全局特征。

3. 合并宽高特征: 将上述两个特征映射合并,通常通过简单的堆叠操作,得到一个新的特征层,形状为 [C, 1, H + W],其中 H + W 表示在宽度和高度两个方向上的维度合并在一起。

4. 卷积+标准化+激活函数: 对合并后的特征层进行卷积操作,通常是 1x1 卷积,以捕捉宽度和高度维度之间的关系。然后,通常会应用标准化(如批量标准化)和激活函数(如ReLU)来进一步处理特征,得到一个更加丰富的表示。

5. 再次分开: 分别从上述特征层中分离出宽度和高度方向的特征:

  • 一个分支得到特征层 [C, 1, H]
  • 另一个分支得到特征层 [C, 1, W]

6. 转置: 对分开的两个特征层进行转置操作,以恢复宽度和高度的维度,得到两个特征层分别为 [C, H, 1][C, 1, W]

7. 通道调整和 Sigmoid: 对两个分开的特征层分别应用 1x1 卷积,以调整通道数,使其适应注意力计算。然后,应用 Sigmoid 激活函数,得到在宽度和高度维度上的注意力分数。这些分数用于指示不同位置的重要性。

8. 应用注意力: 将原始输入特征图与宽度和高度方向上的注意力分数相乘,得到 CA 注意力机制的输出。

2. YOLOv8添加CA注意力机制

加入注意力机制,在ultralytics包中的nn包的modules里添加CA注意力模块,我这里选择在conv.py文件中添加CA注意力机制。

CA注意力机制代码如下:

import torch
import torch.nn as nn
import torch.nn.functional as Fclass h_sigmoid(nn.Module):def __init__(self, inplace=True):super(h_sigmoid, self).__init__()self.relu = nn.ReLU6(inplace=inplace)def forward(self, x):return self.relu(x + 3) / 6class h_swish(nn.Module):def __init__(self, inplace=True):super(h_swish, self).__init__()self.sigmoid = h_sigmoid(inplace=inplace)def forward(self, x):return x * self.sigmoid(x)class CoordAtt(nn.Module):def __init__(self, inp, reduction=32):super(CoordAtt, self).__init__()self.pool_h = nn.AdaptiveAvgPool2d((None, 1))self.pool_w = nn.AdaptiveAvgPool2d((1, None))mip = max(8, inp // reduction)self.conv1 = nn.Conv2d(inp, mip, kernel_size=1, stride=1, padding=0)self.bn1 = nn.BatchNorm2d(mip)self.act = h_swish()self.conv_h = nn.Conv2d(mip, inp, kernel_size=1, stride=1, padding=0)self.conv_w = nn.Conv2d(mip, inp, kernel_size=1, stride=1, padding=0)def forward(self, x):identity = xn, c, h, w = x.size()x_h = self.pool_h(x)x_w = self.pool_w(x).permute(0, 1, 3, 2)y = torch.cat([x_h, x_w], dim=2)y = self.conv1(y)y = self.bn1(y)y = self.act(y)x_h, x_w = torch.split(y, [h, w], dim=2)x_w = x_w.permute(0, 1, 3, 2)a_h = self.conv_h(x_h).sigmoid()a_w = self.conv_w(x_w).sigmoid()out = identity * a_w * a_hreturn out

CA注意力机制的注册和引用如下:

 ultralytics/nn/modules/_init_.py文件中:

  ultralytics/nn/tasks.py文件夹中:

 在tasks.py中的parse_model中添加如下代码:

        elif m in {CoordAtt}:args=[ch[f],*args]

新建相应的yolov8s-CA.yaml文件,代码如下:

# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect# Parameters
nc: 80  # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'# [depth, width, max_channels]n: [0.33, 0.25, 1024]  # YOLOv8n summary: 225 layers,  3157200 parameters,  3157184 gradients,   8.9 GFLOPss: [0.33, 0.50, 1024]  # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients,  28.8 GFLOPsm: [0.67, 0.75, 768]   # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients,  79.3 GFLOPsl: [1.00, 1.00, 512]   # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPsx: [1.00, 1.25, 512]   # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs# YOLOv8.0n backbone
backbone:# [from, repeats, module, args]- [-1, 1, Conv, [64, 3, 2]]  # 0-P1/2- [-1, 1, Conv, [128, 3, 2]]  # 1-P2/4- [-1, 3, C2f, [128, True]]- [-1, 1, Conv, [256, 3, 2]]  # 3-P3/8- [-1, 6, C2f, [256, True]]- [-1,1,CoordAtt,[]]- [-1, 1, Conv, [512, 3, 2]]  # 5-P4/16- [-1, 6, C2f, [512, True]]- [-1,1,CoordAtt,[]]- [-1, 1, Conv, [1024, 3, 2]]  # 7-P5/32- [-1, 3, C2f, [1024, True]]- [-1,1,CoordAtt,[]]- [-1, 1, SPPF, [1024, 5]]  # 9# YOLOv8.0n head
head:- [-1, 1, nn.Upsample, [None, 2, 'nearest']]- [[-1, 8], 1, Concat, [1]]  # cat backbone P4- [-1, 3, C2f, [512]]  # 12- [-1, 1, nn.Upsample, [None, 2, 'nearest']]- [[-1, 5], 1, Concat, [1]]  # cat backbone P3- [-1, 3, C2f, [256]]  # 15 (P3/8-small)- [-1, 1, Conv, [256, 3, 2]]- [[-1, 15], 1, Concat, [1]]  # cat head P4- [-1, 3, C2f, [512]]  # 18 (P4/16-medium)- [-1, 1, Conv, [512, 3, 2]]- [[-1, 12], 1, Concat, [1]]  # cat head P5- [-1, 3, C2f, [1024]]  # 21 (P5/32-large)- [[18, 21, 24], 1, Detect, [nc]]  # Detect(P3, P4, P5)

在main.py文件中进行训练:

if __name__ == '__main__':# 使用yaml配置文件来创建模型,并导入预训练权重.model = YOLO('ultralytics/cfg/models/v8/yolov8s-CA.yaml')# model.load('yolov8n.pt')model.train(**{'cfg': 'ultralytics/cfg/default.yaml', 'data': 'dataset/data.yaml'})

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/91532.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Error: Activity class {xxx.java} does not exist

git切换到不同的branch之后,报下面的错误: Error: Activity class {xxx.java} does not exist 解决方案: 首先clean 然后会删除build目录 然后点击:Invalidate Caches Android Studio重启,然后重新build即可。

数据链路层 MTU 对 IP 协议的影响

在介绍主要内容之前,我们先来了解一下数据链路层中的"以太网" 。 “以太网”不是一种具体的网络,而是一种技术标准;既包含了数据链路层的内容,也包含了一些物理层的内容。 下面我们再来了解一下以太网数据帧&#xff…

【Java 进阶篇】MySQL 事务详解

在数据库管理中,事务是一组SQL语句的执行单元,它们被视为一个整体。事务的主要目标是保持数据库的一致性和完整性,即要么所有SQL语句都成功执行,要么所有SQL语句都不执行。在MySQL中,事务起到了非常重要的作用&#xf…

Linux文件查找,别名,用户组综合练习

1.文件查看: 查看/etc/passwd文件的第5行 [rootserver ~]# head -5 /etc/passwd root:x:0:0:root:/root:/bin/bash bin:x:1:1:bin:/bin:/sbin/nologin daemon:x:2:2:daemon:/sbin:/sbin/nologin adm:x:3:4:adm:/var/adm:/sbin/nologin lp:x:4:7:lp:/var/spool/lpd:/sbin/nologi…

【实践成果】Splunk 9.0 Configuration Change Tracking

Splunk 9.0 引入了新的功能,一个很重要的一个,就是跟踪conguration 文件的变化: 这个很重要的特性,在splunk 9.0 以后才引入,就看server.conf 配置中,9.0 以后的版本才有: server.conf - Splu…

数据集笔记:纽约花旗共享单车od数据

花旗共享单车公布的其共享单车轨迹数据,包括2013年-2021年曼哈顿、布鲁克林、皇后区和泽西城大约14500辆自行车和950个站点的共享单车轨迹数据 数据地址:Citi Bike System Data | Citi Bike NYC | Citi Bike NYC 性别(0未知;1男&…

详解分布式搜索技术之elasticsearch

目录 一、初识elasticsearch 1.1什么是elasticsearch 1.2elasticsearch的发展 1.3为什么学习elasticsearch? 1.4正向索引和倒排索引 1.4.1传统数据库采用正向索引 1.4.2elasticsearch采用倒排索引 1.4.3posting list ​1.4.4总结 1.5 es的一些概念 1.5.1文档和字段 …

unity打包工具

接手了一个项目,打包存在重大问题,故此在unity addressables 基础上弄了一个简单的打包工具,代码也都做好了注释,操作非常简单以下为操作方法: 首先设置导入Addressables插件,并设置好详细参见&#xff1a…

GitHub 基本操作

最近要发展一下自己的 github 账号了,把以前的项目代码规整规整上传上去,这里总结了一些经验,经过数次实践之后,已解决几乎所有基本操作中的bug,根据下面的操作步骤来,绝对没错了。(若有其他问题…

pytho实例--pandas读取表格内容

前言:由于运维反馈帮忙计算云主机的费用,特编写此脚本进行运算 如图,有如下excel数据 计算过程中需用到数据库中的数据,故封装了一个读取数据库的类 import MySQLdb from sshtunnel import SSHTunnelForwarderclass SSHMySQL(ob…

win11+wsl+git+cmake+x86gcc+armgcc+clangformat+vscode环境安装

一、安装wsl (1)打开power shell 并运行: Enable-WindowsOptionalFeature -Online -FeatureName Microsoft-Windows-Subsystem-Linux Enable-WindowsOptionalFeature -Online -FeatureName VirtualMachinePlatform (2&#xff0…

【图像分割】图像检测(分割、特征提取)、各种特征(面积等)的测量和过滤(Matlab代码实现)

💥💥💞💞欢迎来到本博客❤️❤️💥💥 🏆博主优势:🌞🌞🌞博客内容尽量做到思维缜密,逻辑清晰,为了方便读者。 ⛳️座右铭&a…

node版本问题:Error: error:0308010C:digital envelope routines::unsupported

前言 出现这个错误是因为 node.js V17及以后版本中最近发布的OpenSSL3.0, 而OpenSSL3.0对允许算法和密钥大小增加了严格的限制,可能会对生态系统造成一些影响. 在node.js V17以前一些可以正常运行的的应用程序,但是在 V17 及以后版本可能会抛出以下异常: 我重装系统前,用的…

三个要点,掌握Spring Boot单元测试

单元测试是软件开发中不可或缺的重要环节,它用于验证软件中最小可测试单元的准确性。结合运用Spring Boot、JUnit、Mockito和分层架构,开发人员可以更便捷地编写可靠、可测试且高质量的单元测试代码,确保软件的正确性和质量。 一、介绍 本文…

Lua学习笔记:require非.lua拓展名的文件

前言 本篇在讲什么 Lua的require相关的内容 本篇需要什么 对Lua语法有简单认知 对C语法有简单认知 依赖Visual Studio工具 本篇的特色 具有全流程的图文教学 重实践,轻理论,快速上手 提供全流程的源码内容 ★提高阅读体验★ 👉 ♠…

uni-app 之 短信验证码登录

uni-app 之 短信验证码登录 image.png image.png <template><view style"width: 100%; display: flex; flex-direction:column; align-items:center;"><view style"width: 300px; margin-top: 100px;"><!-- // --><!-- 1&#…

uni-app:实现密码框内容展示与隐藏

效果 代码 <template><view class"container"><view class"item_left"><view>密码</view><view class"eye_position" taptoggleShowPassword><image :srceye v-ifisShowPassword /><image :srcey…

DBRichEdit关联ClientDataSet不能保存的Bug

ClientDataSet的最大好处&#xff0c;就是建立能内存表&#xff0c;特别DataSnap三层运用中&#xff0c;主要使用ClientDataSet与运程的服务器中的数据表&#xff0c;建立读取存贮关系。 在软件的使用中&#xff0c;总有客户反映&#xff0c;一些数据不能保存。 发现都是使用DB…

Springboot中使用拦截器、过滤器、监听器

一、Servlet、Filter&#xff08;过滤器&#xff09;、 Listener&#xff08;监听器&#xff09;、Interceptor&#xff08;拦截器&#xff09; Javaweb三大组件&#xff1a;servlet、Filter&#xff08;过滤器&#xff09;、 Listener&#xff08;监听器&#xff09; Spring…

nodejs在pdf中绘制表格

需求 之前我已经了解过如何在pdf模板中填写字段了 nodejs根据pdf模板填入中文数据并生成新的pdf文件https://blog.csdn.net/ArmadaDK/article/details/132456324 但是当我具体使用的时候&#xff0c;我发现我的模板里面有表格&#xff0c;表格的长度是不固定的&#xff0c;所…