yolov8添加注意力机制模块-CBAM

修改

  1. 在tasks.py(路径:ultralytics-main/ultralytics-main - attention/ultralytics/nn/tasks.py)文件中,引入CBAM模块。因为yolov8源码中已经包含CBAM模块,在conv.py文件中(路径:ultralytics-main/ultralytics-main - attention/ultralytics/nn/modules/conv.py),这里就就用自己写了。
  2. 修改tasks.py文件,搜索parse_model。在指定位置添加代码。
            elif m is CBAM:  # todo 源码修改 (增加了elif)"""ch[f]:上一层的args[0]:第0个参数c1:输入通道数c2:输出通道数"""c1, c2 = ch[f], args[0]# print("ch[f]:",ch[f])# print("args[0]:",args[0])# print("args:",args)# print("c1:",c1)# print("c2:",c2)if c2 != nc:  # if c2 not equal to number of classes (i.e. for Classify() output)c2 = make_divisible(c2 * width, 8)args = [c1, *args[1:]]

    3.修改yolov8.yaml文件位置(ultralytics-main/ultralytics-main - attention/ultralytics/cfg/models/v8/yolov8.yaml)。修改head模块,修改的内容如下图。

        4.测试打印网络。已经添加成功。

分析

一般来说,注意力机制通常被分为以下基本四大类:

通道注意力 Channel Attention

空间注意力机制 Spatial Attention

时间注意力机制 Temporal Attention

分支注意力机制 Branch Attention

CBAM:通道注意力和空间注意力的集成者

源码解读

这段代码是对通道的注意力。首先经过自适应平均池化层,它会对每个输入通道的空间维度进行全局平均池化,并输出一个具有空间大小为 1x1 的特征图。然后是一个卷积操作,这相当于是对每个通道进行独立的全连接层变换,因为卷积核大小为1。

最后经过Sigmoid函数,将卷积层的输出转换为权重因子,范围在(0, 1)最后,这些权重因子与原始输入x逐元素相乘,以得到加权后的特征图,这一操作实现了注意力机制,允许模型专注于更有信息量的通道。

class ChannelAttention(nn.Module):"""Channel-attention module https://github.com/open-mmlab/mmdetection/tree/v3.0.0rc1/configs/rtmdet."""def __init__(self, channels: int) -> None:"""Initializes the class and sets the basic configurations and instance variables required."""super().__init__()self.pool = nn.AdaptiveAvgPool2d(1)self.fc = nn.Conv2d(channels, channels, 1, 1, 0, bias=True)self.act = nn.Sigmoid()def forward(self, x: torch.Tensor) -> torch.Tensor:"""Applies forward pass using activation on convolutions of the input, optionally using batch normalization."""return x * self.act(self.fc(self.pool(x)))

下面是一个空间注意力模块,旨在通过对输入特征图加权来强调或抑制某些空间区域。空间注意力通常用于强调图像的重要部分并抑制不重要的部分。

self.cv1 是一个卷积层,有两个输入通道,一个输出通道,和可选的 kernel_size 与 padding。由于 bias=False,这个卷积层不会有偏置参数。两个输入通道对应于输入特征图的均值和最大值。

forward中

  1. torch.mean(x, 1, keepdim=True) 计算输入张量 x 每个样本的通道维度的均值,keepdim=True 表示保持输出张量的维度不变。

  2. torch.max(x, 1, keepdim=True)[0] 计算输入张量 x 每个样本的通道维度的最大值,[0] 是因为 torch.max 返回一个元组,包含最大值和相应的索引。

  3. torch.cat([avg_out, max_out], 1) 将均值和最大值沿通道维度拼接起来,这样每个空间位置都有两个通道:其均值和最大值。

  4. self.cv1(x_cat) 对拼接的结果应用 1x2 卷积,生成一个单通道的特征图,该特征图对应于每个空间位置的注意力权重。

  5. self.act(...) 应用 Sigmoid 激活函数将注意力权重映射到 (0, 1) 范围内。

  6. x * scale 将原始输入 x 与计算得到的空间注意力权重相乘,这样每个空间位置的特征值都会根据其重要性加权,实现了特征重标定。

