VSSM VMamba实现

文章目录

    • VSSM
      • 维度变换
      • 初始化
      • 模型参数初始化
      • 模型搭建
        • def_make_layer
        • def _make_downsample
      • patch embed
      • 第一至四阶段
      • 分类器
    • VSSBlock
      • def __ init__
        • ssm分支
        • mlp分支
      • def forward

VSSM

Mamba实现可以参照之前的
mamba_minimal系列
论文地址:
VMamba
论文阅读:
VMamba:视觉状态空间模型
代码地址:
https://github.com/MzeroMiko/VMamba.git
SS2D实现

以分类任务用到的VMamba为例。

维度变换

操作的具体参数定义见初始化

阶段维度
输入x [ B , C , H , W ] [B, C, H, W] [B,C,H,W]
embed [ B , H / 4 , W / 4 , C 1 ] [B, H/4, W/4, C_1 ] [B,H/4,W/4,C1]
阶段1 [ B , H / 4 , W / 4 , C 1 ] [B, H/4, W/4, C_1 ] [B,H/4,W/4,C1]
阶段2 [ B , H / 8 , W / 8 , C 2 ] [B, H/8, W/8, C_2 ] [B,H/8,W/8,C2]
阶段3 [ B , H / 16 , W / 16 , C 3 ] [B, H/16, W/16, C_3 ] [B,H/16,W/16,C3]
阶段4 [ B , H / 32 , W / 32 , C 4 ] [B, H/32, W/32, C_4 ] [B,H/32,W/32,C4]
分类器 [ B , 1000 ] [B, 1000 ] [B,1000]

在这里插入图片描述

初始化

参数定义说明
in_chans3输入图像的通道数
depths[2, 2, 9, 2]定义每层的VSS Block数
dims[96, 192, 384, 768]定义每层的输出通道数
downsample_versionv2下采样操作的版本
patchembed_versionv1图像嵌入
mlp_ratio4.0定义mlp隐藏维度缩放
ssm_d_state16ssm隐状态的维度
ssm_ratio2.0d_inner = d_state * ssm_ratio
ssm_initv0ssm初始化版本
forward_typev2ssm前向版本

模型参数初始化

大部分参数即SS2D,VSS块中的参数由定义的ssm初始化版本初始化,剩下的线性层和归一化层参数由下面的函数初始化。

    def _init_weights(self, m: nn.Module):if isinstance(m, nn.Linear):trunc_normal_(m.weight, std=.02)if isinstance(m, nn.Linear) and m.bias is not None:nn.init.constant_(m.bias, 0)elif isinstance(m, nn.LayerNorm):nn.init.constant_(m.bias, 0)nn.init.constant_(m.weight, 1.0)

模型搭建

def_make_layer

构建VSSM的4个阶段,即4层,VSSBlock本身并不改变输入的尺寸,因此需要下采样模块将输出维度变换为下一阶段的输入维度

