将自注意力过程分解为区域和局部特征提取过程,每个过程产生的计算复杂度要小得多。然而,区域信息通常仅以由于下采样而丢失的不希望的信息为代价。在本文中,作者提出了一种旨在缓解成本问题的新型Transformer架构,称为双视觉Transformer(Dual ViT)。新架构结合了一个关键的语义路径,可以更有效地将token向量压缩为全局语义,并降低复杂性。这种压缩的全局语义通过另一个构建的像素路径,作为学习内部像素级细节的有用先验信息。然后将语义路径和像素路径整合在一起,并进行联合训练,通过这两条路径并行传播增强的自注意力信息。因此,双ViT能够在不影响精度的情况下降低计算复杂度。实证证明,双ViT比SOTA Transformer架构提供了更高的精度,同时降低了训练复杂度。
1) 提出了一种新的Transformer架构,称为双视觉Transformer(双ViT)。顾名思义,双ViT网络包括两条路径,分别用于提取输入语义特征的更全面全局视图,以及另一条专注于学习内部局部特征的像素路径。
3) 与VOLO相比,双ViT在ImageNet上实现了85.7%的top-1精度,只有41.1%的浮点运算和37.8%的参数。在目标检测和实例分割方面,双ViT在映射方面也提高了PVT,在COCO上分别提高了1.2%和0.9%,参数减少了48.0%。
2.yolov8加入Dual Vision Transformer的步骤:
2.1 新建加入ultralytics/nn/attention/dualvit.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import partialfrom timm.models.layers import DropPath, to_2tuple, trunc_normal_
import mathclass DWConv(nn.Module):def __init__(self, dim=768):super(DWConv, self).__init__()self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)def forward(self, x, H, W):B, N, C = x.shapex = x.transpose(1, 2).view(B, C, H, W)x = self.dwconv(x)x = x.flatten(2).transpose(1, 2)return xclass PVT2FFN(nn.Module):def __init__(self, in_features, hidden_features):super().__init__()self.fc1 = nn.Linear(in_features, hidden_features)self.dwconv = DWConv(hidden_features)self.act = nn.GELU()self.fc2 = nn.Linear(hidden_features, in_features)self.apply(self._init_weights)def _init_weights(self, m):if isinstance(m, nn.Linear):trunc_normal_(m.weight, std=.02)if isinstance(m, nn.Linear) and m.bias is not None:nn.init.constant_(m.bias, 0)elif isinstance(m, nn.LayerNorm):if m.bias is not None:nn.init.constant_(m.bias, 0)if m.weight is not None:nn.init.constant_(m.weight, 1.0)elif isinstance(m, nn.Conv2d):fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channelsfan_out //= m.groupsm.weight.data.normal_(0, math.sqrt(2.0 / fan_out))if m.bias is not None:m.bias.data.zero_()def forward(self, x, H, W):x = self.fc1(x)x = self.dwconv(x, H, W)x = self.act(x)x = self.fc2(x)return xclass MergeFFN(nn.Module):def __init__(self, in_features, hidden_features):super().__init__()self.fc1 = nn.Linear(in_features, hidden_features)self.dwconv = DWConv(hidden_features)self.act = nn.GELU()self.fc2 = nn.Linear(hidden_features, in_features)self.fc_proxy = nn.Sequential(nn.Linear(in_features, 2 * in_features),nn.GELU(),nn.Linear(2 * in_features, in_features),)self.apply(self._init_weights)def _init_weights(self, m):if isinstance(m, nn.Linear):trunc_normal_(m.weight, std=.02)if isinstance(m, nn.Linear) and m.bias is not None:nn.init.constant_(m.bias, 0)elif isinstance(m, nn.LayerNorm):if m.bias is not None:nn.init.constant_(m.bias, 0)if m.weight is not None:nn.init.constant_(m.weight, 1.0)elif isinstance(m, nn.Conv2d):fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channelsfan_out //= m.groupsm.weight.data.normal_(0, math.sqrt(2.0 / fan_out))if m.bias is not None:m.bias.data.zero_()def forward(self, x, H, W):x, semantics = torch.split(x, [H * W, x.shape[1] - H * W], dim=1)semantics = self.fc_proxy(semantics)x = self.fc1(x)x = self.dwconv(x, H, W)x = self.act(x)x = self.fc2(x)x = torch.cat([x, semantics], dim=1)return xclass Attention(nn.Module):def __init__(self, dim, num_heads):super().__init__()assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."self.dim = dimself.num_heads = num_headshead_dim = dim // num_headsself.scale = head_dim ** -0.5self.q = nn.Linear(dim, dim)self.kv = nn.Linear(dim, dim * 2)self.proj = nn.Linear(dim, dim)self.apply(self._init_weights)def _init_weights(self, m):if isinstance(m, nn.Linear):trunc_normal_(m.weight, std=.02)if isinstance(m, nn.Linear) and m.bias is not None:nn.init.constant_(m.bias, 0)elif isinstance(m, nn.LayerNorm):if m.bias is not None:nn.init.constant_(m.bias, 0)if m.weight is not None:nn.init.constant_(m.weight, 1.0)elif isinstance(m, nn.Conv2d):fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channelsfan_out //= m.groupsm.weight.data.normal_(0, math.sqrt(2.0 / fan_out))if m.bias is not None:m.bias.data.zero_()def forward(self, x):# x =x.permute(3, 0, 1, 2)B, H, W, C = x.shapeN = H * Wq = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)k, v = kv[0], kv[1]attn = (q @ k.transpose(-2, -1)) * self.scaleattn = attn.softmax(dim=-1)x = (attn @ v).transpose(1, 2).reshape(B, H, W , C)x = self.proj(x)return xclass MergeBlockattention(nn.Module):def __init__(self,input, dim, num_heads=2, mlp_ratio=8, drop_path=0., norm_layer=nn.LayerNorm, is_last=False):super().__init__()self.norm1 = norm_layer(dim)self.norm2 = norm_layer(dim)self.attn = Attention(dim, num_heads)if is_last:self.mlp = PVT2FFN(in_features=dim, hidden_features=int(dim * mlp_ratio))else:self.mlp = MergeFFN(in_features=dim, hidden_features=int(dim * mlp_ratio))self.is_last = is_lastself.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()layer_scale_init_value = 1e-6self.gamma1 = nn.Parameter(layer_scale_init_value * torch.ones((dim)),requires_grad=True) if layer_scale_init_value > 0 else Noneself.gamma2 = nn.Parameter(layer_scale_init_value * torch.ones((dim)),requires_grad=True) if layer_scale_init_value > 0 else Noneself.apply(self._init_weights)def _init_weights(self, m):if isinstance(m, nn.Linear):trunc_normal_(m.weight, std=.02)if isinstance(m, nn.Linear) and m.bias is not None:nn.init.constant_(m.bias, 0)elif isinstance(m, nn.LayerNorm):if m.bias is not None:nn.init.constant_(m.bias, 0)if m.weight is not None:nn.init.constant_(m.weight, 1.0)elif isinstance(m, nn.Conv2d):fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channelsfan_out //= m.groupsm.weight.data.normal_(0, math.sqrt(2.0 / fan_out))if m.bias is not None:m.bias.data.zero_()def forward(self, x):B, C, H, W = x.shapex = x.permute(0, 2, 3, 1)#x = x + self.drop_path(self.gamma1 * self.attn(self.norm1(x)))x =self.attn(self.norm1(x))x = x.permute(0, 3, 2, 1)return x
from ultralytics.nn.attention.dualvit import MergeBlockattention
修改def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
if m in (Classify, Conv, ConvTranspose, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, Focus,BottleneckCSP, C1, C2, C2f, C3, C3TR, C3Ghost, nn.ConvTranspose2d, DWConvTranspose2d, C3x, MergeBlockattention):
# Ultralytics YOLO 🚀, GPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect# Parameters
nc: 1 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'# [depth, width, max_channels]n: [0.33, 0.25, 1024] # YOLOv8n summary: 225 layers, 3157200 parameters, 3157184 gradients, 8.9 GFLOPss: [0.33, 0.50, 1024] # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients, 28.8 GFLOPsm: [0.67, 0.75, 768] # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients, 79.3 GFLOPsl: [1.00, 1.00, 512] # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPsx: [1.00, 1.25, 512] # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs# YOLOv8.0n backbone
backbone:# [from, repeats, module, args]- [-1, 1, Conv, [64, 3, 2]] # 0-P1/2- [-1, 1, Conv, [128, 3, 2]] # 1-P2/4- [-1, 3, C2f, [128, True]]- [-1, 1, Conv, [256, 3, 2]] # 3-P3/8- [-1, 6, C2f, [256, True]]- [-1, 1, Conv, [512, 3, 2]] # 5-P4/16- [-1, 6, C2f, [512, True]]- [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32- [-1, 3, C2f, [1024, True]]- [-1, 1, SPPF, [1024, 5]] # 9# YOLOv8.0n head
head:- [-1, 1, nn.Upsample, [None, 2, 'nearest']]- [[-1, 6], 1, Concat, [1]] # cat backbone P4- [-1, 3, C2f, [512]] # 12- [-1, 1, nn.Upsample, [None, 2, 'nearest']]- [[-1, 4], 1, Concat, [1]] # cat backbone P3- [-1, 3, C2f, [256]] # 15 (P3/8-small)- [-1, 1, Conv, [256, 3, 2]]- [[-1, 12], 1, Concat, [1]] # cat head P4- [-1, 3, C2f, [512]] # 18 (P4/16-medium)- [-1, 1, Conv, [512, 3, 2]]- [[-1, 9], 1, Concat, [1]] # cat head P5- [-1, 3, C2f, [1024]] # 21 (P5/32-large)- [-1, 1, MergeBlockattention, [1024]] # 21 (P5/32-large)- [[15, 18, 22], 1, Detect, [nc]] # Detect(P3, P4, P5