双向SSM: Vision Mamba Encoder

文章目录

  • Vision Mamba Encoder
    • 初始化
      • 输入映射
      • 序列变换
      • 参数映射
        • BC参数映射
        • delta参数映射
      • SSM参数初始化
        • A , D矩阵初始化
        • delta参数初始化
      • 双向SSM初始化
        • 参数初始化
    • 前向
      • 输入映射
      • fast_path
        • use_fast_path
        • no use_fast_path
    • 双向SSM
      • v1
        • 前向
        • 后向
      • v2
        • 前向
        • 后向

Vision Mamba Encoder

Vision Mamba的编码器部分,也位于Vim模型的中间和主要部分。由多个Mamba块堆叠而成,VisionMamba的Mamba块是在原始论文MambaBlock上修改,特别的地方在于其双向SSM机制。双向与数据流动方向无关,并不是指网络中存在反馈回路,而是等价的扫描方向有两种。

在这里插入图片描述

初始化

输入映射

首先,还是一个标准的输入映射,这一点没有更改,输入映射用来得到门控变量z和主干变量x,其中x的维度d_model扩充到2 * d_inner。

self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs)

序列变换

通过一个1D卷积进行序列变换。

        self.conv1d = nn.Conv1d(in_channels=self.d_inner,out_channels=self.d_inner,bias=conv_bias,kernel_size=d_conv,groups=self.d_inner,padding=d_conv - 1,**factory_kwargs,)self.activation = "silu"self.act = nn.SiLU()

参数映射

参数映射是一个简单的线性映射,为了得到输入依赖的矩阵参数B,C还有 Δ \Delta Δ参数

BC参数映射

d_state*2属于B,C参数,dt_rank属于delta参数的原始维度

 self.x_proj = nn.Linear(self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs)
delta参数映射

delta参数的给出:x->x_proj ->split -> dt_proj ->delta

输入x经过x_proj映射得到数据依赖的三个参数 B , C , Δ B, C,\Delta B,C,Δ,其中 Δ \Delta Δ 得到的维度是dt_rank,还需要进行一个(dt_rank, d_inner)的线性映射

self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs)

SSM参数初始化

在这里初始化非输入依赖的SSM参数包括A矩阵和D矩阵,还包括步长delta参数dt的初始化

A , D矩阵初始化
参数维度
A[d_state] -> [d_inner, d_state]
D[d_inner]
		A = repeat(torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device),"n -> d n",d=self.d_inner,).contiguous()A_log = torch.log(A)  # Keep A_log in fp32self.A_log = nn.Parameter(A_log)self.A_log._no_weight_decay = True# D "skip" parameterself.D = nn.Parameter(torch.ones(self.d_inner, device=device))  # Keep in fp32self.D._no_weight_decay = True
delta参数初始化

d t = e α ( ⋅ l o g ( d t _ m a x ) − l o g ( d t _ m i n ) ) + l o g ( d t _ m i n ) dt = e^{\alpha (\cdot log(dt\_max) - log(dt\_min)) + log(dt\_min)} dt=eα(log(dt_max)log(dt_min))+log(dt_min)

其中 α \alpha α属于0到1的均匀分布,因此 d t dt dt的取值为 e l o g d t _ m i n e^{log{dt\_min}} elogdt_min e l o g d t _ m a x e^{log{dt\_max}} elogdt_max。即 d t _ m i n dt\_min dt_min d t _ m a x dt\_max dt_max

