YOLOV8注意力改进方法: Dual-ViT(Dual Vision Transformer) (附改进代码)

原论文地址:原论文下载网址

论文相关内容介绍

将自注意力过程分解为区域和局部特征提取过程,每个过程产生的计算复杂度要小得多。然而,区域信息通常仅以由于下采样而丢失的不希望的信息为代价。在本文中,作者提出了一种旨在缓解成本问题的新型Transformer架构,称为双视觉Transformer(Dual ViT)。新架构结合了一个关键的语义路径,可以更有效地将token向量压缩为全局语义,并降低复杂性。这种压缩的全局语义通过另一个构建的像素路径,作为学习内部像素级细节的有用先验信息。然后将语义路径像素路径整合在一起,并进行联合训练,通过这两条路径并行传播增强的自注意力信息。因此,双ViT能够在不影响精度的情况下降低计算复杂度。实证证明,双ViT比SOTA Transformer架构提供了更高的精度,同时降低了训练复杂度。

Transformer结构在革新深度学习应用方面取得了巨大成功,包括自然语言处理和计算机视觉任务。不幸的是,由于Transformer通常依赖密集的自注意力计算,因此对于高分辨率输入,此类架构的训练通常很慢。由于transformer技术通常可以提供比同类技术更高的性能,因此这种复杂性问题逐渐成为制约这种强大体系结构发展的瓶颈。

自注意力过程是此类复杂性问题的主要负担,因为每个token的每个表示都是通过关注所有token来更新的。最近的工作将重点放在研究复杂性问题上,通过提供不同于标准自注意力的替代解决方案。许多人考虑将自注意力与下采样相结合,以有效地取代原来的标准注意力。这种方式自然能够探索区域语义信息,从而进一步促进局部特征的学习/提取。例如,PVT提出了线性空间减少注意(SRA),该注意通过下采样操作(例如,平均池或跨步卷积)减少键和值的空间比例,如图1(a)所示。Twins(上图(b))在SRA之前添加了额外的局部分组自注意力层,以通过区域内相互作用进一步增强表示。RegionViT(上图(c))通过区域和局部自注意力分解原始注意力。然而,由于上述方法严重依赖于特征映射到区域的下采样,在有效节省总计算成本的同时,观察到了明显的性能下降。

在这些不同的组合策略中,很少有人试图研究全局语义和内部像素级特征之间在降低复杂性方面的依赖关系。在本文中,作者考虑通过提出的双重ViT将训练分解为全局语义和内部特征注意。其动机是提取全局语义信息(即参数语义查询),这些信息可以作为丰富的先验信息,在新的双通道设计中帮助用户进行局部特征提取。本文对全局语义和局部特征的独特分解和集成允许有效减少多头注意力中涉及的token数量,从而与标准注意对应项相比节省了计算复杂性。特别是,如上图(d)所示,双ViT由两个特殊路径组成,分别称为“语义路径”和“像素路径”。通过构造的“像素路径”进行局部像素级特征提取是强烈依赖于“语义路径”之外的压缩全局先验。由于梯度同时通过语义路径和像素路径,因此双ViT训练过程可以有效地补偿全局特征压缩的信息损失,同时减少局部特征提取的困难。前者和后者都可以并行显著降低计算成本,因为注意力大小较小,并且两条路径之间存在强制依赖关系。

综上所述,本文做出了以下贡献:

1) 提出了一种新的Transformer架构,称为双视觉Transformer(双ViT)。顾名思义,双ViT网络包括两条路径,分别用于提取输入语义特征的更全面全局视图,以及另一条专注于学习内部局部特征的像素路径。

2)双ViT考虑了两条路径上全局语义和局部特征之间的依赖关系,目的是通过减少token大小和注意力来简化训练。

3) 与VOLO相比,双ViT在ImageNet上实现了85.7%的top-1精度,只有41.1%的浮点运算和37.8%的参数。在目标检测和实例分割方面,双ViT在映射方面也提高了PVT,在COCO上分别提高了1.2%和0.9%,参数减少了48.0%。