def _make_layer(dim=96, drop_path=[0.1, 0.1], use_checkpoint=False, norm_layer=nn.LayerNorm,downsample=nn.Identity(),# ===========================ssm_d_state=16,ssm_ratio=2.0,ssm_dt_rank="auto",       ssm_act_layer=nn.SiLU,ssm_conv=3,ssm_conv_bias=True,ssm_drop_rate=0.0, ssm_init="v0",forward_type="v2",# ===========================mlp_ratio=4.0,mlp_act_layer=nn.GELU,mlp_drop_rate=0.0,**kwargs,):depth = len(drop_path)blocks = []for d in range(depth):blocks.append(VSSBlock(hidden_dim=dim, drop_path=drop_path[d],norm_layer=norm_layer,ssm_d_state=ssm_d_state,ssm_ratio=ssm_ratio,ssm_dt_rank=ssm_dt_rank,ssm_act_layer=ssm_act_layer,ssm_conv=ssm_conv,ssm_conv_bias=ssm_conv_bias,ssm_drop_rate=ssm_drop_rate,ssm_init=ssm_init,forward_type=forward_type,mlp_ratio=mlp_ratio,mlp_act_layer=mlp_act_layer,mlp_drop_rate=mlp_drop_rate,use_checkpoint=use_checkpoint,))return nn.Sequential(OrderedDict(blocks=nn.Sequential(*blocks,),downsample=downsample,))
def _make_downsample

默认下采样版本v2

下采样模块,通过2D卷积之后,长宽变为原来的一半,通道数不变

    def _make_downsample(dim=96, out_dim=192, norm_layer=nn.LayerNorm):return nn.Sequential(Permute(0, 3, 1, 2),nn.Conv2d(dim, out_dim, kernel_size=2, stride=2),Permute(0, 2, 3, 1),norm_layer(out_dim),)

patch embed

默认嵌入版本v1,对输入图像进行embed

输入x维度 [ B , 3 , H , W ] [B, 3, H, W] [B,3,H,W],嵌入后通道维变为96, H = H p a t c h _ s i z e H = \frac{H}{patch\_size} H=patch_sizeH W = W p a t c h _ s i z e W = \frac{W}{patch\_size} W=patch_sizeW [ B , 96 , H 4 , W 4 ] [B, 96, \frac{H}{4}, \frac{W}{4}] [B,96,4H,4W]

 def _make_patch_embed(in_chans=3, embed_dim=96, patch_size=4, patch_norm=True, norm_layer=nn.LayerNorm):return nn.Sequential(nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=True),Permute(0, 2, 3, 1),(norm_layer(embed_dim) if patch_norm else nn.Identity()), )

第一至四阶段

这几个阶段的差别在于每一层的VSSBlock数不同,由depths定义分别为 [2, 2, 9, 2],输出维度由dims定义分别为[96, 192, 384, 768]。其组成元素除一阶段外,均在VSSBlock前包含下采样模块以变换维度。

具体介绍见VSSBlock

分类器

池化后长宽变为1,则变量尺寸变为 [ B , C , 1 , 1 ] [B, C, 1, 1] [B,C,1,1],展平后变为 [ B , C ] [B, C] [B,C]最后线性投影到类别维度1000

[ B , 1000 ] [B, 1000] [B,1000]

