SwinTransformer 改进:添加SelfAttention自注意力层

目录

1. SelfAttention自注意力层

2. SwinTransformer + SelfAttention

3. 代码


1. SelfAttention自注意力层

Self-Attention自注意力层是一种在神经网络中用于处理序列数据的注意力机制。它通过对输入序列中的不同位置进行关注,来计算每个位置与其他位置的关联程度,并根据这些关联程度对输入序列进行加权。

自注意力层的计算过程如下:

  1. 首先,通过对输入序列中的每一对位置计算一个相关性得分。这可以通过计算输入序列中两个位置之间的点积来实现。得分越高表示两个位置之间的相关性越强。
  2. 然后,对得分进行归一化处理,以确保它们的总和为1。这可以通过将得分除以一个较大的数值来实现,以避免过大的得分。
  3. 接下来,将归一化后的得分与输入序列进行加权求和,得到自注意力层的输出。加权求和时,得分越高的位置对应的向量将会被分配更大的权重。

自注意力层的优势在于它能够利用序列中的局部和全局信息,从而更好地捕捉序列中不同位置之间的依赖关系。在自然语言处理领域中,自注意力层被广泛应用于机器翻译、文本分类和阅读理解等任务中。

实现代码如下:

# 定义自注意力层
class SelfAttention(nn.Module):def __init__(self, in_channels):super(SelfAttention, self).__init__()self.query_conv = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)self.key_conv = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)self.value_conv = nn.Conv2d(in_channels, in_channels, kernel_size=1)self.gamma = nn.Parameter(torch.zeros(1))def forward(self, x):batch_size, channels, height, width = x.size()query = self.query_conv(x).view(batch_size, -1, height * width).permute(0, 2, 1)key = self.key_conv(x).view(batch_size, -1, height * width)energy = torch.bmm(query, key)attention = torch.softmax(energy, dim=-1)value = self.value_conv(x).view(batch_size, -1, height * width)out = torch.bmm(value, attention.permute(0, 2, 1))out = out.view(batch_size, channels, height, width)out = self.gamma * out + xreturn out

想要完整的分类代码,请参考 本章 ,将下文的model替换即可

2. SwinTransformer + SelfAttention

SwinTransformer 网络结构如下:

本文在 SwinTransformer 最后一个SwinTransformerBlock 添加SelfAttention模块

添加如下:其中sa部分就是添加的模块

  (0): SwinTransformerBlock((norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)(attn): ShiftedWindowAttention((qkv): Linear(in_features=768, out_features=2304, bias=True)(proj): Linear(in_features=768, out_features=768, bias=True))(stochastic_depth): StochasticDepth(p=0.18181818181818182, mode=row)(norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)(mlp): MLP((0): Linear(in_features=768, out_features=3072, bias=True(sa): SelfAttention((query_conv): Conv2d(3072, 384, kernel_size=(1, 1), stride=(1, 1))(key_conv): Conv2d(3072, 384, kernel_size=(1, 1), stride=(1, 1))(value_conv): Conv2d(3072, 3072, kernel_size=(1, 1), stride=(1, 1))))(1): GELU(approximate='none')(2): Dropout(p=0.0, inplace=False)(3): Linear(in_features=3072, out_features=768, bias=True(sa): SelfAttention((query_conv): Conv2d(768, 96, kernel_size=(1, 1), stride=(1, 1))(key_conv): Conv2d(768, 96, kernel_size=(1, 1), stride=(1, 1))(value_conv): Conv2d(768, 768, kernel_size=(1, 1), stride=(1, 1))))(4): Dropout(p=0.0, inplace=False)

3. 代码

完整代码:

import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'import torch
import torch.nn as nn
import torchvision.models as m# 定义自注意力层
class SelfAttention(nn.Module):def __init__(self, in_channels):super(SelfAttention, self).__init__()self.query_conv = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)self.key_conv = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)self.value_conv = nn.Conv2d(in_channels, in_channels, kernel_size=1)self.gamma = nn.Parameter(torch.zeros(1))def forward(self, x):batch_size, channels, height, width = x.size()query = self.query_conv(x).view(batch_size, -1, height * width).permute(0, 2, 1)key = self.key_conv(x).view(batch_size, -1, height * width)energy = torch.bmm(query, key)attention = torch.softmax(energy, dim=-1)value = self.value_conv(x).view(batch_size, -1, height * width)out = torch.bmm(value, attention.permute(0, 2, 1))out = out.view(batch_size, channels, height, width)out = self.gamma * out + xreturn out# 获取网络模型
def create_model(model,num,weights):if model == 't':net = m.swin_t(weights=m.Swin_T_Weights.DEFAULT if weights else False,progress=True)elif model == 's':net = m.swin_s(weights=m.Swin_S_Weights.DEFAULT if weights else False,progress=True)elif model == 'b':net = m.swin_b(weights=m.Swin_B_Weights.DEFAULT if weights else False,progress=True)else:print('模型选择错误!!')return Nonetmp = net.head.in_featuresnet.head = torch.nn.Linear(tmp,num,bias=True)# 添加模块net.features[7][0].mlp[0].add_module('sa',SelfAttention(list(net.features)[7][0].mlp[0].out_features))net.features[7][0].mlp[3].add_module('sa',SelfAttention(list(net.features)[7][0].mlp[3].out_features))print(net)return netif __name__ == '__main__':model = create_model(model='t',num=10,weights=False)i = torch.randn(1,3,224,224)o = model(i)print(o.size())