具体的方法如何实现大家可以参考原论文

2.yolov8加入Dual Vision Transformer的步骤:

2.1 新建加入ultralytics/nn/attention/dualvit.py

注意在nn文件夹下新建attention文件夹,创建dualvit.py文件,导入下面的代码

import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import partialfrom timm.models.layers import DropPath, to_2tuple, trunc_normal_
import mathclass DWConv(nn.Module):def __init__(self, dim=768):super(DWConv, self).__init__()self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)def forward(self, x, H, W):B, N, C = x.shapex = x.transpose(1, 2).view(B, C, H, W)x = self.dwconv(x)x = x.flatten(2).transpose(1, 2)return xclass PVT2FFN(nn.Module):def __init__(self, in_features, hidden_features):super().__init__()self.fc1 = nn.Linear(in_features, hidden_features)self.dwconv = DWConv(hidden_features)self.act = nn.GELU()self.fc2 = nn.Linear(hidden_features, in_features)self.apply(self._init_weights)def _init_weights(self, m):if isinstance(m, nn.Linear):trunc_normal_(m.weight, std=.02)if isinstance(m, nn.Linear) and m.bias is not None:nn.init.constant_(m.bias, 0)elif isinstance(m, nn.LayerNorm):if m.bias is not None:nn.init.constant_(m.bias, 0)if m.weight is not None:nn.init.constant_(m.weight, 1.0)elif isinstance(m, nn.Conv2d):fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channelsfan_out //= m.groupsm.weight.data.normal_(0, math.sqrt(2.0 / fan_out))if m.bias is not None:m.bias.data.zero_()def forward(self, x, H, W):x = self.fc1(x)x = self.dwconv(x, H, W)x = self.act(x)x = self.fc2(x)return xclass MergeFFN(nn.Module):def __init__(self, in_features, hidden_features):super().__init__()self.fc1 = nn.Linear(in_features, hidden_features)self.dwconv = DWConv(hidden_features)self.act = nn.GELU()self.fc2 = nn.Linear(hidden_features, in_features)self.fc_proxy = nn.Sequential(nn.Linear(in_features, 2 * in_features),nn.GELU(),nn.Linear(2 * in_features, in_features),)self.apply(self._init_weights)def _init_weights(self, m):if isinstance(m, nn.Linear):trunc_normal_(m.weight, std=.02)if isinstance(m, nn.Linear) and m.bias is not None:nn.init.constant_(m.bias, 0)elif isinstance(m, nn.LayerNorm):if m.bias is not None:nn.init.constant_(m.bias, 0)if m.weight is not None:nn.init.constant_(m.weight, 1.0)elif isinstance(m, nn.Conv2d):fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channelsfan_out //= m.groupsm.weight.data.normal_(0, math.sqrt(2.0 / fan_out))if m.bias is not None:m.bias.data.zero_()def forward(self, x, H, W):x, semantics = torch.split(x, [H * W, x.shape[1] - H * W], dim=1)semantics = self.fc_proxy(semantics)x = self.fc1(x)x = self.dwconv(x, H, W)x = self.act(x)x = self.fc2(x)x = torch.cat([x, semantics], dim=1)return xclass Attention(nn.Module):def __init__(self, dim, num_heads):super().__init__()assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."self.dim = dimself.num_heads = num_headshead_dim = dim // num_headsself.scale = head_dim ** -0.5self.q = nn.Linear(dim, dim)self.kv = nn.Linear(dim, dim * 2)self.proj = nn.Linear(dim, dim)self.apply(self._init_weights)def _init_weights(self, m):if isinstance(m, nn.Linear):trunc_normal_(m.weight, std=.02)if isinstance(m, nn.Linear) and m.bias is not None:nn.init.constant_(m.bias, 0)elif isinstance(m, nn.LayerNorm):if m.bias is not None:nn.init.constant_(m.bias, 0)if m.weight is not None:nn.init.constant_(m.weight, 1.0)elif isinstance(m, nn.Conv2d):fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channelsfan_out //= m.groupsm.weight.data.normal_(0, math.sqrt(2.0 / fan_out))if m.bias is not None:m.bias.data.zero_()def forward(self, x):# x =x.permute(3, 0, 1, 2)B, H, W, C = x.shapeN = H * Wq = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)k, v = kv[0], kv[1]attn = (q @ k.transpose(-2, -1)) * self.scaleattn = attn.softmax(dim=-1)x = (attn @ v).transpose(1, 2).reshape(B, H, W , C)x = self.proj(x)return xclass MergeBlockattention(nn.Module):def __init__(self,input, dim, num_heads=2, mlp_ratio=8, drop_path=0., norm_layer=nn.LayerNorm, is_last=False):super().__init__()self.norm1 = norm_layer(dim)self.norm2 = norm_layer(dim)self.attn = Attention(dim, num_heads)if is_last:self.mlp = PVT2FFN(in_features=dim, hidden_features=int(dim * mlp_ratio))else:self.mlp = MergeFFN(in_features=dim, hidden_features=int(dim * mlp_ratio))self.is_last = is_lastself.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()layer_scale_init_value = 1e-6self.gamma1 = nn.Parameter(layer_scale_init_value * torch.ones((dim)),requires_grad=True) if layer_scale_init_value > 0 else Noneself.gamma2 = nn.Parameter(layer_scale_init_value * torch.ones((dim)),requires_grad=True) if layer_scale_init_value > 0 else Noneself.apply(self._init_weights)def _init_weights(self, m):if isinstance(m, nn.Linear):trunc_normal_(m.weight, std=.02)if isinstance(m, nn.Linear) and m.bias is not None:nn.init.constant_(m.bias, 0)elif isinstance(m, nn.LayerNorm):if m.bias is not None:nn.init.constant_(m.bias, 0)if m.weight is not None:nn.init.constant_(m.weight, 1.0)elif isinstance(m, nn.Conv2d):fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channelsfan_out //= m.groupsm.weight.data.normal_(0, math.sqrt(2.0 / fan_out))if m.bias is not None:m.bias.data.zero_()def forward(self, x):B, C, H, W = x.shapex = x.permute(0, 2, 3, 1)#x = x + self.drop_path(self.gamma1 * self.attn(self.norm1(x)))x =self.attn(self.norm1(x))x = x.permute(0, 3, 2, 1)return x