最终,forward 方法返回的是加权后的输入特征图(对特征图的每个元素值×权值),它突出了输入中更重要的空间区域。

class SpatialAttention(nn.Module):"""Spatial-attention module."""def __init__(self, kernel_size=7):"""Initialize Spatial-attention module with kernel size argument."""super().__init__()assert kernel_size in (3, 7), 'kernel size must be 3 or 7'padding = 3 if kernel_size == 7 else 1self.cv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)self.act = nn.Sigmoid()def forward(self, x):"""Apply channel and spatial attention on input for feature recalibration."""return x * self.act(self.cv1(torch.cat([torch.mean(x, 1, keepdim=True), torch.max(x, 1, keepdim=True)[0]], 1)))

下面就是CBAM,是上面两个模块的组合,通道注意力和空间注意力。通道注意力专注于哪些通道更重要,而空间注意力则集中在输入特征图中的哪些空间位置更重要。

  • 输入 x 首先通过 self.channel_attention,这个步骤会重新调整每个通道的重要性。
  • 然后,调整通道重要性后的特征图 x 通过 self.spatial_attention,这个步骤会重新调整特征图中每个位置的重要性。
  • 最终,这两个注意力机制的结果被串联起来,形成了最终的输出。

这种结构可以提高网络对于输入特征的逐通道和逐空间位置的重要性评估能力,进而可能提高模型的性能。

class CBAM(nn.Module):"""Convolutional Block Attention Module."""def __init__(self, c1, kernel_size=7):"""Initialize CBAM with given input channel (c1) and kernel size."""super().__init__()self.channel_attention = ChannelAttention(c1)self.spatial_attention = SpatialAttention(kernel_size)def forward(self, x):"""Applies the forward pass through C1 module."""return self.spatial_attention(self.channel_attention(x))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/701915.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

业务流程管理系统(BPMS):一文掌握,组织业务流程优化必备。