self.classifier = nn.Sequential(OrderedDict(norm=norm_layer(self.num_features), # B,H,W,Cpermute=Permute(0, 3, 1, 2),avgpool=nn.AdaptiveAvgPool2d(1),flatten=nn.Flatten(1),head=nn.Linear(self.num_features, num_classes),))

VSSBlock

对于ssm分支来说,其输入输出维度不变为(B, H, W, d_model) ,对于mlp分支来说中间的隐藏维度根据mlp_ratio参数定义会有所增加,但是最后又会映射为原来的维度,因此整体上并不改变输入的维度。

def __ init__

主要分为两个分支ssm分支和mlp分支

ssm分支

主要组成部分是SS2D块
SS2D实现

if self.ssm_branch:self.norm = norm_layer(hidden_dim)self.op = _SS2D(d_model=hidden_dim, d_state=ssm_d_state, ssm_ratio=ssm_ratio,dt_rank=ssm_dt_rank,act_layer=ssm_act_layer,# ==========================d_conv=ssm_conv,conv_bias=ssm_conv_bias,# ==========================dropout=ssm_drop_rate,# =========================initialize=ssm_init,forward_type=forward_type,)      

图中的SS2D和SS2D类的定义有偏差,简单来说是是包含SS2D块加一个残差连接,图中所示SS2D应表示状态空间模型SSM部分,即VSS块相比SS2D块只增加了残差连接和入口的归一化。如果定义了MLP分支,VSS块的输出还会经过一个残差连接的两层MLP

在这里插入图片描述

mlp分支
 if self.mlp_branch:self.norm2 = norm_layer(hidden_dim)mlp_hidden_dim = int(hidden_dim * mlp_ratio)self.mlp = Mlp(in_features=hidden_dim, hidden_features=mlp_hidden_dim, act_layer=mlp_act_layer, drop=mlp_drop_rate, channels_first=False)class Mlp(nn.Module):def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.,channels_first=False):super().__init__()out_features = out_features or in_featureshidden_features = hidden_features or in_featuresLinear = partial(nn.Conv2d, kernel_size=1, padding=0) if channels_first else nn.Linearself.fc1 = Linear(in_features, hidden_features)self.act = act_layer()self.fc2 = Linear(hidden_features, out_features)self.drop = nn.Dropout(drop)def forward(self, x):x = self.fc1(x)x = self.act(x)x = self.drop(x)x = self.fc2(x)x = self.drop(x)return x

def forward

    def _forward(self, input: torch.Tensor):if self.ssm_branch:if self.post_norm:x = input + self.drop_path(self.norm(self.op(input)))else:x = input + self.drop_path(self.op(self.norm(input)))if self.mlp_branch:if self.post_norm:x = x + self.drop_path(self.norm2(self.mlp(x))) # FFNelse:x = x + self.drop_path(self.mlp(self.norm2(x))) # FFNreturn xdef forward(self, input: torch.Tensor):if self.use_checkpoint:return checkpoint.checkpoint(self._forward, input)else:return self._forward(input)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/745334.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

css3常见选择器

使用工具 Visual Studio Code 1.CSS3基础选择器 1.1 标签选择器 1.2.1 标签选择器的语法 一个完整的HTML5页面是由很多不同的标签组成的,而标签选择器则决定标签应采用的CSS样式,语法如下:标签名{ 属性1:属性值1; 属性2&…

Vscode 修改C++版本

1. 首先要检查GCC版本,有的gcc版本过低会导致C版本升级不成功 可以用cmd,用gcc --version命令查看gcc版本 我这里就是gcc版本较低,不支持c17 需要先升级gcc版本 gcc与c对应的版本,大家可以在这位大佬的博客中看,写…

python 读取pdf 将每页转成jpg

需要安装fitz pip install PyMuPDF 这里我发现了问题,默认安装最新版本1.21.x 但是不支持大部分网上的api 所以分开两部分 1.21.x的 import fitz # PyMuPDF from PIL import Imagedef extract_images_from_tiff(tiff_path, output_folder):# 打开 TIFF 文件pdf_document f…

经典排序算法之计数排序|c++代码实现

引言 排序算法c实现系列第8弹——计数排序。 计数排序是理解起来相对简单的一个排序算法, 计数排序 计数排序(Counting Sort)是一种非比较型的排序算法,它的基本思想是统计待排序数组中每个元素的出现次数,然后根据…

基于grafana+elk等开源组件的 云服务监控大屏架构

本套大屏,在某云服务大规模测试环境,良好运行3年. 本文主要展示这套监控大屏的逻辑架构.不做具体操作与配置的解释. 监控主要分为三部分: 数据展示部分数据存储数据采集 1. 数据展示 数据展示方面主要使用grafana 2. 数据存储 根据数据种类和特性和用途的不同,本套监控采…

Intelli idea 自带maven路径和配置

自带maven位于:plugins/maven/lib/maven3 Mac配置maven环境变量: #maven export MAVEN_HOME/maven根路径 export PATH$MAVEN_HOME/bin:$PATH#刷新环境变量 source ~/.bash_profile#查看maven版本 mvn -version#查看依赖树 mvn dependency:tree 配置ma…

django-q轻量级定时任务制定

django-q ,celery,apschedule都可以作为python的选型,但是django-q更轻量级,可以定制想要的任务,通过消息中间件,来实现不太高并发的实现 官网介绍地址 django-q官网地址 本次测试的是python3.12版本 首先需要安装dja…

几何相互作用GNN预测3D-PLA

预测PLA是药物发现中的核心问题。最近的进展显示了将ML应用于PLA预测的巨大潜力。然而,它们大多忽略了复合物的3D结构和蛋白质与配体之间的物理相互作用,而这对于理解结合机制至关重要。作者提出了一种结合3D结构和物理相互作用的几何相互作用图神经网络GIGN,用于预测蛋白质…

架构实战--以海量存储系统讲解热门话题:分布式概念

关注我,持续分享逻辑思维&管理思维; 可提供大厂面试辅导、及定制化求职/在职/管理/架构辅导; 有意找工作的同学,请参考博主的原创:《面试官心得--面试前应该如何准备》,《面试官心得--面试时如何进行自…

Nodejs 第五十七章(addon)

Nodejs在IO方面拥有极强的能力,但是对CPU密集型任务,会有不足,为了填补这方面的缺点,Nodejs支持c/c为其编写原生nodejs插件,补充这方面的能力。 Nodejs c扩展 c编写的代码能够被编译成一个动态链接库(dll),可以被nod…

VMware workstation的安装

VMware workstation安装: 1.双击VMware-workstation-full-9.0.0-812388.exe 2.点击next进行安装 选择安装方式 Typical:典型安装 Custom:自定义安装 选择程序安装位置 点击change选择程序安装位置,然后点击next 选择是否自动…

vue 如何实现手机横屏功能

功能背景 有些需求需要手动实现横屏功能,比如点击一个横屏的图标将整个页面90度翻转,再点击退出横屏回到正常页面。 实现思路 一拿到这个需求很可能就被吓到了,但简单来说,就是点击横屏按钮跳转一个新页面,通过 cs…

MySQL模块---安装并配置

1. 在项目中操作数据库的步骤 ① 安装操作 MySQL 数据库的第三方模块(mysql) ② 通过 mysql 模块链接到 MySQL 数据库 ③ 通过 mysql 模块执行 SQL 语句 2. 安装 mysql 模块 这里要安装的是 mysql2 也就是 mysql 8.0后面的版本 npm init -y npm…

手动创建线程池各个参数的意义?

今天我们学习线程池各个参数的含义,并重点掌握线程池中线程是在什么时机被创建和销毁的。 线程池的参数 首先,我们来看下线程池中各个参数的含义,如表所示线程池主要有 6 个参数,其中第 3 个参数由 keepAliveTime 时间单位组成。…

【Linux】linuxCNC+Qt+Opencascade+kdl+hal 实时6轴机器人控制器

CNC机器人 程序框架 机器人模型 笔记: debian重启后 无法打开共享目录 最新版搜狗输入法安装后不支持中文,需要安装旧版本的 sogoupinyin_4.0.1.2800_x86_64.deb可用 数控机器人在哪些领域应用有优势 数控机器人在多个领域都展现出了显著的优势&#xff…

介绍一下redis中底层磁盘及IO模型,数据持久化机制,哨兵机制

底层磁盘及IO模型: Redis中的数据存储在内存中,但为了保证数据的持久化,Redis还提供了两种数据持久化方式:RDB(Redis DataBase)和AOF(Append-Only File)。 RDB:RDB是一种…

PyQt4应用程序的PDF查看器

最近因为项目需要创建一个基于PyQt4的PDF查看器应用程序,正常来说,我们可以使用PyQt4的QtWebKit模块来显示PDF文件。那么具体怎么实现呢 ?以下就是我写的一个简单的示例代码,演示如何创建一个PyQt4应用程序的PDF查看器&#xff1a…

SQL笔记 -- 黑马程序员

SQL目录 文章目录 SQL目录一、SQL分类1、DDL2、数据类型3、DML4、DQL1)基本查询2)条件查询3)聚合函数查询4)分组查询5)排序查询6)分页查询 5、DCL 一、SQL分类 分类说明DDL数据定义语言,用来定…

MySQL order by 语句执行流程

全字段排序 假设这个表的部分定义是这样的: CREATE TABLE t (id int(11) NOT NULL,city varchar(16) NOT NULL,name varchar(16) NOT NULL,age int(11) NOT NULL,addr varchar(128) DEFAULT NULL,PRIMARY KEY (id),KEY city (city) ) ENGINEInnoDB; 有如下 SQL 语…

蓝桥杯2023年-三国游戏(贪心)

题目描述 小蓝正在玩一款游戏。游戏中魏蜀吴三个国家各自拥有一定数量的士兵X, Y, Z (一开始可以认为都为 0 )。游戏有 n 个可能会发生的事件,每个事件之间相互独立且最多只会发生一次,当第 i 个事件发生时会分别让 X, Y, Z 增加Ai , Bi ,Ci 。 当游戏…