注册ultralytics/nn/tasks.py

在tasks.py文件的上面导入部分粘贴下面的代码

from ultralytics.nn.attention.dualvit import MergeBlockattention

修改def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)

只需要加入MergeBlockattention,加入以下代码:

 if m in (Classify, Conv, ConvTranspose, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, Focus,BottleneckCSP, C1, C2, C2f, C3, C3TR, C3Ghost, nn.ConvTranspose2d, DWConvTranspose2d, C3x, MergeBlockattention):

 yolov8_DualAttention.yaml


# Ultralytics YOLO 🚀, GPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect# Parameters
nc: 1  # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'# [depth, width, max_channels]n: [0.33, 0.25, 1024]  # YOLOv8n summary: 225 layers,  3157200 parameters,  3157184 gradients,   8.9 GFLOPss: [0.33, 0.50, 1024]  # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients,  28.8 GFLOPsm: [0.67, 0.75, 768]   # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients,  79.3 GFLOPsl: [1.00, 1.00, 512]   # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPsx: [1.00, 1.25, 512]   # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs# YOLOv8.0n backbone
backbone:# [from, repeats, module, args]- [-1, 1, Conv, [64, 3, 2]]  # 0-P1/2- [-1, 1, Conv, [128, 3, 2]]  # 1-P2/4- [-1, 3, C2f, [128, True]]- [-1, 1, Conv, [256, 3, 2]]  # 3-P3/8- [-1, 6, C2f, [256, True]]- [-1, 1, Conv, [512, 3, 2]]  # 5-P4/16- [-1, 6, C2f, [512, True]]- [-1, 1, Conv, [1024, 3, 2]]  # 7-P5/32- [-1, 3, C2f, [1024, True]]- [-1, 1, SPPF, [1024, 5]]  # 9# YOLOv8.0n head
head:- [-1, 1, nn.Upsample, [None, 2, 'nearest']]- [[-1, 6], 1, Concat, [1]]  # cat backbone P4- [-1, 3, C2f, [512]]  # 12- [-1, 1, nn.Upsample, [None, 2, 'nearest']]- [[-1, 4], 1, Concat, [1]]  # cat backbone P3- [-1, 3, C2f, [256]]  # 15 (P3/8-small)- [-1, 1, Conv, [256, 3, 2]]- [[-1, 12], 1, Concat, [1]]  # cat head P4- [-1, 3, C2f, [512]]  # 18 (P4/16-medium)- [-1, 1, Conv, [512, 3, 2]]- [[-1, 9], 1, Concat, [1]]  # cat head P5- [-1, 3, C2f, [1024]]  # 21 (P5/32-large)- [-1, 1, MergeBlockattention, [1024]]  # 21 (P5/32-large)- [[15, 18, 22], 1, Detect, [nc]]  # Detect(P3, P4, P5

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/800972.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

