YOLOv5改进 | 注意力机制 | 添加三重注意力机制 TripletAttention【原理 + 完整代码】

💡💡💡本专栏所有程序均经过测试,可成功执行💡💡💡

得益于在通道或空间位置之间建立相互依赖关系的能力,近年来,注意力机制在计算机视觉任务中得到了广泛的研究和应用。一种轻量级但有效的注意力机制——三重注意力,这是一种通过使用三分支结构捕获跨维度交互来计算注意力权重的创新方法。对于一个输入张量,三重注意力通过旋转变换建立跨维度依赖关系,并通过残差变换编码跨通道和空间信息,几乎不增加计算开销。在本文中,给大家带来的教程是将原来的网络添加TripletAttention。文章在介绍主要的原理后,将手把手教学如何进行模块的代码添加和修改,并将修改后的完整代码放在文章的最后,方便大家一键运行小白也可轻松上手实践。以帮助您更好地学习深度学习目标检测YOLO系列的挑战。

专栏地址 YOLOv5改进+入门——持续更新各种有效涨点方法 点击即可跳转

目录

1.原理

2. TripletAttention代码实现

2.1 将TripletAttention添加到YOLOv5中

2.2 新增yaml文件

2.3 注册模块

2.4 执行程序

3. 完整代码分享

4. GFLOPs

5. 进阶

6.总结


1.原理

官方论文:Rotate to Attend: Convolutional Triplet Attention Module——点击即可跳转

官方代码:官方代码仓库地址——点击即可跳转

三重注意力机制(Triplet Attention)是一种深度学习中的注意力机制,旨在提高模型对输入数据的理解和表示能力。它在自然语言处理(NLP)和计算机视觉(CV)等领域都有应用。

这个机制的核心思想是将注意力机制引入到不同级别的特征表示中,以更全面地捕捉输入数据的信息。通常来说,传统的注意力机制会在同一级别的特征表示中计算注意力权重,而三重注意力机制则引入了三个不同级别的特征表示,并在每个级别上计算注意力权重,从而实现了“三重”的概念。

具体来说,三重注意力机制通常包含以下三个层次的注意力计算:

  1. 全局注意力(Global Attention): 全局注意力通常是在输入数据的最底层或最原始的表示上计算的,例如,在NLP中可能是词级别的表示,或者在CV中可能是原始图像的表示。在这一层次上,模型尝试理解整个输入的上下文信息,并计算每个部分的重要性。

  2. 组间注意力(Inter-group Attention): 组间注意力是在全局注意力得到的表示的基础上计算的。它将全局表示分成不同的组(可能是空间上的不同区域,或者是语义上的不同部分),然后在这些组之间计算注意力权重。这一层级的注意力有助于模型更好地理解输入数据中不同部分之间的关系和交互。

  3. 组内注意力(Intra-group Attention): 组内注意力是在组间注意力得到的表示的基础上计算的。它在每个组内部计算注意力权重,以捕捉组内部分的重要性和关联性。这一层级的注意力有助于模型更好地理解每个组内部分的内在结构和语义信息。

通过这三个层次的注意力计算,三重注意力机制可以在不同级别上捕捉输入数据的全局信息、组间关系和组内结构,从而更有效地理解和表示输入数据。

总的来说,三重注意力机制通过在不同级别上引入注意力机制,能够更全面地捕捉输入数据的信息,从而提高了深度学习模型的表现能力。

2. TripletAttention代码实现

2.1 将TripletAttention添加到YOLOv5中

关键步骤一将下面代码粘贴到/projects/yolov5-6.1/models/common.py文件中