大家好,我是大美B端工场,本期继续分享商业智能信息系统的设计,欢迎大家关注,如有B端写系统界面的设计和前端需求,可以联络我们。 一、什么是BPMS系统 BPMS是Business Process Management System(业务流程管…

学习Python分支结构不走弯路

1.单分支语句 """ 语法: if 表达式:执行语句 执行流程:当表达式成立的时候,执行语句,否则不执行 """age int(input(请输入你的年龄:)) if age > 18:print(欢迎光临!) …

二进制部署k8s集群之cni网络插件

目录 k8s的三种网络模式 pod内容器之间的通信 同一个node节点中pod之间通信 不同的node节点的pod之间通信 flannel网络插件 flannel的三种工作方式 VxLAN host-GW UDP Flannel udp 模式 Flannel VXLAN 模式 flannel插件的三大模式的总结 calico网络插件 k8s 组网…

ABC342 A-G

HUAWEI Programming Contest 2024(AtCoder Beginner Contest 342) - AtCoder 被薄纱的一场 A - Yay! 题意: 给出一串仅由两种小写字母构成的字符串,其中一种小写字母仅出现一次,输出那个仅出现一次的小写字母的位置…

PyTorch概述(五)---LINEAR

torch.nn.Linear torch.nn.Linear(in_features,out_features,biasTrue,deviceNone,dtypeNone) 对输入的数据应用一个线性变换: 该模块支持TensorFLoat32类型的数据;在某些ROCm设备上,使用float16类型的数据输入时,该模块在反向传…

文本左右对齐

题目链接 文本左右对齐 题目描述 注意点 words[i] 由小写英文字母和符号组成每个单词的长度大于 0,小于等于 maxWidth输入单词数组 words 至少包含一个单词要求尽可能均匀分配单词间的空格数量。如果某一行单词间的空格不能均匀分配,则左侧放置的空格…

Unity中URP实现水体(水下的扭曲)

文章目录 前言一、使用一张法线纹理,作为水下扭曲的纹理1、在属性面板定义一个纹理,用于传入法线贴图2、在Pass中,定义对应的纹理和采样器3、在常量缓冲区,申明修改 Tilling 和 Offset 的ST4、在顶点着色器,计算得到 应…

目标检测开源数据集——太阳能板缺陷

简介 太阳能板,也称为太阳能电池板,是一种将太阳能转化为电能的设备。它的主要作用包括: 提供电力:太阳能板通过吸收阳光,将其转化为直流电,这种电能可以被各种设备使用。例如,它可以为家庭、…

重生奇迹MU职业排行

1、魔法师:魔法师是奇迹MU中最具实力的职业之一,他们拥有顶级的范围输出能力,同时还具备不错的控制技能。此外,魔法师还具有位移和护盾保命技能,技能伤害非常高,使其在游戏中具有很高的生存和攻击能力。 2…

第十四章 Linux面试题

第十四章 Linux面试题 日志t.log(访问量), 将各个ip地址截取,并统计出现次数,并按从大到小排序(腾 讯) http://192. 168200.10/index1.html http://192. 168.200. 10/index2.html http:/192. 168 200.20/index1 html http://192. 168 200.30/…

【Redis】搞懂过期删除策略和内存淘汰策略

1、过期删除策略 1.1、介绍 Redis 是可以对 key 设置过期时间的,因此需要有相应的机制将已过期的键值对删除,而做这个工作的就是过期键值删除策略。 每当我们对一个 key 设置了过期时间时,Redis 会把该 key 带上过期时间存储到一个过期字典…

【Web】CTFSHOW 常用姿势刷题记录(全)

目录 web801 web802 web803 web804 web805 web806 web807 法一:反弹shell 法二:vps外带 web808 web809 web810 web811 web812 web813 web814 web815 web816 web817 web818 web819 web820 web821 web822 web823 web824 web825…

软考45-上午题-【数据库】-数据操纵语言DML

一、INSERT插入语句 向SQL的基本表中插入数据有两种方式: ①直接插入元组值 ②插入一个查询的结果值 1-1、直接插入元组值 【注意】: 列名序列是可选的,若是所有列都要插入数值,则可以不写列名序列。 示例: 1-2、插…

yolov8学习笔记(一)网络结构

一、yolov8.yaml YOLOv8详解 【网络结构代码实操】: YOLOv8详解 【网络结构代码实操】-CSDN博客文章浏览阅读10w次,点赞559次,收藏2.9k次。YOLOv8 算法的核心特性和改动可以归结为如下:提供了一个全新的 SOTA 模型,包…

#LLM入门|Prompt#1.8_聊天机器人_Chatbot

聊天机器人设计 以会话形式进行交互,接受一系列消息作为输入,并返回模型生成的消息作为输出。原本设计用于简便多轮对话,但同样适用于单轮任务。 设计思路 个性化特性:通过定制模型的训练数据和参数,使机器人拥有特…

蛇形矩阵3

题目描述 把数1,2,3,4,5,…,N*N按照“蛇形3”放入N*N矩阵的中,输出结果。 下面是N6的蛇形3的图示 输入格式 第一行1个正整数:N,范围在[1,100]。 输出格式 N行&#x…

Git笔记——3

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 前言 一、合并模式和分支策略 二、bug分支 三、强制删除分支 四、创建远程仓库 五、克隆远程仓库_HTTPS和_SSH 克隆远程仓库_HTTPS 克隆远程仓库_SSH 六、向远程仓库…

【黑马程序员】3、TypeScript常用类型_黑马程序员前端TypeScript教程,TypeScript零基础入门到实战全套教程

课程地址:【黑马程序员前端TypeScript教程,TypeScript零基础入门到实战全套教程】 https://www.bilibili.com/video/BV14Z4y1u7pi/?share_sourcecopy_web&vd_sourceb1cb921b73fe3808550eaf2224d1c155 目录 3、TypeScript常用类型 3.1 类型注解 …

【统计分析数学模型】聚类分析: 系统聚类法

【统计分析数学模型】聚类分析: 系统聚类法 一、聚类分析1. 基本原理2. 距离的度量(1)变量的测量尺度(2)距离(3)R语言计算距离 三、聚类方法1. 系统聚类法2. K均值法 三、示例1. Q型聚类&#x…

四六级成绩爬取代码原创

在六级成绩刚发布时,只需要通过学生姓名和身份证号便可以查询到成绩 据此,我们可以利用selenium框架对学生的成绩进行爬取 首先我们要建立一个excel表格,里面放三列(多几列也无所谓),第一列列名取为学生姓…