网络结构:

SwinTransformer(
  (features): Sequential(
    (0): Sequential(
      (0): Conv2d(3, 96, kernel_size=(4, 4), stride=(4, 4))
      (1): Permute()
      (2): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
    )
    (1): Sequential(
      (0): SwinTransformerBlock(
        (norm1): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
        (attn): ShiftedWindowAttention(
          (qkv): Linear(in_features=96, out_features=288, bias=True)
          (proj): Linear(in_features=96, out_features=96, bias=True)
        )
        (stochastic_depth): StochasticDepth(p=0.0, mode=row)
        (norm2): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
        (mlp): MLP(
          (0): Linear(in_features=96, out_features=384, bias=True)
          (1): GELU(approximate='none')
          (2): Dropout(p=0.0, inplace=False)
          (3): Linear(in_features=384, out_features=96, bias=True)
          (4): Dropout(p=0.0, inplace=False)
        )
      )
      (1): SwinTransformerBlock(
        (norm1): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
        (attn): ShiftedWindowAttention(
          (qkv): Linear(in_features=96, out_features=288, bias=True)
          (proj): Linear(in_features=96, out_features=96, bias=True)
        )
        (stochastic_depth): StochasticDepth(p=0.018181818181818184, mode=row)
        (norm2): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
        (mlp): MLP(
          (0): Linear(in_features=96, out_features=384, bias=True)
          (1): GELU(approximate='none')
          (2): Dropout(p=0.0, inplace=False)
          (3): Linear(in_features=384, out_features=96, bias=True)
          (4): Dropout(p=0.0, inplace=False)
        )
      )
    )
    (2): PatchMerging(
      (reduction): Linear(in_features=384, out_features=192, bias=False)
      (norm): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
    )
    (3): Sequential(
      (0): SwinTransformerBlock(
        (norm1): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
        (attn): ShiftedWindowAttention(
          (qkv): Linear(in_features=192, out_features=576, bias=True)
          (proj): Linear(in_features=192, out_features=192, bias=True)
        )
        (stochastic_depth): StochasticDepth(p=0.03636363636363637, mode=row)
        (norm2): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
        (mlp): MLP(
          (0): Linear(in_features=192, out_features=768, bias=True)
          (1): GELU(approximate='none')
          (2): Dropout(p=0.0, inplace=False)
          (3): Linear(in_features=768, out_features=192, bias=True)
          (4): Dropout(p=0.0, inplace=False)
        )
      )
      (1): SwinTransformerBlock(
        (norm1): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
        (attn): ShiftedWindowAttention(
          (qkv): Linear(in_features=192, out_features=576, bias=True)
          (proj): Linear(in_features=192, out_features=192, bias=True)
        )
        (stochastic_depth): StochasticDepth(p=0.05454545454545456, mode=row)
        (norm2): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
        (mlp): MLP(
          (0): Linear(in_features=192, out_features=768, bias=True)
          (1): GELU(approximate='none')
          (2): Dropout(p=0.0, inplace=False)
          (3): Linear(in_features=768, out_features=192, bias=True)
          (4): Dropout(p=0.0, inplace=False)
        )
      )
    )
    (4): PatchMerging(
      (reduction): Linear(in_features=768, out_features=384, bias=False)
      (norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    )
    (5): Sequential(
      (0): SwinTransformerBlock(
        (norm1): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
        (attn): ShiftedWindowAttention(
          (qkv): Linear(in_features=384, out_features=1152, bias=True)
          (proj): Linear(in_features=384, out_features=384, bias=True)
        )
        (stochastic_depth): StochasticDepth(p=0.07272727272727274, mode=row)
        (norm2): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
        (mlp): MLP(
          (0): Linear(in_features=384, out_features=1536, bias=True)
          (1): GELU(approximate='none')
          (2): Dropout(p=0.0, inplace=False)
          (3): Linear(in_features=1536, out_features=384, bias=True)
          (4): Dropout(p=0.0, inplace=False)
        )
      )
      (1): SwinTransformerBlock(
        (norm1): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
        (attn): ShiftedWindowAttention(
          (qkv): Linear(in_features=384, out_features=1152, bias=True)
          (proj): Linear(in_features=384, out_features=384, bias=True)
        )
        (stochastic_depth): StochasticDepth(p=0.09090909090909091, mode=row)
        (norm2): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
        (mlp): MLP(
          (0): Linear(in_features=384, out_features=1536, bias=True)
          (1): GELU(approximate='none')
          (2): Dropout(p=0.0, inplace=False)
          (3): Linear(in_features=1536, out_features=384, bias=True)
          (4): Dropout(p=0.0, inplace=False)
        )
      )
      (2): SwinTransformerBlock(
        (norm1): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
        (attn): ShiftedWindowAttention(
          (qkv): Linear(in_features=384, out_features=1152, bias=True)
          (proj): Linear(in_features=384, out_features=384, bias=True)
        )
        (stochastic_depth): StochasticDepth(p=0.10909090909090911, mode=row)
        (norm2): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
        (mlp): MLP(
          (0): Linear(in_features=384, out_features=1536, bias=True)
          (1): GELU(approximate='none')
          (2): Dropout(p=0.0, inplace=False)
          (3): Linear(in_features=1536, out_features=384, bias=True)
          (4): Dropout(p=0.0, inplace=False)
        )
      )
      (3): SwinTransformerBlock(
        (norm1): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
        (attn): ShiftedWindowAttention(
          (qkv): Linear(in_features=384, out_features=1152, bias=True)
          (proj): Linear(in_features=384, out_features=384, bias=True)
        )
        (stochastic_depth): StochasticDepth(p=0.1272727272727273, mode=row)
        (norm2): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
        (mlp): MLP(
          (0): Linear(in_features=384, out_features=1536, bias=True)
          (1): GELU(approximate='none')
          (2): Dropout(p=0.0, inplace=False)
          (3): Linear(in_features=1536, out_features=384, bias=True)
          (4): Dropout(p=0.0, inplace=False)
        )
      )
      (4): SwinTransformerBlock(
        (norm1): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
        (attn): ShiftedWindowAttention(
          (qkv): Linear(in_features=384, out_features=1152, bias=True)
          (proj): Linear(in_features=384, out_features=384, bias=True)
        )
        (stochastic_depth): StochasticDepth(p=0.14545454545454548, mode=row)
        (norm2): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
        (mlp): MLP(
          (0): Linear(in_features=384, out_features=1536, bias=True)
          (1): GELU(approximate='none')
          (2): Dropout(p=0.0, inplace=False)
          (3): Linear(in_features=1536, out_features=384, bias=True)
          (4): Dropout(p=0.0, inplace=False)
        )
      )
      (5): SwinTransformerBlock(
        (norm1): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
        (attn): ShiftedWindowAttention(
          (qkv): Linear(in_features=384, out_features=1152, bias=True)
          (proj): Linear(in_features=384, out_features=384, bias=True)
        )
        (stochastic_depth): StochasticDepth(p=0.16363636363636364, mode=row)
        (norm2): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
        (mlp): MLP(
          (0): Linear(in_features=384, out_features=1536, bias=True)
          (1): GELU(approximate='none')
          (2): Dropout(p=0.0, inplace=False)
          (3): Linear(in_features=1536, out_features=384, bias=True)
          (4): Dropout(p=0.0, inplace=False)
        )
      )
    )
    (6): PatchMerging(
      (reduction): Linear(in_features=1536, out_features=768, bias=False)
      (norm): LayerNorm((1536,), eps=1e-05, elementwise_affine=True)
    )
    (7): Sequential(
      (0): SwinTransformerBlock(
        (norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): ShiftedWindowAttention(
          (qkv): Linear(in_features=768, out_features=2304, bias=True)
          (proj): Linear(in_features=768, out_features=768, bias=True)
        )
        (stochastic_depth): StochasticDepth(p=0.18181818181818182, mode=row)
        (norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): MLP(
          (0): Linear(
            in_features=768, out_features=3072, bias=True
            (sa): SelfAttention(
              (query_conv): Conv2d(3072, 384, kernel_size=(1, 1), stride=(1, 1))
              (key_conv): Conv2d(3072, 384, kernel_size=(1, 1), stride=(1, 1))
              (value_conv): Conv2d(3072, 3072, kernel_size=(1, 1), stride=(1, 1))
            )
          )
          (1): GELU(approximate='none')
          (2): Dropout(p=0.0, inplace=False)
          (3): Linear(
            in_features=3072, out_features=768, bias=True
            (sa): SelfAttention(
              (query_conv): Conv2d(768, 96, kernel_size=(1, 1), stride=(1, 1))
              (key_conv): Conv2d(768, 96, kernel_size=(1, 1), stride=(1, 1))
              (value_conv): Conv2d(768, 768, kernel_size=(1, 1), stride=(1, 1))
            )
          )
          (4): Dropout(p=0.0, inplace=False)
        )
      )
      (1): SwinTransformerBlock(
        (norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): ShiftedWindowAttention(
          (qkv): Linear(in_features=768, out_features=2304, bias=True)
          (proj): Linear(in_features=768, out_features=768, bias=True)
        )
        (stochastic_depth): StochasticDepth(p=0.2, mode=row)
        (norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): MLP(
          (0): Linear(in_features=768, out_features=3072, bias=True)
          (1): GELU(approximate='none')
          (2): Dropout(p=0.0, inplace=False)
          (3): Linear(in_features=3072, out_features=768, bias=True)
          (4): Dropout(p=0.0, inplace=False)
        )
      )
    )
  )
  (norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  (permute): Permute()
  (avgpool): AdaptiveAvgPool2d(output_size=1)
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (head): Linear(in_features=768, out_features=10, bias=True)
)

输出size:torch.Size([1, 10])

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/64824.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

c++ ------语句

一、简单语句 简单语句是C中最基本的语句单元,通常以分号(;)结尾,用于执行一个单一的操作。常见的简单语句类型有: 表达式语句:由一个表达式后面加上分号构成,用于计算表达式的值或者执行具有…

【他山之石】The SVG path Syntax: An Illustrated Guide:SVG 中的 path 语法图解指南

写在前面 本文为我的自学精译专栏《CSS in Depth 2》第 086 篇文章、在介绍 CSS 的 clip-path 属性的用法时作者提到的一篇延伸阅读材料,以图文并茂的形式系统梳理了 SVG path 属性的方方面面。其中最为精彩的是文中列举的大量使用案例。为了方便查找,特…

小型 Vue 项目,该不该用 Pinia 、Vuex呢?

说到 Vue3 的状态管理,我们会第一时间想到 Pinia、Vuex,但是经过很长一段时间的 Vue3 项目开发,我逐渐发现,我们真的有必要用 Pinia、Vuex 这类的状态管理工具吗? 带着这样的疑惑,我首先是想知道一下 Pini…

c4d动画怎么导出mp4视频,c4d动画视频格式设置

宝子们,今天来给大家讲讲 C4D 咋导出mp4视频的方法。通过用图文教程的形式给大家展示得明明白白的,让你能轻松理解和掌握,不管是理论基础,还是实际操作和技能技巧,都能学到,快速入门然后提升自己哦。 c4d动…

EfficienetAD异常值检测之瓷砖表面缺陷检测(免费下载测试数据集和模型)

背景 当今制造业蓬勃发展,产品质量把控至关重要。从精密电子元件到大型工业板材,表面缺陷哪怕细微,都可能引发性能故障或外观瑕疵。人工目视检测耗时费力且易漏检,已无法适应高速生产线节奏。在此背景下,表面缺陷异常…

将Minio设置为Django的默认Storage(django-storages)

这里写自定义目录标题 前置说明静态文件收集静态文件 使用django-storages来使Django集成Minio安装依赖settings.py测试收集静态文件测试媒体文件 前置说明 静态文件 Django默认的Storage是本地,项目中的CSS、图片、JS都是静态文件。一般会将静态文件放到一个单独…

Redis生产实践中相关疑问记录

1. Redis相关疑问 1.1. redis内存使用率100% 就等同于redis不可用吗? 正常使用情况下,不是。 redis有【缓存淘汰机制】,Redis 在内存使用率达到 100% 时不会直接崩溃。相反,它依赖内存淘汰策略来释放内存,确保系统的…

量化交易——RSI策略(vectorbt实现)

本文为通过vectorbt(以下简称vbt)实现量化交易系列第一篇文章,通过使用vbt实现RSI策略从而熟悉其代码框架。 关于本文所使用数据的说明 由于vbt官方文档提供的入门案例使用的数据是通过其内置的yfinance包获取,在国内无法直接访…

本地摄像头视频流在html中打开

1.准备ffmpeg 和(rtsp-simple-server srs搭建流媒体服务器)视频服务器. 2.解压视频流服务器修改配置文件mediamtx.yml ,hlsAlwaysRemux: yes 3.双击运行服务器。 4,安装ffmpeg ,添加到环境变量。 5.查询本机设备列表 ffmpeg -list_devices true -f dshow -i d…

unipp中使用阿里图标,以及闭坑指南

-----------------------------------------------------点赞收藏才是更新的动力------------------------------------------------- unipp中使用阿里图标 官网下载图标在项目中引入使用注意事项 官网下载图标 进入阿里图标网站 将需要下载的图标添加到购物车中 2. 直接下载…

设计模式の享元模板代理模式

文章目录 前言一、享元模式二、模板方法模式三、代理模式3.1、静态代理3.2、JDK动态代理3.3、Cglib动态代理3.4、小结 前言 本篇是关于设计模式中享元模式、模板模式、以及代理模式的学习笔记。 一、享元模式 享元模式是一种结构型设计模式,目的是为了相似对象的复用…

flink实现复杂kafka数据读取

接上文:一文说清flink从编码到部署上线 环境说明:MySQL:5.7;flink:1.14.0;hadoop:3.0.0;操作系统:CentOS 7.6;JDK:1.8.0_401。 常见的文章中&…

越疆科技营收增速放缓:毛利率未恢复,持续亏损下销售费用偏高

《港湾商业观察》施子夫 12月13日,深圳市越疆科技股份有限公司(以下简称,越疆科技,02432.HK)发布全球发售公告,公司计划全球发售4000万股股份,其中3800万股国际发售,200万股香港公开…

datasets 笔记:加载数据集(基本操作)

参考了huggingface的教程 1 了解数据集基本信息( load_dataset_builder) 在下载数据集之前,通常先快速了解数据集的基本信息会很有帮助。数据集的信息存储在 DatasetInfo 中,可能包括数据集描述、特征和数据集大小等信息。&…

Java图片拼接

最近遇到一个挺离谱的功能,某个表单只让上传一张图,多图上传会使导出失败。跟开发沟通后表示,这个问题处理不了。我... 遂自己思考,能否以曲线救国的方式拯救一下,即不伤及代码之根本,又能解决燃眉之急。灵…

.NET重点

B/S C/S什么语言 B/S: 浏览器端:JavaScript,HTML,CSS 服务器端:ASP(.NET)PHP/JSP 优势:维护方便,易于升级和扩展 劣势:服务器负担沉重 C/S java/.NET/…

Linux——卷

Linux——卷 介绍 最近做的项目,涉及到对系统的一些维护,有些盘没有使用,需要创建逻辑盘并挂载到指定目录下。有些软件需要依赖空的逻辑盘(LVM)。 先简单介绍一下卷的一些概念,有分区、物理存储介质、物…

M3D: 基于多模态大模型的新型3D医学影像分析框架,将3D医学图像分析从“看图片“提升到“理解空间“的层次,支持检索、报告生成、问答、定位和分割等8类任务

M3D: 基于多模态大模型的新型3D医学影像分析框架,将3D医学图像分析从“看图片“提升到“理解空间“的层次,支持检索、报告生成、问答、定位和分割等8类任务 论文大纲理解1. 确认目标2. 分析过程(目标-手段分析)核心问题拆解 3. 实…

clickhouse-副本和分片

1、副本 1.1、概述 集群是副本和分片的基础,它将ClickHouse的服务拓扑由单节点延伸到多个节点,但它并不像Hadoop生态的某些系统那样,要求所有节点组成一个单一的大集群。ClickHouse的集群配置非常灵活,用户既可以将所有节点组成…

Redis 集群实操:强大的数据“分身术”

目录 Redis Cluster集群模式 1、介绍 2、架构设计 3、集群模式实操 4、故障转移 5、常用命令 Redis Cluster集群模式 1、介绍 redis3.0版本推出的Redis Cluster 集群模式,每个节点都可以保存数据和整个集群状态,每个节点都和其他所有节点连接。Cl…