softplus函数为 S o f t p l u s ( x ) = 1 β ∗ l o g ( 1 + e x p ( β ∗ x ) ) Softplus(x) = \frac{1}{\beta} \ast log(1+exp(\beta \ast x)) Softplus(x)=β1log(1+exp(βx))

	 # Initialize special dt projection to preserve variance at initializationdt_init_std = self.dt_rank**-0.5 * dt_scaleif dt_init == "constant":nn.init.constant_(self.dt_proj.weight, dt_init_std)elif dt_init == "random":nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std)else:raise NotImplementedError# Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_maxdt = torch.exp(torch.rand(self.d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))+ math.log(dt_min)).clamp(min=dt_init_floor)# Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759inv_dt = dt + torch.log(-torch.expm1(-dt))with torch.no_grad():self.dt_proj.bias.copy_(inv_dt)# Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinitself.dt_proj.bias._no_reinit = True

双向SSM初始化

参数初始化

对于标准Mamba块来说,仅限于前向分支,而后向分支是不存在的,可以看到后向分支是前向分支的复制。在初始化阶段,双向SSM只是额外定义并初始化了一个A矩阵名为A_b。对于v1版本仅仅是多初始化一个矩阵A,而v2版本除此之外,还初始化了标准Mamba所需的全部参数,如D矩阵,参数映射。简单来说,v1版本的双向SSM除A矩阵以外,其他参数是公用的。

参数维度
A_b[d_state] -> [d_inner, d_state]
        # bidirectionalif bimamba_type == "v1":A_b = repeat(torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device),"n -> d n",d=self.d_inner,).contiguous()A_b_log = torch.log(A_b)  # Keep A_b_log in fp32self.A_b_log = nn.Parameter(A_b_log)self.A_b_log._no_weight_decay = Trueelif bimamba_type == "v2":A_b = repeat(torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device),"n -> d n",d=self.d_inner,).contiguous()A_b_log = torch.log(A_b)  # Keep A_b_log in fp32self.A_b_log = nn.Parameter(A_b_log)self.A_b_log._no_weight_decay = True self.conv1d_b = nn.Conv1d(in_channels=self.d_inner,out_channels=self.d_inner,bias=conv_bias,kernel_size=d_conv,groups=self.d_inner,padding=d_conv - 1,**factory_kwargs,)self.x_proj_b = nn.Linear(self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs)self.dt_proj_b = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs)self.D_b = nn.Parameter(torch.ones(self.d_inner, device=device))  # Keep in fp32self.D_b._no_weight_decay = True

前向

参数维度
输入x[b, l, d]
xz[b, 2 * d, l]
x_dbl[b,dt_rank + d_state * 2 ]
SSM参数shape来源
状态矩阵A(d_in, n)在初始化中定义,非数据依赖
输入矩阵B(b, l, n)由x_db1切分而来,因此数据依赖
输出矩阵C(b, l, n)由x_db1切分而来,因此数据依赖
直接传递矩阵D(d_in)在初始化中定义,非数据依赖
数据依赖步长 Δ \Delta Δ(b, l, d_in)由x_db1切分而来,因此数据依赖
维度约定说明
B / bbatch size
L / llength
D / dd_inner

输入映射

输入映射把输入x映射为两个分支xz,主分支x和门控分支z。

def forward(self, hidden_states, inference_params=None):"""hidden_states: (B, L, D)Returns: same shape as hidden_states"""batch, seqlen, dim = hidden_states.shapeconv_state, ssm_state = None, Noneif inference_params is not None:conv_state, ssm_state = self._get_states_from_cache(inference_params, batch)if inference_params.seqlen_offset > 0:# The states are updated inplaceout, _, _ = self.step(hidden_states, conv_state, ssm_state)return out# We do matmul and transpose BLH -> HBL at the same timexz = rearrange(self.in_proj.weight @ rearrange(hidden_states, "b l d -> d (b l)"),"d (b l) -> b d l",l=seqlen,)if self.in_proj.bias is not None:xz = xz + rearrange(self.in_proj.bias.to(dtype=xz.dtype), "d -> d 1")

fast_path

在之后通过use_fast_path分为两种

use_fast_path

在这里通过bimamba类别又可分为三类,v1,v2和其它

bimamba_type == v1

