目录
1. SelfAttention自注意力层
2. SwinTransformer + SelfAttention
3. 代码
1. SelfAttention自注意力层
Self-Attention自注意力层是一种在神经网络中用于处理序列数据的注意力机制。它通过对输入序列中的不同位置进行关注,来计算每个位置与其他位置的关联程度,并根据这些关联程度对输入序列进行加权。
自注意力层的计算过程如下:
- 首先,通过对输入序列中的每一对位置计算一个相关性得分。这可以通过计算输入序列中两个位置之间的点积来实现。得分越高表示两个位置之间的相关性越强。
- 然后,对得分进行归一化处理,以确保它们的总和为1。这可以通过将得分除以一个较大的数值来实现,以避免过大的得分。
- 接下来,将归一化后的得分与输入序列进行加权求和,得到自注意力层的输出。加权求和时,得分越高的位置对应的向量将会被分配更大的权重。
自注意力层的优势在于它能够利用序列中的局部和全局信息,从而更好地捕捉序列中不同位置之间的依赖关系。在自然语言处理领域中,自注意力层被广泛应用于机器翻译、文本分类和阅读理解等任务中。
实现代码如下:
# 定义自注意力层
class SelfAttention(nn.Module):def __init__(self, in_channels):super(SelfAttention, self).__init__()self.query_conv = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)self.key_conv = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)self.value_conv = nn.Conv2d(in_channels, in_channels, kernel_size=1)self.gamma = nn.Parameter(torch.zeros(1))def forward(self, x):batch_size, channels, height, width = x.size()query = self.query_conv(x).view(batch_size, -1, height * width).permute(0, 2, 1)key = self.key_conv(x).view(batch_size, -1, height * width)energy = torch.bmm(query, key)attention = torch.softmax(energy, dim=-1)value = self.value_conv(x).view(batch_size, -1, height * width)out = torch.bmm(value, attention.permute(0, 2, 1))out = out.view(batch_size, channels, height, width)out = self.gamma * out + xreturn out
想要完整的分类代码,请参考 本章 ,将下文的model替换即可
2. SwinTransformer + SelfAttention
SwinTransformer 网络结构如下:
本文在 SwinTransformer 最后一个SwinTransformerBlock 添加SelfAttention模块
添加如下:其中sa部分就是添加的模块
(0): SwinTransformerBlock((norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)(attn): ShiftedWindowAttention((qkv): Linear(in_features=768, out_features=2304, bias=True)(proj): Linear(in_features=768, out_features=768, bias=True))(stochastic_depth): StochasticDepth(p=0.18181818181818182, mode=row)(norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)(mlp): MLP((0): Linear(in_features=768, out_features=3072, bias=True(sa): SelfAttention((query_conv): Conv2d(3072, 384, kernel_size=(1, 1), stride=(1, 1))(key_conv): Conv2d(3072, 384, kernel_size=(1, 1), stride=(1, 1))(value_conv): Conv2d(3072, 3072, kernel_size=(1, 1), stride=(1, 1))))(1): GELU(approximate='none')(2): Dropout(p=0.0, inplace=False)(3): Linear(in_features=3072, out_features=768, bias=True(sa): SelfAttention((query_conv): Conv2d(768, 96, kernel_size=(1, 1), stride=(1, 1))(key_conv): Conv2d(768, 96, kernel_size=(1, 1), stride=(1, 1))(value_conv): Conv2d(768, 768, kernel_size=(1, 1), stride=(1, 1))))(4): Dropout(p=0.0, inplace=False)
3. 代码
完整代码:
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'import torch
import torch.nn as nn
import torchvision.models as m# 定义自注意力层
class SelfAttention(nn.Module):def __init__(self, in_channels):super(SelfAttention, self).__init__()self.query_conv = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)self.key_conv = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)self.value_conv = nn.Conv2d(in_channels, in_channels, kernel_size=1)self.gamma = nn.Parameter(torch.zeros(1))def forward(self, x):batch_size, channels, height, width = x.size()query = self.query_conv(x).view(batch_size, -1, height * width).permute(0, 2, 1)key = self.key_conv(x).view(batch_size, -1, height * width)energy = torch.bmm(query, key)attention = torch.softmax(energy, dim=-1)value = self.value_conv(x).view(batch_size, -1, height * width)out = torch.bmm(value, attention.permute(0, 2, 1))out = out.view(batch_size, channels, height, width)out = self.gamma * out + xreturn out# 获取网络模型
def create_model(model,num,weights):if model == 't':net = m.swin_t(weights=m.Swin_T_Weights.DEFAULT if weights else False,progress=True)elif model == 's':net = m.swin_s(weights=m.Swin_S_Weights.DEFAULT if weights else False,progress=True)elif model == 'b':net = m.swin_b(weights=m.Swin_B_Weights.DEFAULT if weights else False,progress=True)else:print('模型选择错误!!')return Nonetmp = net.head.in_featuresnet.head = torch.nn.Linear(tmp,num,bias=True)# 添加模块net.features[7][0].mlp[0].add_module('sa',SelfAttention(list(net.features)[7][0].mlp[0].out_features))net.features[7][0].mlp[3].add_module('sa',SelfAttention(list(net.features)[7][0].mlp[3].out_features))print(net)return netif __name__ == '__main__':model = create_model(model='t',num=10,weights=False)i = torch.randn(1,3,224,224)o = model(i)print(o.size())
网络结构:
SwinTransformer(
(features): Sequential(
(0): Sequential(
(0): Conv2d(3, 96, kernel_size=(4, 4), stride=(4, 4))
(1): Permute()
(2): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
)
(1): Sequential(
(0): SwinTransformerBlock(
(norm1): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
(attn): ShiftedWindowAttention(
(qkv): Linear(in_features=96, out_features=288, bias=True)
(proj): Linear(in_features=96, out_features=96, bias=True)
)
(stochastic_depth): StochasticDepth(p=0.0, mode=row)
(norm2): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
(mlp): MLP(
(0): Linear(in_features=96, out_features=384, bias=True)
(1): GELU(approximate='none')
(2): Dropout(p=0.0, inplace=False)
(3): Linear(in_features=384, out_features=96, bias=True)
(4): Dropout(p=0.0, inplace=False)
)
)
(1): SwinTransformerBlock(
(norm1): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
(attn): ShiftedWindowAttention(
(qkv): Linear(in_features=96, out_features=288, bias=True)
(proj): Linear(in_features=96, out_features=96, bias=True)
)
(stochastic_depth): StochasticDepth(p=0.018181818181818184, mode=row)
(norm2): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
(mlp): MLP(
(0): Linear(in_features=96, out_features=384, bias=True)
(1): GELU(approximate='none')
(2): Dropout(p=0.0, inplace=False)
(3): Linear(in_features=384, out_features=96, bias=True)
(4): Dropout(p=0.0, inplace=False)
)
)
)
(2): PatchMerging(
(reduction): Linear(in_features=384, out_features=192, bias=False)
(norm): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
)
(3): Sequential(
(0): SwinTransformerBlock(
(norm1): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
(attn): ShiftedWindowAttention(
(qkv): Linear(in_features=192, out_features=576, bias=True)
(proj): Linear(in_features=192, out_features=192, bias=True)
)
(stochastic_depth): StochasticDepth(p=0.03636363636363637, mode=row)
(norm2): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
(mlp): MLP(
(0): Linear(in_features=192, out_features=768, bias=True)
(1): GELU(approximate='none')
(2): Dropout(p=0.0, inplace=False)
(3): Linear(in_features=768, out_features=192, bias=True)
(4): Dropout(p=0.0, inplace=False)
)
)
(1): SwinTransformerBlock(
(norm1): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
(attn): ShiftedWindowAttention(
(qkv): Linear(in_features=192, out_features=576, bias=True)
(proj): Linear(in_features=192, out_features=192, bias=True)
)
(stochastic_depth): StochasticDepth(p=0.05454545454545456, mode=row)
(norm2): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
(mlp): MLP(
(0): Linear(in_features=192, out_features=768, bias=True)
(1): GELU(approximate='none')
(2): Dropout(p=0.0, inplace=False)
(3): Linear(in_features=768, out_features=192, bias=True)
(4): Dropout(p=0.0, inplace=False)
)
)
)
(4): PatchMerging(
(reduction): Linear(in_features=768, out_features=384, bias=False)
(norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)
(5): Sequential(
(0): SwinTransformerBlock(
(norm1): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
(attn): ShiftedWindowAttention(
(qkv): Linear(in_features=384, out_features=1152, bias=True)
(proj): Linear(in_features=384, out_features=384, bias=True)
)
(stochastic_depth): StochasticDepth(p=0.07272727272727274, mode=row)
(norm2): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
(mlp): MLP(
(0): Linear(in_features=384, out_features=1536, bias=True)
(1): GELU(approximate='none')
(2): Dropout(p=0.0, inplace=False)
(3): Linear(in_features=1536, out_features=384, bias=True)
(4): Dropout(p=0.0, inplace=False)
)
)
(1): SwinTransformerBlock(
(norm1): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
(attn): ShiftedWindowAttention(
(qkv): Linear(in_features=384, out_features=1152, bias=True)
(proj): Linear(in_features=384, out_features=384, bias=True)
)
(stochastic_depth): StochasticDepth(p=0.09090909090909091, mode=row)
(norm2): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
(mlp): MLP(
(0): Linear(in_features=384, out_features=1536, bias=True)
(1): GELU(approximate='none')
(2): Dropout(p=0.0, inplace=False)
(3): Linear(in_features=1536, out_features=384, bias=True)
(4): Dropout(p=0.0, inplace=False)
)
)
(2): SwinTransformerBlock(
(norm1): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
(attn): ShiftedWindowAttention(
(qkv): Linear(in_features=384, out_features=1152, bias=True)
(proj): Linear(in_features=384, out_features=384, bias=True)
)
(stochastic_depth): StochasticDepth(p=0.10909090909090911, mode=row)
(norm2): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
(mlp): MLP(
(0): Linear(in_features=384, out_features=1536, bias=True)
(1): GELU(approximate='none')
(2): Dropout(p=0.0, inplace=False)
(3): Linear(in_features=1536, out_features=384, bias=True)
(4): Dropout(p=0.0, inplace=False)
)
)
(3): SwinTransformerBlock(
(norm1): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
(attn): ShiftedWindowAttention(
(qkv): Linear(in_features=384, out_features=1152, bias=True)
(proj): Linear(in_features=384, out_features=384, bias=True)
)
(stochastic_depth): StochasticDepth(p=0.1272727272727273, mode=row)
(norm2): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
(mlp): MLP(
(0): Linear(in_features=384, out_features=1536, bias=True)
(1): GELU(approximate='none')
(2): Dropout(p=0.0, inplace=False)
(3): Linear(in_features=1536, out_features=384, bias=True)
(4): Dropout(p=0.0, inplace=False)
)
)
(4): SwinTransformerBlock(
(norm1): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
(attn): ShiftedWindowAttention(
(qkv): Linear(in_features=384, out_features=1152, bias=True)
(proj): Linear(in_features=384, out_features=384, bias=True)
)
(stochastic_depth): StochasticDepth(p=0.14545454545454548, mode=row)
(norm2): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
(mlp): MLP(
(0): Linear(in_features=384, out_features=1536, bias=True)
(1): GELU(approximate='none')
(2): Dropout(p=0.0, inplace=False)
(3): Linear(in_features=1536, out_features=384, bias=True)
(4): Dropout(p=0.0, inplace=False)
)
)
(5): SwinTransformerBlock(
(norm1): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
(attn): ShiftedWindowAttention(
(qkv): Linear(in_features=384, out_features=1152, bias=True)
(proj): Linear(in_features=384, out_features=384, bias=True)
)
(stochastic_depth): StochasticDepth(p=0.16363636363636364, mode=row)
(norm2): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
(mlp): MLP(
(0): Linear(in_features=384, out_features=1536, bias=True)
(1): GELU(approximate='none')
(2): Dropout(p=0.0, inplace=False)
(3): Linear(in_features=1536, out_features=384, bias=True)
(4): Dropout(p=0.0, inplace=False)
)
)
)
(6): PatchMerging(
(reduction): Linear(in_features=1536, out_features=768, bias=False)
(norm): LayerNorm((1536,), eps=1e-05, elementwise_affine=True)
)
(7): Sequential(
(0): SwinTransformerBlock(
(norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(attn): ShiftedWindowAttention(
(qkv): Linear(in_features=768, out_features=2304, bias=True)
(proj): Linear(in_features=768, out_features=768, bias=True)
)
(stochastic_depth): StochasticDepth(p=0.18181818181818182, mode=row)
(norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(mlp): MLP(
(0): Linear(
in_features=768, out_features=3072, bias=True
(sa): SelfAttention(
(query_conv): Conv2d(3072, 384, kernel_size=(1, 1), stride=(1, 1))
(key_conv): Conv2d(3072, 384, kernel_size=(1, 1), stride=(1, 1))
(value_conv): Conv2d(3072, 3072, kernel_size=(1, 1), stride=(1, 1))
)
)
(1): GELU(approximate='none')
(2): Dropout(p=0.0, inplace=False)
(3): Linear(
in_features=3072, out_features=768, bias=True
(sa): SelfAttention(
(query_conv): Conv2d(768, 96, kernel_size=(1, 1), stride=(1, 1))
(key_conv): Conv2d(768, 96, kernel_size=(1, 1), stride=(1, 1))
(value_conv): Conv2d(768, 768, kernel_size=(1, 1), stride=(1, 1))
)
)
(4): Dropout(p=0.0, inplace=False)
)
)
(1): SwinTransformerBlock(
(norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(attn): ShiftedWindowAttention(
(qkv): Linear(in_features=768, out_features=2304, bias=True)
(proj): Linear(in_features=768, out_features=768, bias=True)
)
(stochastic_depth): StochasticDepth(p=0.2, mode=row)
(norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(mlp): MLP(
(0): Linear(in_features=768, out_features=3072, bias=True)
(1): GELU(approximate='none')
(2): Dropout(p=0.0, inplace=False)
(3): Linear(in_features=3072, out_features=768, bias=True)
(4): Dropout(p=0.0, inplace=False)
)
)
)
)
(norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(permute): Permute()
(avgpool): AdaptiveAvgPool2d(output_size=1)
(flatten): Flatten(start_dim=1, end_dim=-1)
(head): Linear(in_features=768, out_features=10, bias=True)
)
输出size:torch.Size([1, 10])