论文名称:《On the Integration of Self-Attention and Convolution》
论文地址:2111.14556 (arxiv.org)
卷积和自注意力是两种强大的表示学习技术,通常被认为是两种截然不同的并列方法。在本文中,我们展示了它们之间存在一种强烈的潜在关系,从计算角度来看,这两个范式的大部分计算实际上是使用相同的操作完成的。具体而言,我们首先表明,传统的卷积操作,具有 k × k k×k k×k 的核大小,可以被分解为 k²
个独立的 1 × 1 1×1 1×1 卷积,然后是移位和求和操作。然后,我们将自注意力模块中查询、键和值的投影解释为多个 1 × 1 1×1 1×1 卷积,然后计算注意力权重并聚合这些值。因此,这两个模块的第一阶段包含了相似的操作。更重要的是,相较于第二阶段,第一阶段在计算复杂性上占据主要地位(频道大小的平方)。这一发现自然引出了这两个看似截然不同的范式的巧妙结合,即一个混合模型,它既享受自注意力和卷积的好处,同时比纯卷积或纯自注意力模型具有最小的计算开销。大量实验表明,我们的模型在图像识别和下游任务上持续取得比竞争基准更好的结果。
问题背景
论文讨论了深度学习中的两大主要方法:卷积和自注意力。在计算机视觉领域,卷积神经网络(CNN)通常是默认的选择,而自注意力则在自然语言处理(NLP)中更常见。该论文的背景在于这些两种技术的不同设计范式,以及它们在计算机视觉任务中的应用。这引发了一个问题:这两者是否可以通过某种方式结合,以实现更好的性能和效率。
核心概念
该论文提出了一种名为ACmix
的新方法,将卷积和自注意力结合在一起。核心概念在于发现这两种技术在计算上有共同点。传统的卷积可以分解为一系列的 1 × 1 1×1 1×1 卷积,而自注意力的投影过程也类似。这种相似性为将这两种方法结合在一起提供了机会,从而形成一种混合模型,既具有卷积的局部特性,又具备自注意力的灵活性。
模块的操作步骤
展示了 ACmix
的概念草图。我们探索了卷积和自注意力之间的密切关系,这种关系在于共享相同的计算负载(1×1
卷积),并结合其余的轻量级聚合操作。我们展示了每个模块相对于特征通道的计算复杂性。
在ACmix
中,操作步骤分为两个阶段:
- 阶段一:将输入特征映射通过 1 × 1 1×1 1×1 卷积进行投影,形成中间特征。
- 阶段二:根据不同的范式,分别应用卷积和自注意力的操作。在卷积路径上,利用
Shift
和Summation
操作实现卷积过程。在自注意力路径上,计算查询、键和值,然后应用传统的自注意力方法。
文章贡献
- 揭示了卷积和自注意力之间的强关联性,提供了一种理解这两者之间关系的新方式。
- 提出了
ACmix
,一种将卷积和自注意力优雅地结合在一起的模型。它能够在没有额外计算负担的情况下,享受这两种方法的优势。
实验结果与应用:
ACmix
在ImageNet
分类、语义分割和目标检测等任务中进行了验证。实验结果表明,与传统的卷积或自注意力模型相比,ACmix
在准确性、计算开销和参数数量上具有优势。在ImageNet
分类任务中,ACmix
的模型在相同的FLOPs
或参数数量下表现出色,并且在与竞争对手的基准比较中取得了持续的改进。此外,ACmix
在ADE20K
语义分割任务和COCO
目标检测任务中也显示出明显的改进,进一步验证了该模型的有效性。
对未来工作的启示:
该论文提出了ACmix
这一混合模型,为计算机视觉领域提供了一种新的方向。它揭示了卷积和自注意力之间的关系,为设计新的学习范式提供了灵感。未来的工作可以在以下几个方面继续探索:
- 在其他自注意力模块中应用
ACmix
,并进一步验证其有效性。 - 研究如何在更大的模型中更好地整合卷积和自注意力,从而在更复杂的任务中实现性能提升。
- 考虑在实际应用中进一步优化
ACmix
的计算效率,以确保其在生产环境中的可用性。
代码
import torch
import torch.nn as nn
import torch.nn.functional as Fdef position(H, W, is_cuda=True):if is_cuda:loc_w = torch.linspace(-1.0, 1.0, W).cuda().unsqueeze(0).repeat(H, 1)loc_h = torch.linspace(-1.0, 1.0, H).cuda().unsqueeze(1).repeat(1, W)else:loc_w = torch.linspace(-1.0, 1.0, W).unsqueeze(0).repeat(H, 1)loc_h = torch.linspace(-1.0, 1.0, H).unsqueeze(1).repeat(1, W)loc = torch.cat([loc_w.unsqueeze(0), loc_h.unsqueeze(0)], 0).unsqueeze(0)return locdef stride(x, stride):b, c, h, w = x.shapereturn x[:, :, ::stride, ::stride]def init_rate_half(tensor):if tensor is not None:tensor.data.fill_(0.5)def init_rate_0(tensor):if tensor is not None:tensor.data.fill_(0.0)class ACmix(nn.Module):def __init__(self,in_planes,out_planes,kernel_att=7,head=4,kernel_conv=3,stride=1,dilation=1,):super(ACmix, self).__init__()self.in_planes = in_planesself.out_planes = out_planesself.head = headself.kernel_att = kernel_attself.kernel_conv = kernel_convself.stride = strideself.dilation = dilationself.rate1 = torch.nn.Parameter(torch.Tensor(1))self.rate2 = torch.nn.Parameter(torch.Tensor(1))self.head_dim = self.out_planes // self.headself.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=1)self.conv2 = nn.Conv2d(in_planes, out_planes, kernel_size=1)self.conv3 = nn.Conv2d(in_planes, out_planes, kernel_size=1)self.conv_p = nn.Conv2d(2, self.head_dim, kernel_size=1)self.padding_att = (self.dilation * (self.kernel_att - 1) + 1) // 2self.pad_att = torch.nn.ReflectionPad2d(self.padding_att)self.unfold = nn.Unfold(kernel_size=self.kernel_att, padding=0, stride=self.stride)self.softmax = torch.nn.Softmax(dim=1)self.fc = nn.Conv2d(3 * self.head,self.kernel_conv * self.kernel_conv,kernel_size=1,bias=False,)self.dep_conv = nn.Conv2d(self.kernel_conv * self.kernel_conv * self.head_dim,out_planes,kernel_size=self.kernel_conv,bias=True,groups=self.head_dim,padding=1,stride=stride,)self.reset_parameters()def reset_parameters(self):init_rate_half(self.rate1)init_rate_half(self.rate2)kernel = torch.zeros(self.kernel_conv * self.kernel_conv, self.kernel_conv, self.kernel_conv)for i in range(self.kernel_conv * self.kernel_conv):kernel[i, i // self.kernel_conv, i % self.kernel_conv] = 1.0kernel = kernel.squeeze(0).repeat(self.out_planes, 1, 1, 1)self.dep_conv.weight = nn.Parameter(data=kernel, requires_grad=True)self.dep_conv.bias = init_rate_0(self.dep_conv.bias)def forward(self, x):q, k, v = self.conv1(x), self.conv2(x), self.conv3(x)scaling = float(self.head_dim) ** -0.5b, c, h, w = q.shapeh_out, w_out = h // self.stride, w // self.stridepe = self.conv_p(position(h, w, x.is_cuda))q_att = q.view(b * self.head, self.head_dim, h, w) * scalingk_att = k.view(b * self.head, self.head_dim, h, w)v_att = v.view(b * self.head, self.head_dim, h, w)if self.stride > 1:q_att = stride(q_att, self.stride)q_pe = stride(pe, self.stride)else:q_pe = peunfold_k = self.unfold(self.pad_att(k_att)).view(b * self.head,self.head_dim,self.kernel_att * self.kernel_att,h_out,w_out,) # b*head, head_dim, k_att^2, h_out, w_outunfold_rpe = self.unfold(self.pad_att(pe)).view(1, self.head_dim, self.kernel_att * self.kernel_att, h_out, w_out) # 1, head_dim, k_att^2, h_out, w_outatt = (q_att.unsqueeze(2) * (unfold_k + q_pe.unsqueeze(2) - unfold_rpe)).sum(1) # (b*head, head_dim, 1, h_out, w_out) * (b*head, head_dim, k_att^2, h_out, w_out) -> (b*head, k_att^2, h_out, w_out)att = self.softmax(att)out_att = self.unfold(self.pad_att(v_att)).view(b * self.head,self.head_dim,self.kernel_att * self.kernel_att,h_out,w_out,)out_att = ((att.unsqueeze(1) * out_att).sum(2).view(b, self.out_planes, h_out, w_out))f_all = self.fc(torch.cat([q.view(b, self.head, self.head_dim, h * w),k.view(b, self.head, self.head_dim, h * w),v.view(b, self.head, self.head_dim, h * w),],1,))f_conv = f_all.permute(0, 2, 1, 3).reshape(x.shape[0], -1, x.shape[-2], x.shape[-1])out_conv = self.dep_conv(f_conv)return self.rate1 * out_att + self.rate2 * out_convif __name__ == "__main__":input = torch.randn(64, 256, 8, 8)model = ACmix(in_planes=256, out_planes=256)output = model(input)print(output.shape)