import torch
import math
import torch.nn as nn
import torch.nn.functional as Fclass BasicConv(nn.Module):def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True, bn=True, bias=False):super(BasicConv, self).__init__()self.out_channels = out_planesself.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)self.bn = nn.BatchNorm2d(out_planes,eps=1e-5, momentum=0.01, affine=True) if bn else Noneself.relu = nn.ReLU() if relu else Nonedef forward(self, x):x = self.conv(x)if self.bn is not None:x = self.bn(x)if self.relu is not None:x = self.relu(x)return xclass ChannelPool(nn.Module):def forward(self, x):return torch.cat( (torch.max(x,1)[0].unsqueeze(1), torch.mean(x,1).unsqueeze(1)), dim=1 )class SpatialGate(nn.Module):def __init__(self):super(SpatialGate, self).__init__()kernel_size = 7self.compress = ChannelPool()self.spatial = BasicConv(2, 1, kernel_size, stride=1, padding=(kernel_size-1) // 2, relu=False)def forward(self, x):x_compress = self.compress(x)x_out = self.spatial(x_compress)scale = torch.sigmoid_(x_out) return x * scaleclass TripletAttention(nn.Module):def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max'], no_spatial=False):super(TripletAttention, self).__init__()self.ChannelGateH = SpatialGate()self.ChannelGateW = SpatialGate()self.no_spatial=no_spatialif not no_spatial:self.SpatialGate = SpatialGate()def forward(self, x):x_perm1 = x.permute(0,2,1,3).contiguous()x_out1 = self.ChannelGateH(x_perm1)x_out11 = x_out1.permute(0,2,1,3).contiguous()x_perm2 = x.permute(0,3,2,1).contiguous()x_out2 = self.ChannelGateW(x_perm2)x_out21 = x_out2.permute(0,3,2,1).contiguous()if not self.no_spatial:x_out = self.SpatialGate(x)x_out = (1/3)*(x_out + x_out11 + x_out21)else:x_out = (1/2)*(x_out11 + x_out21)return x_out

三重注意力机制的主要流程可以分为以下步骤:

  1. 输入数据表示: 首先,将输入数据(例如文本序列、图像等)进行表示。这可能包括将文本序列转换为词嵌入向量、将图像转换为特征图等。这一步骤的目的是将输入数据转换为模型可以处理的表示形式。

  2. 全局注意力计算: 在第一级别,对输入数据的全局表示进行计算。这可以通过应用传统的注意力机制来实现,例如使用自注意力机制(Self-Attention)或注意力机制的变体。在这一步骤中,模型尝试理解整个输入的上下文信息,并计算每个部分的重要性。

  3. 组间表示生成: 在第二级别,根据全局注意力得到的权重,将全局表示分成不同的组。这些组可以根据具体的任务和数据特点来确定,例如在图像中可能是空间上的不同区域,在文本中可能是不同的句子或段落。然后,对每个组进行表示生成,得到组间表示。

  4. 组间注意力计算: 在第二级别,对组间表示进行注意力计算。这一步骤可以类似地使用注意力机制,但是针对的是组间的关系和交互。通过计算组间的注意力权重,模型可以更好地理解不同组之间的关系和重要性。

  5. 组内表示生成: 在第三级别,根据组间注意力得到的权重,将每个组内的表示进行生成。这一步骤可以帮助模型更好地理解每个组内部分的内在结构和语义信息。

  6. 组内注意力计算: 在第三级别,对每个组内的表示进行注意力计算。这类似于组间注意力计算,但是针对的是组内部分的关系和重要性。通过计算组内的注意力权重,模型可以更好地理解组内部分之间的关系和重要性。

  7. 输出: 最后,根据经过三级注意力机制处理后的表示,进行任务相关的后续处理,如分类、回归等,得到最终的输出结果。

总的来说,三重注意力机制通过在不同级别上引入注意力机制,实现了对输入数据的全局信息、组间关系和组内结构的捕捉和理解,从而提高了深度学习模型的表现能力。

2.2 新增yaml文件

关键步骤二在下/projects/yolov5-6.1/models下新建文件 yolov5_TripletAttention.yaml并将下面代码复制进去

# YOLOv5 🚀 by Ultralytics, GPL-3.0 license# Parameters
nc: 80  # number of classes
depth_multiple: 1.0  # model depth multiple
width_multiple: 1.0  # layer channel multiple
anchors:- [10,13, 16,30, 33,23]  # P3/8- [30,61, 62,45, 59,119]  # P4/16- [116,90, 156,198, 373,326]  # P5/32# YOLOv5 v6.0 backbone
backbone:# [from, number, module, args][[-1, 1, Conv, [64, 6, 2, 2]],  # 0-P1/2[-1, 1, Conv, [128, 3, 2]],  # 1-P2/4[-1, 3, C3, [128]],[-1, 1, TripletAttention, [128,3]],[-1, 1, Conv, [256, 3, 2]],  # 3-P3/8[-1, 6, C3, [256]],[-1, 1, Conv, [512, 3, 2]],  # 5-P4/16[-1, 9, C3, [512]],[-1, 1, Conv, [1024, 3, 2]],  # 7-P5/32[-1, 3, C3, [1024]],[-1, 1, SPPF, [1024, 5]],  # 9]# YOLOv5 v6.0 head
head:[[-1, 1, Conv, [512, 1, 1]],[-1, 1, nn.Upsample, [None, 2, 'nearest']],[[-1, 7], 1, Concat, [1]],  # cat backbone P4[-1, 3, C3, [512, False]],  # 13[-1, 1, Conv, [256, 1, 1]],[-1, 1, nn.Upsample, [None, 2, 'nearest']],[[-1, 5], 1, Concat, [1]],  # cat backbone P3[-1, 3, C3, [256, False]],  # 17 (P3/8-small)[-1, 1, Conv, [256, 3, 2]],[[-1, 15], 1, Concat, [1]],  # cat head P4[-1, 3, C3, [512, False]],  # 20 (P4/16-medium)[-1, 1, Conv, [512, 3, 2]],[[-1, 11], 1, Concat, [1]],  # cat head P5[-1, 3, C3, [1024, False]],  # 23 (P5/32-large)[[18, 21, 24], 1, Detect, [nc, anchors]],  # Detect(P3, P4, P5)]

温馨提示:本文只是对yolov5l基础上添加swin模块,如果要对yolov8n/l/m/x进行添加则只需要指定对应的depth_multiple 和 width_multiple。


# YOLOv5n
depth_multiple: 0.33  # model depth multiple
width_multiple: 0.25  # layer channel multiple# YOLOv5s
depth_multiple: 0.33  # model depth multiple
width_multiple: 0.50  # layer channel multiple# YOLOv5l 
depth_multiple: 1.0  # model depth multiple
width_multiple: 1.0  # layer channel multiple# YOLOv5m
depth_multiple: 0.67  # model depth multiple
width_multiple: 0.75  # layer channel multiple# YOLOv5x
depth_multiple: 1.33  # model depth multiple
width_multiple: 1.25  # layer channel multiple

2.3 注册模块

关键步骤三在yolo.py中注册, 大概在260行左右添加 ‘TripletAttention’

2.4 执行程序

在train.py中,将cfg的参数路径设置为yolov5_TripletAttention.yaml的路径

建议大家写绝对路径,确保一定能找到

 🚀运行程序,如果出现下面的内容则说明添加成功🚀

3. 完整代码分享

https://pan.baidu.com/s/1RST9hL8La0GZ8n-kk9bXiw?pwd=45cq

提取码: 45cq 

4. GFLOPs

关于GFLOPs的计算方式可以查看:百面算法工程师 | 卷积基础知识——Convolution

未改进的GFLOPs

改进后的GFLOPs,可以看出这个计算量几乎没有变化,也印证了文章开头说的“计算量几乎b”

5. 进阶

你能在不同的位置添加三重注意力机制吗?这非常有趣,快去试试吧

6.总结

三重注意力机制(Triplet Attention)是一种深度学习中的注意力机制,通过在不同层次上引入注意力机制,增强模型对输入数据的理解和表示能力。其流程包括首先将输入数据转换为模型可处理的表示形式,然后在全局表示上计算注意力权重以捕捉整体上下文信息,接着根据全局注意力权重将全局表示分成不同的组并生成组间表示,再在组间表示上计算注意力权重以理解组间关系,随后生成组内表示并在组内计算注意力权重以捕捉组内结构和关系,最后基于处理后的表示进行任务相关的处理得到最终输出。通过全局、组间和组内三个层次的注意力计算,三重注意力机制能够更全面地捕捉输入数据的信息,从而提升模型的表现能力。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/19490.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

嵌入式Linux命令基础

一、命令概述 1. 命令本质 命令的特性:一般就是对应shell命令,每一个命令代表一个可执行程序,运行一个命令就相当于 运行一个可执行代码。 2. 打开终端方法 第一种方法:通过鼠标右键选择打开终端 第二种方法:利用…

Django——Admin站点(Python)

#前言: 该博客为小编Django基础知识操作博客的最后一篇,主要讲解了关于Admin站点的一些基本操作,小编会继续尽力更新一些优质文章,同时欢迎大家点赞和收藏,也欢迎大家关注等待后续文章。 一、简介: Djan…

认识Oracle v$mystat视图

v$mystat就是当前用户的各种统计信息, sid就是session的id(也就是当前用户),STATISTIC#就是统计量的编号(用来唯一确定统计量的名称),value是统计量的值; desc命令在Oracle中通常用于查看表结构; v$mystat视图中只会有当前用户…

【NVM】nvm常用命令,切换node版本命令

nvm常用的命令,切换node版本命令 nvm 查看支持安装的node版本 nvm list available nvm安装指定版本node nvm install 版本号 例如:nvm install 10.24.1 nvm查看本机安装所有node版本 nvm list nvm切换node版本 nvm use 10.24.1 检测当前node版本 node -…

大数据中的电商数仓项目:探秘业务的核心

我学习完一个电商数仓的项目和电影实时推荐项目,便兴冲冲的去面试大数据开发岗,在面试的时候,面试官总是喜欢问,聊聊你为什么要做这个项目以及你这个项目有哪些业务? 我心想,为什么要做这个业务&#xff1f…

【码银送书第二十期】《游戏运营与出海实战:策略、方法与技巧》

市面上的游戏品种繁杂,琳琅满目,它们是如何在历史的长河中逐步演变成今天的模式的呢?接下来,我们先回顾游戏的发展史,然后按照时间轴来叙述游戏运营的兴起。 作者:艾小米 本文经机械工业出版社授权转载&a…

用Idea 解决Git冲突

https://intellijidea.com.cn/help/idea/resolving-conflicts.html https://www.jetbrains.com/help/idea/resolve-conflicts.html idea 官方文档 当您在团队中工作时,您可能会遇到这样的情况:有人对您当前正在处理的文件进行更改。如果这些更改没有重叠(也就是说…

Ps系统教程03

选区工具的组合使用 先用魔棒将大致区域点击圈主 会发现一些零散的小区域 使用套索工具进行区域的加减(按住shift/alt键进行相关区域加减) 可以放大查看 基本处理完细节之后 如果把不用的填充背景直接按delete删除,那么原版图案就会…

Hadoop3:MapReduce的序列化和反序列化

一、概念 1、序列化 就是把内存中的对象,转换成字节序列 (或其他数据传输协议)以便于存储到磁 盘(持久化)和网络传输。 2、反序列化 就是将收到字节序列(或其他数据传输协议)或者是磁盘的持…

LeetCode-47 全排列Ⅱ

LeetCode-47 全排列Ⅱ 题目描述解题思路代码说明 题目描述 给定一个可包含重复数字的序列 nums ,按任意顺序 返回所有不重复的全排列。 示例 : 输入:nums [1,1,2]输出: [[1,1,2], [1,2,1], [2,1,1]] b站题目解读讲的不好&…

部署k8s的DashBoard

1. 部署 Dashboard UI [rootk8s-master ~]# kubectl apply -f https://raw.githubusercontent.com/kubernetes/dashboard/v2.7.0/aio/deploy/recomme nded.yaml一般上面的网站访问不了 可以下载我上传的资源DashBoard的recommended.yaml vim recommended.yaml 复制粘贴我上…

做外贸,怎么选国外服务器?

不管是新手还是外贸老司机,大家都知道要用海外服务器来做外贸网站,无论外贸独立站的客户是欧美、东南亚、还是非洲,都不能选择国内机房的服务器,必须选择海外服务器,这是共识。 但是今天,我要告诉大家一个…

Java Apache Jaccard文本相似度匹配初体验

文章目录 前言一、文本相似度算法的选择二、常见的文本相似度算法介绍三、使用示例1、引入jar包2、方法示例3、Jaccard源码剖析4、Jaccard源码解释 写在最后 前言 产品今天提了个需求,大概是这样的,来,请看大屏幕。。。额。。。搞错了&#…

Spring Boot 2 入门基础

学习要求 ● 熟悉Spring基础 ● 熟悉Maven使用 环境要求 ● Java8及以上 ● Maven 3.3及以上:https://docs.spring.io/spring-boot/docs/current/reference/html/getting-started.html#getting-started-system-requirements 学习资料 ● 文档地址: htt…

前端从零到一开发vscode插件并发布到插件市场

前端从零到一开发vscode插件并发布到插件市场 背景目标成果展示一条龙实现过程安装插件脚手架和工具创建项目运行调试打包第一次打包前的必要操作 发布第一次发布前账号准备注册Azure DevOps发布账号-获取token注册vscode开发者账号终端登录vsce 发布方式2-手动上传插件 进阶开…

深入分析 Android Service (三)

文章目录 深入分析 Android Service (三)1. Service 与 Activity 之间的通信2. 详细示例:通过绑定服务进行通信2.1 创建一个绑定服务2.2 绑定和通信 3. 优化建议4. 使用场景5. 总结 深入分析 Android Service (三) 1. Service 与 Activity 之间的通信 在 Android …

115道MySQL面试题(含答案),从简单到深入!

1. 什么是数据库事务? 数据库事务是一个作为单个逻辑工作单元执行的一系列操作。事务具有ACID属性,即原子性(Atomicity)、一致性(Consistency)、隔离性(Isolation)和持久性&#xf…

手机站怎么推广

随着手机的普及和移动互联网的快速发展,越来越多的人开始使用手机进行在线购物、社交娱乐、阅读资讯等,同时也催生了越来越多的手机站的出现。但是,在海量的手机站中,要让自己的手机站脱颖而出,吸引更多用户访问和使用…

CSS 【实战】 “四合院”布局

效果预览 页面要求: 上下固定高度左右固定宽度中间区域自适应宽高整个页面内容撑满全屏,没有滚动条 技术要点 使用 html5 语义化标签 header 网页内的标题区域nav 导航区域aside 侧边栏footer 页脚区域section 内容分区article 文章区域 清除浏览器默…

微信小程序区分运行环境

wx.getAccountInfoSync() 是微信小程序的一个 API,它可以同步获取当前账号信息。返回对象中包含小程序 AppID、插件的 AppID、小程序/插件版本等信息。 返回的对象结构如下: 小程序运行环境,可选值有:develop(开发版&…