原论文地址:原论文下载网址
论文相关内容介绍
将自注意力过程分解为区域和局部特征提取过程,每个过程产生的计算复杂度要小得多。然而,区域信息通常仅以由于下采样而丢失的不希望的信息为代价。在本文中,作者提出了一种旨在缓解成本问题的新型Transformer架构,称为双视觉Transformer(Dual ViT)。新架构结合了一个关键的语义路径,可以更有效地将token向量压缩为全局语义,并降低复杂性。这种压缩的全局语义通过另一个构建的像素路径,作为学习内部像素级细节的有用先验信息。然后将语义路径和像素路径整合在一起,并进行联合训练,通过这两条路径并行传播增强的自注意力信息。因此,双ViT能够在不影响精度的情况下降低计算复杂度。实证证明,双ViT比SOTA Transformer架构提供了更高的精度,同时降低了训练复杂度。
Transformer结构在革新深度学习应用方面取得了巨大成功,包括自然语言处理和计算机视觉任务。不幸的是,由于Transformer通常依赖密集的自注意力计算,因此对于高分辨率输入,此类架构的训练通常很慢。由于transformer技术通常可以提供比同类技术更高的性能,因此这种复杂性问题逐渐成为制约这种强大体系结构发展的瓶颈。
自注意力过程是此类复杂性问题的主要负担,因为每个token的每个表示都是通过关注所有token来更新的。最近的工作将重点放在研究复杂性问题上,通过提供不同于标准自注意力的替代解决方案。许多人考虑将自注意力与下采样相结合,以有效地取代原来的标准注意力。这种方式自然能够探索区域语义信息,从而进一步促进局部特征的学习/提取。例如,PVT提出了线性空间减少注意(SRA),该注意通过下采样操作(例如,平均池或跨步卷积)减少键和值的空间比例,如图1(a)所示。Twins(上图(b))在SRA之前添加了额外的局部分组自注意力层,以通过区域内相互作用进一步增强表示。RegionViT(上图(c))通过区域和局部自注意力分解原始注意力。然而,由于上述方法严重依赖于特征映射到区域的下采样,在有效节省总计算成本的同时,观察到了明显的性能下降。
在这些不同的组合策略中,很少有人试图研究全局语义和内部像素级特征之间在降低复杂性方面的依赖关系。在本文中,作者考虑通过提出的双重ViT将训练分解为全局语义和内部特征注意。其动机是提取全局语义信息(即参数语义查询),这些信息可以作为丰富的先验信息,在新的双通道设计中帮助用户进行局部特征提取。本文对全局语义和局部特征的独特分解和集成允许有效减少多头注意力中涉及的token数量,从而与标准注意对应项相比节省了计算复杂性。特别是,如上图(d)所示,双ViT由两个特殊路径组成,分别称为“语义路径”和“像素路径”。通过构造的“像素路径”进行局部像素级特征提取是强烈依赖于“语义路径”之外的压缩全局先验。由于梯度同时通过语义路径和像素路径,因此双ViT训练过程可以有效地补偿全局特征压缩的信息损失,同时减少局部特征提取的困难。前者和后者都可以并行显著降低计算成本,因为注意力大小较小,并且两条路径之间存在强制依赖关系。
综上所述,本文做出了以下贡献:
1) 提出了一种新的Transformer架构,称为双视觉Transformer(双ViT)。顾名思义,双ViT网络包括两条路径,分别用于提取输入语义特征的更全面全局视图,以及另一条专注于学习内部局部特征的像素路径。
2)双ViT考虑了两条路径上全局语义和局部特征之间的依赖关系,目的是通过减少token大小和注意力来简化训练。
3) 与VOLO相比,双ViT在ImageNet上实现了85.7%的top-1精度,只有41.1%的浮点运算和37.8%的参数。在目标检测和实例分割方面,双ViT在映射方面也提高了PVT,在COCO上分别提高了1.2%和0.9%,参数减少了48.0%。
具体的方法如何实现大家可以参考原论文
2.yolov8加入Dual Vision Transformer的步骤:
2.1 新建加入ultralytics/nn/attention/dualvit.py
注意在nn文件夹下新建attention文件夹,创建dualvit.py文件,导入下面的代码
import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import partialfrom timm.models.layers import DropPath, to_2tuple, trunc_normal_
import mathclass DWConv(nn.Module):def __init__(self, dim=768):super(DWConv, self).__init__()self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)def forward(self, x, H, W):B, N, C = x.shapex = x.transpose(1, 2).view(B, C, H, W)x = self.dwconv(x)x = x.flatten(2).transpose(1, 2)return xclass PVT2FFN(nn.Module):def __init__(self, in_features, hidden_features):super().__init__()self.fc1 = nn.Linear(in_features, hidden_features)self.dwconv = DWConv(hidden_features)self.act = nn.GELU()self.fc2 = nn.Linear(hidden_features, in_features)self.apply(self._init_weights)def _init_weights(self, m):if isinstance(m, nn.Linear):trunc_normal_(m.weight, std=.02)if isinstance(m, nn.Linear) and m.bias is not None:nn.init.constant_(m.bias, 0)elif isinstance(m, nn.LayerNorm):if m.bias is not None:nn.init.constant_(m.bias, 0)if m.weight is not None:nn.init.constant_(m.weight, 1.0)elif isinstance(m, nn.Conv2d):fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channelsfan_out //= m.groupsm.weight.data.normal_(0, math.sqrt(2.0 / fan_out))if m.bias is not None:m.bias.data.zero_()def forward(self, x, H, W):x = self.fc1(x)x = self.dwconv(x, H, W)x = self.act(x)x = self.fc2(x)return xclass MergeFFN(nn.Module):def __init__(self, in_features, hidden_features):super().__init__()self.fc1 = nn.Linear(in_features, hidden_features)self.dwconv = DWConv(hidden_features)self.act = nn.GELU()self.fc2 = nn.Linear(hidden_features, in_features)self.fc_proxy = nn.Sequential(nn.Linear(in_features, 2 * in_features),nn.GELU(),nn.Linear(2 * in_features, in_features),)self.apply(self._init_weights)def _init_weights(self, m):if isinstance(m, nn.Linear):trunc_normal_(m.weight, std=.02)if isinstance(m, nn.Linear) and m.bias is not None:nn.init.constant_(m.bias, 0)elif isinstance(m, nn.LayerNorm):if m.bias is not None:nn.init.constant_(m.bias, 0)if m.weight is not None:nn.init.constant_(m.weight, 1.0)elif isinstance(m, nn.Conv2d):fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channelsfan_out //= m.groupsm.weight.data.normal_(0, math.sqrt(2.0 / fan_out))if m.bias is not None:m.bias.data.zero_()def forward(self, x, H, W):x, semantics = torch.split(x, [H * W, x.shape[1] - H * W], dim=1)semantics = self.fc_proxy(semantics)x = self.fc1(x)x = self.dwconv(x, H, W)x = self.act(x)x = self.fc2(x)x = torch.cat([x, semantics], dim=1)return xclass Attention(nn.Module):def __init__(self, dim, num_heads):super().__init__()assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."self.dim = dimself.num_heads = num_headshead_dim = dim // num_headsself.scale = head_dim ** -0.5self.q = nn.Linear(dim, dim)self.kv = nn.Linear(dim, dim * 2)self.proj = nn.Linear(dim, dim)self.apply(self._init_weights)def _init_weights(self, m):if isinstance(m, nn.Linear):trunc_normal_(m.weight, std=.02)if isinstance(m, nn.Linear) and m.bias is not None:nn.init.constant_(m.bias, 0)elif isinstance(m, nn.LayerNorm):if m.bias is not None:nn.init.constant_(m.bias, 0)if m.weight is not None:nn.init.constant_(m.weight, 1.0)elif isinstance(m, nn.Conv2d):fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channelsfan_out //= m.groupsm.weight.data.normal_(0, math.sqrt(2.0 / fan_out))if m.bias is not None:m.bias.data.zero_()def forward(self, x):# x =x.permute(3, 0, 1, 2)B, H, W, C = x.shapeN = H * Wq = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)k, v = kv[0], kv[1]attn = (q @ k.transpose(-2, -1)) * self.scaleattn = attn.softmax(dim=-1)x = (attn @ v).transpose(1, 2).reshape(B, H, W , C)x = self.proj(x)return xclass MergeBlockattention(nn.Module):def __init__(self,input, dim, num_heads=2, mlp_ratio=8, drop_path=0., norm_layer=nn.LayerNorm, is_last=False):super().__init__()self.norm1 = norm_layer(dim)self.norm2 = norm_layer(dim)self.attn = Attention(dim, num_heads)if is_last:self.mlp = PVT2FFN(in_features=dim, hidden_features=int(dim * mlp_ratio))else:self.mlp = MergeFFN(in_features=dim, hidden_features=int(dim * mlp_ratio))self.is_last = is_lastself.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()layer_scale_init_value = 1e-6self.gamma1 = nn.Parameter(layer_scale_init_value * torch.ones((dim)),requires_grad=True) if layer_scale_init_value > 0 else Noneself.gamma2 = nn.Parameter(layer_scale_init_value * torch.ones((dim)),requires_grad=True) if layer_scale_init_value > 0 else Noneself.apply(self._init_weights)def _init_weights(self, m):if isinstance(m, nn.Linear):trunc_normal_(m.weight, std=.02)if isinstance(m, nn.Linear) and m.bias is not None:nn.init.constant_(m.bias, 0)elif isinstance(m, nn.LayerNorm):if m.bias is not None:nn.init.constant_(m.bias, 0)if m.weight is not None:nn.init.constant_(m.weight, 1.0)elif isinstance(m, nn.Conv2d):fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channelsfan_out //= m.groupsm.weight.data.normal_(0, math.sqrt(2.0 / fan_out))if m.bias is not None:m.bias.data.zero_()def forward(self, x):B, C, H, W = x.shapex = x.permute(0, 2, 3, 1)#x = x + self.drop_path(self.gamma1 * self.attn(self.norm1(x)))x =self.attn(self.norm1(x))x = x.permute(0, 3, 2, 1)return x
注册ultralytics/nn/tasks.py
在tasks.py文件的上面导入部分粘贴下面的代码
from ultralytics.nn.attention.dualvit import MergeBlockattention
修改def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
只需要加入MergeBlockattention,加入以下代码:
if m in (Classify, Conv, ConvTranspose, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, Focus,BottleneckCSP, C1, C2, C2f, C3, C3TR, C3Ghost, nn.ConvTranspose2d, DWConvTranspose2d, C3x, MergeBlockattention):
yolov8_DualAttention.yaml
# Ultralytics YOLO 🚀, GPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect# Parameters
nc: 1 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'# [depth, width, max_channels]n: [0.33, 0.25, 1024] # YOLOv8n summary: 225 layers, 3157200 parameters, 3157184 gradients, 8.9 GFLOPss: [0.33, 0.50, 1024] # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients, 28.8 GFLOPsm: [0.67, 0.75, 768] # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients, 79.3 GFLOPsl: [1.00, 1.00, 512] # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPsx: [1.00, 1.25, 512] # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs# YOLOv8.0n backbone
backbone:# [from, repeats, module, args]- [-1, 1, Conv, [64, 3, 2]] # 0-P1/2- [-1, 1, Conv, [128, 3, 2]] # 1-P2/4- [-1, 3, C2f, [128, True]]- [-1, 1, Conv, [256, 3, 2]] # 3-P3/8- [-1, 6, C2f, [256, True]]- [-1, 1, Conv, [512, 3, 2]] # 5-P4/16- [-1, 6, C2f, [512, True]]- [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32- [-1, 3, C2f, [1024, True]]- [-1, 1, SPPF, [1024, 5]] # 9# YOLOv8.0n head
head:- [-1, 1, nn.Upsample, [None, 2, 'nearest']]- [[-1, 6], 1, Concat, [1]] # cat backbone P4- [-1, 3, C2f, [512]] # 12- [-1, 1, nn.Upsample, [None, 2, 'nearest']]- [[-1, 4], 1, Concat, [1]] # cat backbone P3- [-1, 3, C2f, [256]] # 15 (P3/8-small)- [-1, 1, Conv, [256, 3, 2]]- [[-1, 12], 1, Concat, [1]] # cat head P4- [-1, 3, C2f, [512]] # 18 (P4/16-medium)- [-1, 1, Conv, [512, 3, 2]]- [[-1, 9], 1, Concat, [1]] # cat head P5- [-1, 3, C2f, [1024]] # 21 (P5/32-large)- [-1, 1, MergeBlockattention, [1024]] # 21 (P5/32-large)- [[15, 18, 22], 1, Detect, [nc]] # Detect(P3, P4, P5