基于Leaflet.js的Marker闪烁特效的实现-模拟预警

目录 前言 一、闪烁组件 1、关于leaflet-icon-pulse 2、 使用leaflet-icon-pulse 3、方法及参数简介 二、闪烁实例开发 1、创建网页 2、Marker闪烁设置 3、实际效果 三、总结 前言 在一些地质灾害或者应急情况当中,或者热门预测当中。我们需要基于时空位置来…

C++练级之路——类和对象(上)

1、类的定义 class 类名{//成员函数 //成员变量}; class为定义的关键字,{ }内是类的主体,注意后面的 ; 不要忘了 类体中的内容成为类的成员,类中的变量为成员变量或类的属性,类中的函数为成员函数或类的方法, 类的两种…

通过Golang获取公网IP地址

在Go语言中,获取当前的外网(公网)IP地址可以通过多种方法实现。其中一种常见的方法是通过访问外部服务来获取。这些服务可以返回访问者的公网IP地址,例如 httpbin.org/ip 或 ipify.org。下面是一个简单的例子,展示了如…

免费云服务器汇总,最长永久免费使用

随着云计算技术的快速发展,越来越多的企业和个人开始将业务迁移到云端。云服务器作为云计算的重要组成部分,以其灵活、高效、可扩展等特点受到广泛关注。然而,许多人在初次接触云服务器时,可能会对高昂的价格望而却步。为了帮助大…

GEE:绘制和对比不同地物的光谱曲线

作者:CSDN @ _养乐多_ 光谱曲线是指在不同波长范围内物体或地表特征对电磁辐射的反射、吸收或发射的表现。这些曲线展示了物体或地表在可见光、红外线、微波等电磁波段上的光谱特征。光谱曲线的形状和特征能够提供关于物体或地表的信息,可以利用光谱曲线来识别和分类不同的地…

Java设计模式—策略模式(商场打折)

策略这个词应该怎么理解,打个比方说,我们出门的时候会选择不同的出行方式,比如骑自行车、坐公交、坐火车、坐飞机、坐火箭等等,这些出行方式,每一种都是一个策略。 再比如我们去逛商场,商场现在正在搞活动&…

Python技能树学习-函数

题目一:递归调用 函数的参数: def dump(index, default0, *args, **kw): print(打印函数参数) print(---) print(index:, index) print(default:, default) for i, arg in enumerate(args): print(farg[{i}]:, arg) for…

Vue 样式技巧总结与整理[中级局]

SFC(单文件组件)由 3 个不同的实体组成:模板、脚本和样式。三者都很重要,但后者往往被忽视,即使它可能变得复杂,且经常导致挫折和 bug。 更好的理解可以改善代码审查并减少调试时间。 这里有 7 个奇技淫巧…

[StartingPoint][Tier2]Archetype

Task 1 Which TCP port is hosting a database server? (哪个端口开放了数据库服务) $ nmap 10.129.95.187 -sC --min-rate 1000 1433 Task 2 What is the name of the non-Administrative share available over SMB? (哪个非管理共享提供了SMB?) $ smbclient -N -L 1…

Rsync——远程同步命令

目录 一、关于Rsync 1.定义 2.Rsync同步方式 3.备份的方式 4.Rsync命令 5.配置源的两种表达方法 二、配置服务端与客户端的实验——下载 1.准备工作 2.服务端配置 3.客户端配置同步 4.免交互数据同步 5.源服务器删除数据是否会同步 6.可以定期执行数据同步 三、关…

JVM的简单介绍

目录 一、JVM的简单介绍 JVM的执行流程 二、JVM中的内存区域划分 1、堆(只有一份) 2、栈(可能有N份) 3、程序计数器(可能有N份) 4、元数据区(只有一份) 经典笔试题 三、JVM…

如何恢复被.locked勒索病毒加密的服务器和数据库?

.locked勒索病毒有什么特点? .locked勒索病毒的特点主要包括以下几个方面: 文件加密:.locked勒索病毒会对受感染设备上的所有文件进行加密,包括图片、文档、视频和其他各种类型的重要文件。一旦文件被加密,文件的扩展…

淘宝商品描述API接口:轻松获取商品信息的新途径

淘宝商品描述API接口是淘宝开放平台提供的一种高效、便捷的新途径,旨在帮助开发者轻松获取淘宝商品的详细描述信息。通过这一接口,商家、开发者和用户都能获得商品标题、描述、属性、价格、图片等关键信息,从而满足各种业务需求。 在使用淘宝…

指针的深入理解(六)

指针的深入理解(六) 个人主页:大白的编程日记 感谢遇见,我们一起学习进步! 文章目录 指针的深入理解(六)前言一. sizeof和strlen1.1sizeof1.2strlen1.3sizeof和strlen对比 二.数组名和指针加减…

前端html+css+js常用总结快速入门

🔥博客主页: A_SHOWY🎥系列专栏:力扣刷题总结录 数据结构 云计算 数字图像处理 力扣每日一题_ 学习前端全套所有技术性价比低下且容易忘记,先入门学会所有基础的语法(cssjsheml)&#xff…

深度挖掘商品信息,jd.item_get API助您呈现商品全面规格参数

深度挖掘商品信息,特别是在电商平台上,对于商家、开发者和用户来说都至关重要。jd.item_get API作为京东开放平台提供的一个强大工具,能够帮助用户轻松获取商品的全面规格参数,进而为商品分析、推荐、比较等提供有力的数据支撑。 …

两相欠压继电器 WY-35A3 额定输入电压100V 导轨安装 JOSEF约瑟

系列型号: WY-35A4电压继电器;WY-35B4电压继电器; WY-35C4电压继电器;WY-35D4电压继电器; WY-35A4D电压继电器;WY-35A4T电压继电器; WY-35B4D电压继电器;WY-35B4T电压继电器&#xf…

【VMware】虚拟机及镜像Ubuntu安装

Vmware 一.VM是什么?有什么用?二.下载VMware Wworkstation Pro三.安装虚拟机四.安装镜像 一.VM是什么?有什么用? vmware是一款运行在windows系统上的虚拟机软件,可以虚拟出一台计算机硬件,方便安装各类操作…

K8s学习七(服务发现_2)

Ingress Service 主要用于集群内部的通信和负载均衡,而 Ingress 则是用于将服务暴露到集群外部,并提供灵活的 HTTP 路由规则。在实际应用中,它们通常结合使用,Service 提供内部通信和负载均衡,Ingress 提供外部访问和…

网络工程师笔记18(关于网络的一些基本知识)

网络的分类 介绍计算机网络的基本概念,这一章最主要的内容是计算机网络的体系结构-ISO 开放系统互连参考模型,其中的基本概念,例如协议实体、协议数据单元,服务数据单元、面向连接的服务和无连接的服务、服务原语、服务访问点、相…