在v1版本中,调用的函数是bimamba_inner_fn 在后面专门介绍。

       if self.use_fast_path and inference_params is None:  # Doesn't support outputting the statesif self.bimamba_type == "v1":A_b = -torch.exp(self.A_b_log.float())out = bimamba_inner_fn(xz,self.conv1d.weight,self.conv1d.bias,self.x_proj.weight,self.dt_proj.weight,self.out_proj.weight,self.out_proj.bias,A,A_b,None,  # input-dependent BNone,  # input-dependent Cself.D.float(),delta_bias=self.dt_proj.bias.float(),delta_softplus=True,)    

bimamba_type == v2

在v2版本中,调用的函数是mamba_inner_fn_no_out_proj在后面专门介绍。可以看到,在这里不同于v1,v2版本因为新增了一套SSM参数,因此也得到了额外的输出out_b,最后的输出也有两种模式,一是两者的简单平均,注意到因为反向SSM方向与正向方向相反,因此反向的输出要先翻转后再相加。而是直接翻转后相加。

			elif self.bimamba_type == "v2":A_b = -torch.exp(self.A_b_log.float())out = mamba_inner_fn_no_out_proj(xz,self.conv1d.weight,self.conv1d.bias,self.x_proj.weight,self.dt_proj.weight,A,None,  # input-dependent BNone,  # input-dependent Cself.D.float(),delta_bias=self.dt_proj.bias.float(),delta_softplus=True,)out_b = mamba_inner_fn_no_out_proj(xz.flip([-1]),self.conv1d_b.weight,self.conv1d_b.bias,self.x_proj_b.weight,self.dt_proj_b.weight,A_b,None,None,self.D_b.float(),delta_bias=self.dt_proj_b.bias.float(),delta_softplus=True,)if not self.if_devide_out:out = F.linear(rearrange(out + out_b.flip([-1]), "b d l -> b l d"), self.out_proj.weight,self.out_proj.bias)else:out = F.linear(rearrange(out + out_b.flip([-1]), "b d l -> b l d") / 2, self.out_proj.weight, self.out_proj.bias)

其他

如果选择了双向模式,却没有定义模式,则使用Mamba默认的mamba_inner_fn

		else:out = mamba_inner_fn(xz,self.conv1d.weight,self.conv1d.bias,self.x_proj.weight,self.dt_proj.weight,self.out_proj.weight,self.out_proj.bias,A,None,  # input-dependent BNone,  # input-dependent Cself.D.float(),delta_bias=self.dt_proj.bias.float(),delta_softplus=True,)
no use_fast_path

和原始论文一致,如果不选择use_fast_path,则会在这里计算完整个流程,而不是定位到selective_scan_interface中定义的函数,而是计算出SSM参数后再调用selective_scan_interface中定义的selective_scan_fn(),SSM数据依赖的参数有参数映射x_proj得到x_db1,然后切分得到B, C,delta参数。

 else:x, z = xz.chunk(2, dim=1)if conv_state is not None:conv_state.copy_(F.pad(x, (self.d_conv - x.shape[-1], 0)))  # Update state (B D W)if causal_conv1d_fn is None:x = self.act(self.conv1d(x)[..., :seqlen])else:assert self.activation in ["silu", "swish"]x = causal_conv1d_fn(x=x,weight=rearrange(self.conv1d.weight, "d 1 w -> d w"),bias=self.conv1d.bias,activation=self.activation,)x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d"))  # (bl d)dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1)dt = self.dt_proj.weight @ dt.t()dt = rearrange(dt, "d (b l) -> b d l", l=seqlen)B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous()C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous()assert self.activation in ["silu", "swish"]y = selective_scan_fn(x,dt,A,B,C,self.D.float(),z=z,delta_bias=self.dt_proj.bias.float(),delta_softplus=True,return_last_state=ssm_state is not None,)if ssm_state is not None:y, last_state = yssm_state.copy_(last_state)y = rearrange(y, "b d l -> b l d")out = self.out_proj(y)if self.init_layer_scale is not None:out = out * self.gamma    return out

双向SSM

v1

对于v1版本双向SSM在前向时首先定义到bimamba_inner_fn,然后调用BiMambaInnerFn

def bimamba_inner_fn(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,out_proj_weight, out_proj_bias,A, A_b, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,C_proj_bias=None, delta_softplus=True
):return BiMambaInnerFn.apply(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,out_proj_weight, out_proj_bias,A, A_b, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus)
前向

去掉和原始论文中的MambaInnerFn相同的部分,在forward前向过程中,不同在于定义了两个输出,分别为out_zf和out_zb,out_zf对应于原来的前向输出,out_zb则是新增的反向输出,最终的out_z是两者翻转相加。

        out_f, scan_intermediates_f, out_z_f = selective_scan_cuda.fwd(conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus)assert not A_b.is_complex(), "A should not be complex!!"out_b, scan_intermediates_b, out_z_b = selective_scan_cuda.fwd(conv1d_out.flip([-1]), delta.flip([-1]), A_b, B.flip([-1]), C.flip([-1]), D, z.flip([-1]), delta_bias, delta_softplus,)out_z = out_z_f + out_z_b.flip([-1])
后向

去掉和原始论文中的MambaInnerFn相同的部分,在backward后向过程中。对应的,定义复制了一套新参数,参数对应如下

原参数新增后向参数
dzdz_b
dconv1d_outdconv1d_out_f_b
ddeltaddelta_f_b
dAdA_b
dBdB_f_b
dCdC_f_b
dDdD_b
        dz_b = torch.empty_like(dz)dconv1d_out_f_b, ddelta_f_b, dA_b, dB_f_b, dC_f_b, dD_b, ddelta_bias_b, dz_b, out_z_b = selective_scan_cuda.bwd(conv1d_out.flip([-1]), delta.flip([-1]), A_b, B.flip([-1]), C.flip([-1]), D, z.flip([-1]), delta_bias, dout_y.flip([-1]), scan_intermediates_b, out_b, dz_b,ctx.delta_softplus,True  # option to recompute out_z)

根据这些新定义的参数,我们和前向参数相加来重定义原始的参数。我们得到新的dconv1d_out,ddelta等参数,最终保持与原始SSM一致

		dconv1d_out = dconv1d_out + dconv1d_out_f_b.flip([-1])ddelta = ddelta + ddelta_f_b.flip([-1])dB = dB + dB_f_b.flip([-1])dC = dC + dC_f_b.flip([-1])dD = dD + dD_bddelta_bias = ddelta_bias + ddelta_bias_bdz = dz + dz_b.flip([-1])out_z = out_z_f + out_z_b.flip([-1])

v2

对于v2版本双向Mamba在前向时首先定义到mamba_inner_fn_no_out_proj,然后调用MambaInnerFnNoOutProj。在v2版本,因为定义了两套SSM参数,因此双向的修改相比于v1要简单。

def mamba_inner_fn_no_out_proj(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,C_proj_bias=None, delta_softplus=True
):return MambaInnerFnNoOutProj.apply(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus)
前向

mamba_inner_fn_no_out_proj 即相比于原始的mamba_inner_fn缺少了输出映射。

  return F.linear(rearrange(out_z, "b d l -> b l d"), out_proj_weight, out_proj_bias)
后向

相应的,在其中修改掉和out_proj_weight相关的部分。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/750210.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

数据结构的概念大合集04(队列)

概念大合集04 1、队列1.1 队列的定义1.2队列的顺序存储1.2.1 顺序队1.2.2 顺序队的基本运算的基本思想1.2.3 顺序队的4要素的基本思想 1.3 环形队列1.3.1 环形队列的定义1.3.1 环形队列的实现 1.4 队列的链式存储1.4.1 链队1.4.2 链队的实现方式1.4.3 链队的4要素的基本思想 1.…

C语言之快速排序

目录 一 简介 二 代码实现 快速排序基本原理: C语言实现快速排序的核心函数: 三 时空复杂度 A.时间复杂度 B.空间复杂度 C.总结: 一 简介 快速排序是一种高效的、基于分治策略的比较排序算法,由英国计算机科学家C.A.R. H…

Arthas使用案例(二)

说明:记录一次使用Arthas排查测试环境正在运行的项目BUG; 场景 有一个定时任务,该定时任务是定时去拉取某FTP服务器上的文件,进行备份、读取、解析等一系列操作。 而现在,因为开发环境是Windows, 线上项…

FFmpeg 常用命令汇总

​​​​​​经常用到ffmpeg做一些视频数据的处理转换等,用来做测试,今天总结了一下,参考了网上部分朋友的经验,一起在这里汇总了一下。 1、ffmpeg使用语法 命令格式: ffmpeg -i [输入文件名] [参数选项] -f [格…

Spring整合RabbitMQ

需求&#xff1a;使用Spring整合RabbitMQ 步骤&#xff1a; 生产者 1.创建生产者工程 2.添加依赖 3.配置整合 4.编写代码发送消息 消费者步骤相同 生产者 导入依赖 <?xml version"1.0" encoding"UTF-8"?> <project xmlns"http://…

linux源配置:ubuntu、centos

1、ubuntu源配置 1&#xff09;先查电脑版本型号: lsb_release -c2&#xff09;再编辑源更新&#xff0c;源要与上面型号对应 参考&#xff1a;https://midoq.github.io/2022/05/30/Ubuntu20-04%E6%9B%B4%E6%8D%A2%E5%9B%BD%E5%86%85%E9%95%9C%E5%83%8F%E6%BA%90/ /etc/apt/…

大衍数列-蓝桥杯?-Lua 中文代码解题第2题

大衍数列-蓝桥杯&#xff1f;-Lua 中文代码解题第2题 中国古代文献中&#xff0c;曾记载过“大衍数列”, 主要用于解释中国传统文化中的太极衍生原理。 它的前几项是&#xff1a;0、2、4、8、12、18、24、32、40、50 … 其规律是&#xff1a;对偶数项&#xff0c;是序号平方再除…

HttpServer整合模块设计与实现(http模块五)

目录 类功能 类定义 类实现 编译测试 源码路标 类功能 类定义 // HttpServer模块功能设计 class HttpServer { private:using Handler std::function<void(const HttpRequest &, HttpResponse &)>;std::unordered_map<std::string, Handler> _get_r…

ISIS接口认证实验简述

默认情况下&#xff0c;ISIS接口认证通过在ISIS协议数据单元&#xff08;PDU&#xff09;中添加认证字段&#xff0c;例如&#xff1a;一个密钥或密码&#xff0c;用于验证发送方的身份。 ISIS接口认证防止未经授权的设备加入到网络中&#xff0c;并确保邻居之间的通信是可信的…

实战:django项目环境搭建(pycharm,virtualBox)

django项目环境搭建 一.创建虚拟环境二.创建PyCharm远程连接 一.创建虚拟环境 需要用到的软件&#xff1a;PyCharm&#xff0c;VirtualBox虚拟机。 1.打开虚拟机终端&#xff0c;创建新的虚拟环境 Book。 2.在虚拟环境中创建新的文件夹 library&#xff0c;cd命令进入该文件…

【四 (6)数据可视化之 Grafana安装、页面介绍、图表配置】

目录 文章导航一、Grafana介绍[✨ 特性]二、安装和配置1、安装2、权限配置&#xff08;账户/团队/用户&#xff09;①用户管理②团队管理③账户管理④看板权限 3、首选项配置4、插件管理①数据源插件②图表插件③应用插件④插件安装方式一⑤安装方式二 三、数据源管理1、添加数…

【STM32定时器 TIM小总结】

STM32 TIM详解 TIM介绍定时器类型基本定时器通用定时器高级定时器常用名词时序图预分频时序计数器时序图 定时器中断配置图定时器定时 代码调试 TIM介绍 定时器&#xff08;Timer&#xff09;是微控制器中的一个重要模块&#xff0c;用于生成定时和延时信号&#xff0c;以及处…

Vue3+TypeScript 学习回顾,温故而知新

文章简介&#xff1a; &#xff08;1&#xff09;简介&#xff1a; 在 Vue3 中编码规范如下&#xff1a; 编码语言: JavaScript代码风格: 组合式API选项式、API简写形式: setup语法糖 &#xff08;2&#xff09;复习内容&#xff1a; 1.核心: ref、reactive、computed、w…

路由器端口转发远程桌面控制:一电脑连接不同局域网的另一电脑

一、引言 路由器端口转发&#xff1a;指在路由器上设置一定的规则&#xff0c;将外部的数据包转发到内部指定的设备或应用程序。这通常需要对路由器进行一些配置&#xff0c;以允许外部网络访问内部网络中的特定服务和设备。端口转发功能可以实现多种应用场景&#xff0c;例如远…

游戏引擎中的动画基础

一、动画技术简介 视觉残留理论 - 影像在我们的视网膜上残留1/24s。 游戏中动画面临的挑战&#xff1a; 交互&#xff1a;游戏中的玩家动画需要和场景中的物体进行交互。实时&#xff1a;最慢需要在1/30秒内算完所有的场景渲染和动画数据。&#xff08;可以用动画压缩解决&am…

用SeaTunnel从SQL Server向Elasticsearch同步数据

文章目录 引言I 步骤1.1 环境准备1.2 配置JDBC插件1.3 编写SeaTunnel任务配置II Enable Sql Server CDC引言 SeaTunnel 的官网 https://seatunnel.apache.org/ Support SQL Server Version: server:2008 (Or later version for information only)Supported DataSource Info: …

抖去推无人直播+矩阵托管+AI文案撰写一体化工具如何开发搭建

一、 开发和搭建抖去推无人直播矩阵托管AI文案撰写一体化工具需要以下步骤&#xff1a; 确定功能需求&#xff1a;确定抖去推无人直播、矩阵托管和AI文案撰写的具体功能需求&#xff0c;如直播推流、直播管理、托管服务、AI文案生成等。 技术选型&#xff1a;选择适合开发该工…

CSS3技巧38:3D 翻转数字效果

博主其它CSS3 3D的文章&#xff1a; CSS3干货4&#xff1a;CSS中3D运用_css 3d-CSDN博客 CSS3干货5&#xff1a;CSS中3D运用-2_中3d-2-CSDN博客 CSS3干货6&#xff1a;CSS中3D运用-3_css3d 使用-CSDN博客 最近工作上烦心的事情太多&#xff0c;只有周末才能让我冷静一下 cod…

HTTPS(超文本传输安全协议)工作过程

一、简述HTTPS HTTPS超文本传输协议&#xff08;全称&#xff1a;Hypertext Transfer Protocol Secure &#xff09;&#xff0c;是以安全为目标的 HTTP 通道&#xff0c;在HTTP的基础上通过传输加密和身份认证保证了传输过程的安全性 。HTTPS 在HTTP 的基础下加入SSL&#x…

Linux第78步_使用原子整型操作来实现“互斥访问”共享资源

使用原子操作来实现“互斥访问”LED灯设备&#xff0c;目的是每次只允许一个应用程序使用LED灯。 1、创建MyAtomicLED目录 输入“cd /home/zgq/linux/Linux_Drivers/回车” 切换到“/home/zgq/linux/Linux_Drivers/”目录 输入“mkdir MyAtomicLED回车”&#xff0c;创建MyA…