源码解析:从零解读SAM(Segment Anything Model)大模型!

节前,我们星球组织了一场算法岗技术&面试讨论会,邀请了一些互联网大厂朋友、参加社招和校招面试的同学。

针对算法岗技术趋势、大模型落地项目经验分享、新手如何入门算法岗、该如何准备、面试常考点分享等热门话题进行了深入的讨论。

合集:

《大模型面试宝典》(2024版) 正式发布!

持续火爆!!!《AIGC 面试宝典》已圈粉无数!


SAM(Segment Anything Model),顾名思义,即为分割一切!该模型由Facebook的Meta AI实验室,能够根据文本指令或图像识别,实现对任意物体的识别与分割。它的诞生,无疑是CV领域的一次重要里程碑。

图片

论文地址:https://arxiv.org/abs/2304.02643
项目地址:https://github.com/facebookresearch/segment-anything

SAM Task

SAM借鉴了NLP领域的Prompt策略,通过给图像分割任务提供Prompt提示来完成任意目标的快速分割。Prompt类型可以是**「前景/背景点集、粗略的框或遮罩、任意形式的文本或者任何指示图像中需要进行分割」**的信息。如下图(a)所示,模型的输入是原始的图像和一些prompt,目标是输出"valid"的分割,所谓valid,就是当prompt的指向是模糊时,模型能够输出至少其中一个mask。

这样,可以是的SAM能够适配各种下游任务。例如,给定一个猫的边界框,SAM能够输出其mask,从而和实例分割任务搭配起来。

图片

SAM Model

如下图所示,SAM模型包含三个核心组件,Image Encoder、Prompt Encoder和Mask Decoder。图像经过Image Encoder编码,Prompt提示经过Prompt Encoder编码,两部分Embedding再经过一个轻量化的Mask Decoder得到融合后的特征。其中,Encoder部分使用的是已有模型,Decoder部分使用Transformer。

图片

Image Encoder

Image Encoder的作用是把图像映射到特征空间,整体过程如下图所示。

图片

正如论文中所讲,本质上这个Encoder可以是任何网络结构,在这里使用的是微调的Detectron的ViT,当然它也可以被改成传统的卷积结构,非常合理。

输入图像经过ViT结构的过程如下:

1. Patch Embedding

输入图像通过一个卷积base,将图像划分为16x16的patches,步长也为16,这样feature map的尺寸就缩小了16倍,同时channel从3映射到768。Patch Embedding示意图如下所示。

图片

代码实现:

'''
将输入的图像转换为序列化的特征向量
'''
class PatchEmbed(nn.Module):def __init__(self,# 卷积核大小# 这里是 (16, 16),意味着图像将被划分为16x16的patcheskernel_size: Tuple[int, int] = (16, 16),# 卷积的步长,与kernel_size相同,即(16, 16),# 意味着每一步移动16个像素,这样图像的尺寸就会减少到原来的1/16stride: Tuple[int, int] = (16, 16),# 控制边缘填充,这里设置为 (0, 0),意味着没有额外的填充padding: Tuple[int, int] = (0, 0),# 输入图像的通道数,通常为3(RGB图像)in_chans: int = 3,# 输出的特征维度,也就是每个patch被编码为的向量的长度,这里设置为768embed_dim: int = 768,) -> None:'''初始化这个子类实例的属性'''# PatchEmbed的子类,继承自nn.Module,用于构建神经网络模块super().__init__()self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding)'''前向传播:接收输入张量 x,形状 (B, C, H, W),其中,- B表示批次大小- C 是输入通道数- H 和 W 是图像的高度和宽度'''def forward(self, x: torch.Tensor) -> torch.Tensor:# 卷积,将输入的通道数从 in_chans 转换为 embed_dimx = self.proj(x)# 将张量的维度顺序从 (B, C, H, W) 调整为 (B, H, W, C)x = x.permute(0, 2, 3, 1)return x

Patch Embedding过程在Vision Transformer结构图中对应下图所示。

图片

2. Positiona Embedding

经过Patch Embedding后输出tokens需要加入位置编码,以保留图像的空间信息。位置编码可以理解为一张map,map的行数与输入序列个数相同,每一行代表一个向量,向量的维度和输入序列tokens的维度相同,位置编码的操作是sum,所以维度依旧保持不变。

图片

图像尺寸是1024,因此patch的数量是1024/16=64。

代码实现:

# 在ImageEncoderViT的__init__定义
if use_abs_pos:# 使用预训练图像大小初始化绝对位置嵌入self.pos_embed = nn.Parameter(torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim))
# 在ImageEncoderViT的forward添加位置编码
if self.pos_embed is not None:x = x + self.pos_embed

Positiona Embedding过程在结构图中对应的部分:

图片

3. Transformer Encoder

feature map通过16个Transformer Block,其中12个Block使用了基于Window Partition(就是把特征图分成14*14的windows做局部的Attention)的注意力机制,以处理局部信息。另外4个Block是全局注意力模块,它们穿插在Window Partition模块之间,以捕捉图像的全局上下文。

# 在ImageEncoderViT的__init__定义
# -----Transformer Encoder-----
# 初始化一个ModuleList,用于存储Block实例
self.blocks = nn.ModuleList()
# 循环创建Block,depth是Transformer Encoder层数
for i in range(depth):# 创建单个Blockblock = Block(# 输入的通道数,即每个patch编码后的向量维度dim=embed_dim,# 自注意力机制中的注意力头数num_heads=num_heads,# MLP层的通道数相对于输入通道数的比例mlp_ratio=mlp_ratio,# 是否在QKV全连接层中使用偏置qkv_bias=qkv_bias,# 归一化层norm_layer=norm_layer,# 激活函数act_layer=act_layer,# 是否使用相对位置编码use_rel_pos=use_rel_pos,# 相对位置编码的初始化设置rel_pos_zero_init=rel_pos_zero_init,# 如果当前Block不是全局注意力层,则使用窗口大小,否则使用0window_size=window_size if i not in global_attn_indexes else 0,# 输入特征的尺寸,基于原始图像大小和patch大小计算得出input_size=(img_size // patch_size, img_size // patch_size),)# 将创建的Block对象添加到self.blocks列表中self.blocks.append(block)
# -----Transformer Encoder-----

Transformer Encoder过程在结构图中对应的部分:

图片

Encoder Block

如上图右所示,Encoder Block从低到高主要由LayerNorm 、Multi-Head Attention和MLP构成。

class Block(nn.Module):def __init__(self,dim: int,                           # 输入通道数num_heads: int,                     # attention中head的个数mlp_ratio: float = 4.0,             # MLP层的通道数相对于输入通道数的比例。qkv_bias: bool = True,              # 如果为True,QKV全连接层包含偏置。norm_layer: Type[nn.Module] = nn.LayerNorm,     # 归一化层act_layer: Type[nn.Module] = nn.GELU,           # 激活层use_rel_pos: bool = False,                      # 是否使用相对位置编码rel_pos_zero_init: bool = True,                 # 相对位置编码的初始化设置window_size: int = 0,                           # 注意力层的窗口大小input_size: Optional[Tuple[int, int]] = None,   # 输入特征的尺寸) -> None:super().__init__()self.norm1 = norm_layer(dim)         # 第一个归一化层,用于注意力层self.attn = Attention(               # Multi-Head Attentiondim,num_heads=num_heads,qkv_bias=qkv_bias,use_rel_pos=use_rel_pos,rel_pos_zero_init=rel_pos_zero_init,input_size=input_size if window_size == 0 else (window_size, window_size),)self.norm2 = norm_layer(dim)      #第二个归一化层,用于MLP之前# MLPself.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer)self.window_size = window_size# 前向传播def forward(self, x: torch.Tensor) -> torch.Tensor:# 保存输入张量的副本shortcut = x# 对输入张量应用第一个归一化层x = self.norm1(x)# Window partition 对X进行paddingif self.window_size > 0:H, W = x.shape[1], x.shape[2]x, pad_hw = window_partition(x, self.window_size)# Multi-Head Attentionx = self.attn(x)# 如果 window_size > 0,使用window_unpartition去除窗口分区的padding,恢复原始尺寸if self.window_size > 0:x = window_unpartition(x, self.window_size, pad_hw, (H, W))# 将注意力层的输出与输入张量相加,实现残差连接x = shortcut + x# 对经过第二个归一化层的张量应用MLP层,再次使用残差连接x = x + self.mlp(self.norm2(x))# 返回最终的张量 xreturn x

Partition操作

在非全局注意力的Block中,为了适应14x14的窗口大小,输入特征图需要进行补边(padding)和拆分操作。具体流程如下:

  1. 输入特征图:输入特征图的初始尺寸为 1x64x64x768。

  2. 确定最小可整除尺寸:窗口大小为14*14,要找到能够被14整除的最小特征图尺寸。对于宽度和高度,我们需要找到大于等于64且能被14整除的最小数。这两个数分别是70(64+6)和70(64+6),所以最小可整除特征图的尺寸是 1x70x70x768。

  3. padding:为了将特征图尺寸从 64x64 扩展到 70x70,我们需要在右下角填充 6x6 的区域,因为70-64=6。这种padding方式确保了窗口可以在特征图的边缘正确地划分。

  4. 拆分特征图:将padding后的特征图1x70x70x768按照窗口大小14x14进行拆分。因为70/14=5,所以特征图可以被拆分为 5x5个14x14的窗口,总共5x5=25个窗口。每个窗口的尺寸为14x14x768。

如下图所示。

图片

# 将输入张量x分割成指定大小的窗口
def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]:# 获取输入张量形状# B表示批次大小,H和W表示高和宽,C表示通道数B, H, W, C = x.shape# 计算填充高度和宽度 pad_h 和 pad_w,以使得输入尺寸能被window_size整除# 避免在分割时产生非完整的窗口pad_h = (window_size - H % window_size) % window_sizepad_w = (window_size - W % window_size) % window_size# 如果需要填充,使用F.pad函数在宽度和高度方向上进行填充if pad_h > 0 or pad_w > 0:x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))# 更新填充后张量的高度和宽度 Hp 和 WpHp, Wp = H + pad_h, W + pad_w# 张量重塑为:B,Hp/S,S,Wp/S,S,C,这样可以将输入张量分割成多个窗口x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)# 调整张量的形状,使其由B,Hp/S,Wp/S,S,S,C-->B*Hp*Wp/(S*S),S,S,C# 这样每个窗口都在张量的连续部分windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)# 返回一个包含所有窗口的张量和原始张量的填充后尺寸 (Hp, Wp)return windows, (Hp, Wp)

「Unpartition操作」

在非全局注意力的Block中,将attention层输出的特征图1x70x70x768转化为1x64x64x768的特征图,实际上是通过切片操作x = x[:1, :64, :64, :],从1x70x70x768的特征图中取出左上角的1x64x64x768部分。

图片

# 用于将window_partition函数分割的窗口重新组合回原始尺寸的张量
def window_unpartition(# 获取输入张量 windows 的形状,以及窗口大小 window_sizewindows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int]
) -> torch.Tensor:# 原始尺寸的填充高度和宽度Hp, Wp = pad_hw# 原始尺寸的无填充高度和宽度H, W = hw# 从窗口张量的总大小中计算出原始批量大小 BB = windows.shape[0] // (Hp * Wp // window_size // window_size)# 重塑窗口张量:B*Hp*Wp/(S*S),S,S,C-->B,Hp/S,Wp/S,S,S,Cx = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)# 再次重塑张量:B,Hp/S,Wp/S,S,S,C-->B,Hp,Wp,Cx = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)# 如果原始尺寸小于填充后的尺寸if Hp > H or Wp > W:# 通过切片 x[:, :H, :W, :] 去除填充部分,只保留原始大小的区域x = x[:, :H, :W, :].contiguous()# B,H,W,C# 返回合并后的张量,其形状为 (B,H,W,C),即原始的批量大小、高度、宽度和通道数return x

Encoder Block过程如下图所示:

图片

window_partition将输入特征的尺寸从(H, W)调整为(S, S)的窗口,其中S是窗口大小。这种调整是为了在多头注意力(Multi-Head Attention)中将相对位置嵌入添加到注意力图(attn)。然而,并非所有Transformer Block都需要在注意力图中嵌入相对位置信息。 window_unpartition 函数的作用是将经过注意力计算的窗口特征重新组合回原始尺寸(S×S–>H×W)。 Hp和Wp是S的整数倍

Multi-Head Attention

先来看Attention,结构如下图所示。

图片

Attention中q、k和v的作用:

图片

代码实现如下:

class Attention(nn.Module):"""Multi-head Attention block with relative position embeddings."""def __init__(self,dim: int,               # 输入通道数num_heads: int = 8,     # head数目qkv_bias: bool = True,  # 是否在QKV线性变换中使用偏置项,默认为Trueuse_rel_pos: bool = False, #是否使用相对位置编码,默认为Falserel_pos_zero_init: bool = True, #如果使用相对位置编码,是否以零初始化,默认为Trueinput_size: Optional[Tuple[int, int]] = None,       # 可选参数,用于指定相对位置编码的尺寸,只有在使用相对位置编码时才需要) -> None:super().__init__()self.num_heads = num_heads #输入head数目head_dim = dim // num_heads #每个head维度self.scale = head_dim**-0.5 #用于缩放注意力得分的因子,以避免数值溢出,取值为head_dim的平方根的倒数#一个全连接层(nn.Linear),将输入映射到Q、K、V的组合self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)#  一个全连接层,用于将注意力机制的输出投影回原始维度self.proj = nn.Linear(dim, dim)self.use_rel_pos = use_rel_posif self.use_rel_pos:        # 使用相对位置编码assert (input_size is not None), "Input size must be provided if using relative positional encoding."# 初始化水平方向(rel_pos_h)和垂直方向(rel_pos_w)的相对位置嵌入# 2S-1,Epos# 输入尺寸为(H, W),则水平方向的位置嵌入长度为2*H-1,垂直方向的位置嵌入长度为2*W-1# 每个位置嵌入的维度为head_dim# 这些位置嵌入以模型参数的形式定义(nn.Parameter),意味着它们会在训练过程中被学习和更新self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))def forward(self, x: torch.Tensor) -> torch.Tensor:# 输入张量x的形状为(B, H, W, C),其中B是批次大小,H和W是高度和宽度,C是通道数(即dim)B, H, W, _ = x.shape# 使用qkv层将x转换为Q、K、V的组合,然后通过重塑和重新排列来准备多头注意力计算# qkv with shape (3, B, nHead, H * W, C)qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)# q, k, v with shape (B * nHead, H * W, C)q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)# attn with shape (B * nHead, H * W,  H * W)# 计算注意力分数# q * self.scale: q是查询向量(query vectors),形状为(B * nHead, H * W, C),其中B是批次大小,nHead是注意力头的数量,H * W是序列的长度,C是每个位置的特征维度# self.scale是用于缩放注意力分数的因子,通常取head_dim的平方根的倒数,以防止数值过大# 乘以self.scale是为了稳定计算并防止梯度消失# k.transpose(-2, -1): k是键向量(key vectors),形状与q相同。transpose(-2, -1)是对k进行转置操作,即将最后一个和倒数第二个维度互换,目的是让q和k在计算点积时的维度匹配。转置后的k形状变为(B * nHead, C, H * W)# 将q和转置后的k进行矩阵乘法。计算每个查询位置q与所有键位置k的点积,生成一个形状为(B * nHead, H * W, H * W)的注意力分数矩阵attn。每个位置i和j的注意力分数表示q_i与k_j的相似度attn = (q * self.scale) @ k.transpose(-2, -1)# 如果启用了相对位置编码if self.use_rel_pos:# (H, W)代表输入序列的尺寸,这里假设H和W是相等的(S×S),即输入是一个正方形网格(例如,图像的像素网格)# attn: 上述计算得到的注意力分数矩阵,形状为(B * nHead, H * W, H * W)# q: 查询向量,形状为(B * nHead, H * W, C)# self.rel_pos_h和self.rel_pos_w: 分别表示水平和垂直方向上的相对位置嵌入,形状分别为(2 * S - 1, head_dim)# (H, W): 输入序列的尺寸,用于指导相对位置嵌入的计算attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))# 生成的注意力分数矩阵attn随后会经过Softmax函数,将每个位置的分数归一化到[0, 1]区间,形成一个概率分布attn = attn.softmax(dim=-1)# 加权求和: # 使用attn @ v计算加权和,其中@表示矩阵乘法,v是值向量(value vectors),形状为(B * nHead, H * W, C)# 注意力权重矩阵attn(形状为(B * nHead, H * W, H * W))与v按元素相乘后,再进行矩阵乘法,得到加权后的值向量,形状为(B * nHead, H * W, C)# 使用.view()将加权后的值向量重塑为(B, self.num_heads, H, W, -1),然后使用.permute(0, 2, 3, 1, 4)进行重排,将self.num_heads移动到第四个维度。最后,使用.reshape(B, H, W, -1)将结果进一步重塑为(B, H, W, -1),与输入张量的形状一致,但保留了多头注意力的输出x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)# 使用self.proj(一个全连接层,形状为(dim, dim))对上述处理后的张量进行线性投影,以将其投影回原始的特征维度x = self.proj(x)# 最终,返回经过线性投影的张量x作为注意力模块的输出return x

在多头注意力(Multi-Head Attention)模块中,输入特征F(N×E)表示一个序列,其中N是序列中的元素数量,E是每个元素的特征维度。具体流程如下。

  1. 首先将每个token的qkv特征维度embed_dim均拆分到每个head上。

图片

  1. 每个head分别通过q和k计算得到权重w,权重w和v得到输出output,合并所有head的output得到最终的output。

图片

get_rel_pos用于计算查询(query)和键(key)之间在二维空间中的相对位置编码,如下图所示。

图片

实现代码:

def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:# 表示查询(query)和键(key)在二维空间中的最大相对距离# max(q_size, k_size):取查询的宽度q_size和键的宽度k_size中的较大值# 如果q_size和k_size都为S,则最大的正向距离是S-1,最大的负向距离也是S-1,所以总的最大距离是2 * S# - 1:减去1是因为在计算相对位置时,0被包含在内,所以最大距离是2 * S - 1max_rel_dist = int(2 * max(q_size, k_size) - 1)# 如果rel_pos的形状的第0个维度(即长度)不等于max_rel_dist,说明需要进行插值if rel_pos.shape[0] != max_rel_dist:# 使用F.interpolate进行线性插值rel_pos_resized = F.interpolate(# 1,N,Ep --> 1,Ep,N --> 1,Ep,2S-1# 将rel_pos重塑为(1, N, Ep),其中N是原始的长度,Ep是每个位置编码的特征维度# 通过permute(0, 2, 1)进行转置,使其形状变为(1, Ep, N)rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),# 设置插值的目标长度为max_rel_distsize=max_rel_dist,# 指定插值方法为线性插值mode="linear",)# Ep,2S-1 --> 2S-1,Ep# 插值后的rel_pos形状为(1, Ep, max_rel_dist),通过reshape(-1, max_rel_dist)将其重塑为(Ep, max_rel_dist)# 再通过permute(1, 0)转置为(max_rel_dist, Ep)rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)else:# 如果rel_pos的长度与max_rel_dist相等,说明已经足够覆盖所有可能的相对位置,因此直接使用rel_pos,不进行任何处理rel_pos_resized = rel_pos# 如果q和k长度值不同,则用短边长度缩放坐标# 创建查询坐标q_coords# torch.arange(q_size)生成一个从0到q_size - 1的整数序列,表示q_size个位置# [:, None]在序列末尾添加一个维度,使其形状为(q_size, 1),这样可以方便与一个标量进行逐元素乘法# max(k_size / q_size, 1.0)计算比例因子,如果k_size大于q_size,则使用k_size / q_size,否则使用1.0# 这确保了在q_size小于k_size的情况下,q_coords的坐标会被适当放大,以匹配k_coords的尺度q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)# 创建键坐标k_coordsk_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)# S,S# 计算了查询(query)和键(key)在二维空间中的相对坐标relative_coords# (q_coords - k_coords):每个查询位置相对于每个键位置的水平距离# (k_size - 1) * max(q_size / k_size, 1.0):计算了一个偏移量,用于确保相对坐标在正确的范围内# (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0):将计算出的差值和偏移量相加,得到最终的相对坐标relative_coordsrelative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)# tensor索引是tensor时,即tensor1[tensor2]# 假设tensor2某个具体位置值是2,则tensor1[2]位置的tensor1切片替换tensor2中的2# tensor1->shape 5,5,3 tensor2->shape 2,2,3 tensor1切片->shape 5,3 tensor1[tensor2]->shape 2,2,3,5,3# tensor1->shape 5,5 tensor2->shape 3,2,3 tensor1切片->shape 5 tensor1[tensor2]->shape 3,2,3,5# 2S-1,Ep-->S,S,Epreturn rel_pos_resized[relative_coords.long()]

add_decomposed_rel_pos为atten注意力特征添加相对位置的嵌入特征,如下图所示。

图片

def add_decomposed_rel_pos(# 注意力分数矩阵attn: torch.Tensor,q: torch.Tensor,rel_pos_h: torch.Tensor,rel_pos_w: torch.Tensor,q_size: Tuple[int, int],k_size: Tuple[int, int],
) -> torch.Tensor:# S,Sq_h, q_w = q_sizek_h, k_w = k_size# rel_pos_h -> 2S-1×Epos# 查询(query)和键(key)在高度方向上的相对位置编码Rh = get_rel_pos(q_h, k_h, rel_pos_h)# 查询(query)和键(key)在宽度方向上的相对位置编码Rw = get_rel_pos(q_w, k_w, rel_pos_w)# 重塑q为(B, q_h, q_w, dim)B, _, dim = q.shaper_q = q.reshape(B, q_h, q_w, dim)# 计算相对位置加权# 计算rel_h和rel_w,这两个张量表示在每个位置上,查询与相对位置编码的加权和# B,q_h,q_w,k_hrel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)# B,q_h, q_w, k_wrel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)# 合并注意力分数和相对位置编码# 将attn重塑为(B, q_h, q_w, k_h, k_w),然后与rel_h和rel_w按元素相加# 将attn重塑为(B, q_h, q_w, k_h, k_w),然后与rel_h和rel_w按元素相加attn = (# B,q_h, q_w, k_h, k_wattn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]).view(B, q_h * q_w, k_h * k_w)return attn

Multi-Head Attention模块为注意力特征嵌入了相对位置特征(add_decomposed_rel_pos):

图片

Neck Convolution

最后,通过两层卷积(Neck)将通道数降低至256,生成最终的Image Embedding。其结构图如下所示。

图片

代码实现如下:

# neck: nn.Sequential,它包含两个卷积层和两个LayerNorm2d)
self.neck = nn.Sequential(# 1x1的卷积层,用于将输入通道数从embed_dim减小到out_chans# 1x1卷积主要用于通道间的信息融合,而不改变特征图的空间尺寸nn.Conv2d(embed_dim,out_chans,kernel_size=1,# 不使用偏置项bias=False,),# 归一化层,用于规范化输出通道的均值和方差,提高模型的稳定性和收敛速度# out_chans:归一化层的通道数LayerNorm2d(out_chans),# 3x3的卷积层nn.Conv2d(# 使用out_chans作为输入和输出通道数out_chans,out_chans,kernel_size=3,# 输入和输出的特征图尺寸保持不变,避免尺寸收缩padding=1,# 不使用偏置bias=False,),# 第二个归一化层,再次对输出进行规范化LayerNorm2d(out_chans),
)
# 归一化
class LayerNorm2d(nn.Module):def __init__(self, num_channels: int, eps: float = 1e-6) -> None:super().__init__()# 创建了两个可学习的参数:weight和bias# weight初始化为全1,bias初始化为全0self.weight = nn.Parameter(torch.ones(num_channels))self.bias = nn.Parameter(torch.zeros(num_channels))self.eps = epsdef forward(self, x: torch.Tensor) -> torch.Tensor:# 沿着通道维度求均值,keepdim=True保留维度,使得u的形状与x相同,除了通道维度的大小为1u = x.mean(1, keepdim=True)                 # dim=1维度求均值并保留通道# 计算标准化因子 s,即减去均值后的平方差的平均值,也保留通道维度s = (x - u).pow(2).mean(1, keepdim=True)# 归一化,将每个像素的值减去均值 u,然后除以标准差的平方根加上一个小的常数 eps 以保证数值稳定性x = (x - u) / torch.sqrt(s + self.eps)# 应用可学习的权重和偏置x = self.weight[:, None, None] * x + self.bias[:, None, None]return x

Prompt Encoder

SAM模型中Prompt Encoder网络结构如下图所示。主要包括三步骤:

  • Embed_Points:标记点编码(标记点由点转变为向量)

  • Embed_Boxes:标记框编码(标记框由点转变为向量)

  • Embed_Masks:mask编码(mask下采样保证与Image Encoder输出一致)

图片

Embed_Points

Embed_Points结构如下图所示。

图片

标记点预处理,将channel由2变为embed_dim(MatMul:forward_with_coords),然后再加上位置编码权重。其中,

  • 2:坐标(h,w)

  • embed_dim:提示编码的channel

「代码实现:」

# 将输入的点坐标和对应的标签转化为高维的嵌入表示,以便于后续的模型处理
def _embed_points(self,points: torch.Tensor,labels: torch.Tensor,pad: bool,
) -> torch.Tensor:# 将输入的点坐标points的每个坐标值增加0.5,以将坐标从像素的左上角移动到像素中心points = points + 0.5# points和boxes联合则不需要padif pad:# 在点坐标 points 和标签 labels 中添加一个填充项# 以保持批次处理的一致性,即使某些样本的点数量少于最大数量。# 填充的点坐标为(0,0),标签为-1padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device)  # B,1,2padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)     # B,1points = torch.cat([points, padding_point], dim=1)                          # B,N+1,2labels = torch.cat([labels, padding_label], dim=1)                          # B,N+1# 根据调整后的点坐标和输入图像的尺寸生成位置编码# 生成的嵌入维度:B,N+1,2f# 2f 表示每个点位置编码的维度,是通过某种函数(如正弦或余弦函数)从原始的2D坐标扩展而来point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size)  # 根据标签 labels 的值,对每个点的嵌入进行调整。# labels为-1是非标记点,设为非标记点权重point_embedding[labels == -1] = 0.0point_embedding[labels == -1] += self.not_a_point_embed.weight# labels为0是背景点,加上背景点权重point_embedding[labels == 0] += self.point_embeddings[0].weight# labels为1是目标点,加上目标点权重point_embedding[labels == 1] += self.point_embeddings[1].weightreturn point_embedding
Embed_Boxes

Embed_Boxes结构如下图所示。

在这里插入图片描述

标记框(Bounding Box)一般有两个点,编码步骤如下:

  1. 将输入的边界框坐标张量boxes从BxNx4转换为BxNx2x2;

  2. 再使用point embedding编码的方式,得到corner_embedding;

  3. 加上之前生成的可学习的embeding向量。

最后输出的corner_embedding大小为Nx2x256。

「代码实现:」

# 将输入的边界框(boxes)转换为高维的嵌入表示
def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:# 将坐标从像素的左上角移动到像素中心boxes = boxes + 0.5# 将输入的边界框坐标张量boxes从BxN*4转换为B*Nx2x2# 其中B是批次大小,N是每个样本中的边界框数量coords = boxes.reshape(-1, 2, 2)# 对每个边界框的角点坐标进行位置编码corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size)    ## 分别对每个边界框的起始点和末尾点的嵌入向量加上特定的权重corner_embedding[:, 0, :] += self.point_embeddings[2].weightcorner_embedding[:, 1, :] += self.point_embeddings[3].weight# 返回加权后嵌入向量,形状为 B*Nx2xembed_dim,其中 embed_dim 是位置编码的维度return corner_embedding
Embed_Mask

mask提示允许我们直接在原图上指示感兴趣区域来引导模型。这些mask通过卷积操作被转换为与图像嵌入空间相匹配的特征,然后与图像嵌入相加结合,为模型提供分割的精确位置信息。

如果没有使用mask提示,则将一组可学习向量(no_mask_embed,1*256)expand为1x256×64×64后替代,使得在处理序列数据时,即使没有具体的mask信息,也能有一个统一的处理方式。

图片

# 在PromptEncoder的forward定义
'''
首先获取no_mask_embed权重矩阵,并将其重塑成一个形状为(1, num_embeddings, 1, 1)的四维张量。再利用.expand方法将这个张量扩展到与图像编码相同的尺寸。bs是batch大小,-1是一个占位符,它会自动计算出
num_embeddings的值以保持张量的元素总数不变。self.image_embedding_size[0]和self.image_embedding_size[1]分别表示图像编码的宽度和高度。
'''
self.no_mask_embed = nn.Embedding(1, embed_dim)      # embed_dim=256
dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(bs, -1, self.image_embedding_size[0], self.image_embedding_size[1]))

如果有配置mask,Embed_Masks结构如下图所示。

在这里插入图片描述

已知输入mask是Nx1x256x256,经过3层卷积,最后得到与Image Embedding一样的size:

首先,mask进入一个1x2x2x4的卷积,stride=2;LN;再进入一个4x2x2x16的卷积,stride=2;LN;最后再进入一个16x1x1x256的卷积;得到最后的mask_embedding的size为Nx256x64x64,最终mask_embedding作为dense_embedding输出,大小为Nx256x64x64。

mask的输出尺寸是Image Encoder模块输出的图像编码尺寸的4倍,因此为了保持一致,需要4倍下采样。

「代码实现」

# 将输入的掩模(mask)张量转换为一个低分辨率的嵌入表示
# 掩模 masks 是一个形状为 BxCxHxW 的张量
# 其中 B 是批次大小,C 是通道数(通常为1,因为掩模通常只有一通道),H 和 W 分别是高度和宽度。
def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor:# mask下采样4倍mask_embedding = self.mask_downscaling(masks)# 返回下采样并转换后的掩模嵌入,其形状为 B*embed_dim*H'*W',其中 H' 和 W' 是下采样后的高度和宽度return mask_embedding# mask_downscaling包括多个卷积层、层归一化(LayerNorm2d)和激活函数,目的是减少掩模的空间维度,同时增加通道维度
self.mask_downscaling = nn.Sequential(# 将通道数从1减少到mask_in_chans//4,同时使用2x2的卷积核和步长2进行下采样,降低了空间分辨率nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2),# 规范化通道维度上的特征LayerNorm2d(mask_in_chans // 4),# 激活函数,引入非线性activation(),# 将通道数恢复到 mask_in_chans,再次使用2x2的卷积核和步长2进行下采样,进一步降低空间分辨率nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2),# LayerNorm2d 层和激活函数LayerNorm2d(mask_in_chans),activation(),# 将通道数增加到 embed_dim,通常是为了与模型的其他部分保持一致nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),)

「PositionEmbeddingRandom」

用于将标记点和标记框的坐标进行提示编码预处理。就是将64x64个坐标点归一化后,与随机高斯矩阵相乘(2x128),再将结果分别进行sin和cos,最后再拼到一起,输出的大小为256x64x64,与image_embedding大小基本一致了。

class PositionEmbeddingRandom(nn.Module):"""Positional encoding using random spatial frequencies."""def init(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:super().init()if scale is None or scale <= 0.0:scale = 1.0# 构建一个2x128的随机矩阵作为位置编码高斯矩阵self.register_buffer("positional_encoding_gaussian_matrix",scale * torch.randn((2, num_pos_feats)),)def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:"""Positionally encode points that are normalized to [0,1]."""# assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shapecoords = 2 * coords - 1# 矩阵乘法:64x64xx2 @ 2x128 ---> 64x64x128coords = coords @ self.positional_encoding_gaussian_matrixcoords = 2 * np.pi * coords# outputs d_1 x ... x d_n x C shape# cat, 最后一个维度上拼接:64x64x256return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)def forward(self, size: Tuple[int, int]) -> torch.Tensor:"""Generate positional encoding for a grid of the specified size."""h, w = sizedevice: Any = self.positional_encoding_gaussian_matrix.device# 构造一个64x64的全1矩阵grid = torch.ones((h, w), device=device, dtype=torch.float32)# 行、列累加y_embed = grid.cumsum(dim=0) - 0.5x_embed = grid.cumsum(dim=1) - 0.5# 行列累加结果归一化y_embed = y_embed / hx_embed = x_embed / w# 行列拼接:64x64x2,编码后的结果是64x64x256pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1))# 最后输出256x64x64return pe.permute(2, 0, 1)  # C x H x W

Mask Decoder

Mask Decoder网络结构参数配置如下。

def __init__(self,*,# transformer通道数transformer_dim: int,# 用于预测mask的Transformer网络模块transformer: nn.Module,# 消除掩码歧义预测的掩码数量,默认为3num_multimask_outputs: int = 3,# 激活函数,默认为GELUactivation: Type[nn.Module] = nn.GELU,# MLP用于预测掩模质量的深度iou_head_depth: int = 3,# MLP的隐藏层通道数iou_head_hidden_dim: int = 256,
) -> None:super().__init__()self.transformer_dim = transformer_dim #存储传入的transformer_dim# 存储传入的transformer模块self.transformer = transformer# 存储掩码预测的输出数量self.num_multimask_outputs = num_multimask_outputs# 用于表示IoU(Intersection over Union)的嵌入层,大小为1×transformer_dim# 可学习的iou tokens:1x256self.iou_token = nn.Embedding(1, transformer_dim)# 包含IoU token在内的总mask token数量# # num_mask_tokens = 3 + 1 = 4, transformer_dim = 256# 输出一个4x256的矩阵self.num_mask_tokens = num_multimask_outputs + 1# 存储所有mask token的嵌入层,大小为num_mask_tokens×transformer_dimself.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)#----- upscaled -----# 用于4倍上采样的序列,包含两个转置卷积层,每个上采样2倍,中间夹着LayerNorm和激活函数self.output_upscaling = nn.Sequential(nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),     #转置卷积 上采样2倍LayerNorm2d(transformer_dim // 4),activation(),nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),activation(),)# ----- upscaled -----# 多层感知机(MLP)模块#  一个模块列表,包含了num_mask_tokens个MLP,每个MLP用于处理不同mask的输出self.output_hypernetworks_mlps = nn.ModuleList([MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3)for i in range(self.num_mask_tokens)])# ----- MLP -----# ----- MLP -----# 一个MLP,用于预测IoU,输入是transformer_dim,经过iou_head_hidden_dim的隐藏层,输出是num_mask_tokensself.iou_prediction_head = MLP(transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth)# ----- MLP -----

SAM模型Mask Decoder网络结构如下图所示。

在这里插入图片描述

  • spa_pro_emb(sparse embedding)、iou_token、mask_token合并成一个tokens,作为point_embeddings。

  • spa_pro_emb: point、bbox prompt合并后的产物,一般为NxXx256。

  • iou_token:可学习参数,大小为1x256。

  • mask_token:可学习参数,大小为4x256。

原论文中Mask Decoder模块各部分结构示意图如下。

在这里插入图片描述

Mask Decoder网络在特征提取中的基本步骤如下:

  1. transformer:将来自编码器的图像特征与额外的提示信息(如掩码提示或查询向量)融合,以捕捉目标区域的上下文信息。

  2. upscaled:对粗略mask src进行上采样,使其与原始图像尺寸相匹配,以便进行更精细的mask预测。

  3. mask_MLP:通过一系列全连接层,对上采样后的特征进行变换,计算出针对每个像素的mask概率。这些层可以设计为学习如何为每个mask通道分配权重,从而生成最终的mask输出。

  4. iou_MLP:评估生成的mask与真实mask之间的重叠程度,即预测mask的质量。

def forward(self,# image encoder 图像特征image_embeddings: torch.Tensor,# 位置编码# 256x64x64image_pe: torch.Tensor,# 标记点和标记框的嵌入编码sparse_prompt_embeddings: torch.Tensor,# 输入mask的嵌入编码dense_prompt_embeddings: torch.Tensor,# 是否输出多个maskmultimask_output: bool,
) -> Tuple[torch.Tensor, torch.Tensor]:# 将这些特征融合,通过Transformer和后续的上采样及MLP层,生成掩膜预测和IoU分数masks, iou_pred = self.predict_masks(image_embeddings=image_embeddings,image_pe=image_pe,sparse_prompt_embeddings=sparse_prompt_embeddings,dense_prompt_embeddings=dense_prompt_embeddings,)# 如果multimask_output为True,表示需要输出多个掩模,选取索引为1到num_multimask_outputs的所有掩模if multimask_output:mask_slice = slice(1, None)# 否则,如果multimask_output为False,仅输出第一个掩模(通常是最高得分的掩模)else:mask_slice = slice(0, 1)# 根据multimask_output选择后的掩模,维度调整为(batch_size, num_selected_masks, height, width)masks = masks[:, mask_slice, :, :]# 根据multimask_output选择后的IoU预测,维度调整为(batch_size, num_selected_masks)iou_pred = iou_pred[:, mask_slice]return masks, iou_pred
def predict_masks(self,# image embedding: 是image encoder的输出,大小为为1x256x64x64image_embeddings: torch.Tensor,# image_pe位置编码也拓展成Nx256x64x64的矩阵image_pe: torch.Tensor,sparse_prompt_embeddings: torch.Tensor,dense_prompt_embeddings: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:# 首先将iou token和mask token 拼接得到一个5x256的矩阵,再将其拓展到与sparse embedding一个维度Nx5x256# 1,E and 4,E --> 5,Eoutput_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)# 再将拓展后的矩阵与sparse embedding拼接得到tokens,其大小Nx(5+X)x256# 5,E --> B,5,Eoutput_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)# 再与稀疏矩阵拼接,假设稀疏矩阵只有point为Nx2x256,拼接之后则为Nx(5+2)x256# B,5,E and B,N,E -->B,5+N,E       N是点的个数(标记点和标记框的点)tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)# 将image embedding(1x256x64x64)拓展成稠密prompt的维度:Nx256x64x64# B,C,H,Wsrc = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)#将拓展后的image embedding直接与稠密prompt相加:Nx256x64x64# B,C,H,W + 1,C,H,W ---> B,C,H,Wsrc = src + dense_prompt_embeddings# # 将256x64x64的位置编码,拓展成Nx256x64x64# 1,C,H,W---> B,C,H,Wpos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)b, c, h, w = src.shape# ----- transformer -----# Run the transformer:这里使用的TwoWayTransformer,有必要对输入再说明一下# src:image_bedding + dense_prompt(mask),Nx256x64x64# pos_src: 位置编码,Nx256x64x64# tokens: iou_tokens + mask_tokens + sparse_prompt(point/bbox),Nx(5+x)x256# B,N,Chs, src = self.transformer(src, pos_src, tokens)# ----- transformer -----# # 后处理iou_token_out = hs[:, 0, :]mask_tokens_out = hs[:, 1: (1 + self.num_mask_tokens), :]# 通过上采样层将Transformer输出的掩模部分恢复到(batch_size, channels, height, width)的形状# B,N,C-->B,C,H,Wsrc = src.transpose(1, 2).view(b, c, h, w)# ----- upscaled -----# 4倍上采样upscaled_embedding = self.output_upscaling(src)# ----- upscaled -----# 对每个mask token,通过其对应的MLP得到一个权重张量,使用这些权重与上采样后的特征张量进行点乘,得到掩模预测(batch_size, num_mask_tokens, height, width)hyper_in_list: List[torch.Tensor] = []# ----- mlp -----for i in range(self.num_mask_tokens):# mask_tokens_out[:, i, :]: B,1,C# output_hypernetworks_mlps: B,1,chyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]))# B,n,chyper_in = torch.stack(hyper_in_list, dim=1)# ----- mlp -----b, c, h, w = upscaled_embedding.shape# B,n,c × B,c,N-->B,n,h,wmasks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)# ----- mlp -----# 通过IoU预测头(MLP)对IoU token的输出进行处理,得到(batch_size, num_mask_tokens)的IoU分数# iou_token_out: B,1,niou_pred = self.iou_prediction_head(iou_token_out)# ----- mlp -----# 返回预测的掩模和IoU分数# masks: B,n,h,w# iou_pred: B,1,nreturn masks, iou_pred
1. transformer

Mask Decoder由多个重复堆叠TwoWayAttention Block和1个Multi-Head Attention组成。

图片

「TwoWayAttention Block」

TwoWayAttention Block由LayerNorm 、Multi-Head Attention和MLP构成。所谓的TwoWay:即是两轮次循环,第一次point_embedding自注意,第二次则加上上一轮输出的queries进行attention。

图片

原论文中TwoWayAttention部分示意图。

图片

class TwoWayAttentionBlock(nn.Module):def __init__(self,embedding_dim: int,         # 输入特征维度num_heads: int,             # 注意力头的数量,决定了注意力机制的并行度mlp_dim: int = 2048,        # MLP(多层感知机)中间层的维度,用于特征变换和非线性增强activation: Type[nn.Module] = nn.ReLU,      # 激活函数类型,默认为ReLUattention_downsample_rate: int = 2,         # 下采样比率# 是否在第一层自注意力中跳过位置编码的残差连接skip_first_layer_pe: bool = False,) -> None:super().__init__()# 自注意力模块,用于增强queries内部的信息交互self.self_attn = Attention(embedding_dim, num_heads)# norm1/2/3/4: LayerNorm层,用于稳定训练和加速收敛self.norm1 = nn.LayerNorm(embedding_dim)# cross_attn_token_to_image和cross_attn_image_to_token: 交叉注意力模块,分别让标记点特征关注图像特征,以及图像特征反过来关注标记点特征self.cross_attn_token_to_image = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate)self.norm2 = nn.LayerNorm(embedding_dim)# mlp: 多层感知机模块,增加模型的表达能力self.mlp = MLPBlock(embedding_dim, mlp_dim, activation)self.norm3 = nn.LayerNorm(embedding_dim)self.norm4 = nn.LayerNorm(embedding_dim)self.cross_attn_image_to_token = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate)self.skip_first_layer_pe = skip_first_layer_pe# 前向传播def forward(self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor) -> Tuple[Tensor, Tensor]:# queries:标记点编码相关(原始标记点编码经过一系列特征提取)# keys:原始图像编码相关(原始图像编码经过一系列特征提取)# query_pe:原始标记点编码# key_pe:原始图像位置编码# 第一轮本身queries==query_pe没比较再"残差"# 首先对queries应用自注意力,若skip_first_layer_pe=True,直接使用queries进行自注意力计算;否则,将queries与query_pe相加后进行自注意力计算,并残差连接回queries,之后进行LayerNormif self.skip_first_layer_pe:queries = self.self_attn(q=queries, k=queries, v=queries)else:q = queries + query_peattn_out = self.self_attn(q=q, k=q, v=queries)queries = queries + attn_outqueries = self.norm1(queries)# 调整queries和keys(图像特征)加上各自的位置编码,然后通过cross_attn_token_to_image交叉注意力层,使标记点特征关注图像特征,结果与原始queries残差连接并进行LayerNormq = queries + query_pek = keys + key_peattn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)queries = queries + attn_outqueries = self.norm2(queries)# MLP block:将更新后的queries通过MLP模块进行非线性变换,结果与原queries残差连接并进行LayerNormmlp_out = self.mlp(queries)queries = queries + mlp_outqueries = self.norm3(queries)# 交叉注意力(图像到标记点):再次调整queries和keys加上位置编码,但这次通过cross_attn_image_to_token让图像特征关注标记点特征,更新后的keys与原始keys残差连接并进行LayerNormq = queries + query_pek = keys + key_peattn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)keys = keys + attn_outkeys = self.norm4(keys)return queries, keys

「Attention」

Mask Decoder的Attention与ViT的Attention有些细微的不同:

  • Mask Decoder的Attention是3个FC层分别接受3个输入获得q、k和v。

  • ViT的Attention是1个FC层接受1个输入后将结果均拆分获得q、k和v。

如下图所示。

图片

原论文中Attention部分示意图。

图片

class Attention(nn.Module):def __init__(self,embedding_dim: int,         # 输入特征的维度num_heads: int,             # attention的head数downsample_rate: int = 1,   # 下采样) -> None:super().__init__()self.embedding_dim = embedding_dim# 内部维度self.internal_dim = embedding_dim // downsample_rateself.num_heads = num_headsassert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim."# 四个线性层(全连接层):用于生成query向量、key向量、value向量self.q_proj = nn.Linear(embedding_dim, self.internal_dim)self.k_proj = nn.Linear(embedding_dim, self.internal_dim)self.v_proj = nn.Linear(embedding_dim, self.internal_dim)# 用于将注意力机制后的输出投影回原始的特征维度self.out_proj = nn.Linear(self.internal_dim, embedding_dim)# 将输入张量分解为多头注意力所需的形状def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:b, n, c = x.shapex = x.reshape(b, n, num_heads, c // num_heads)return x.transpose(1, 2)  # B x N_heads x N_tokens x C_per_head# 在注意力计算后重新组合这些头部def _recombine_heads(self, x: Tensor) -> Tensor:b, n_heads, n_tokens, c_per_head = x.shapex = x.transpose(1, 2)return x.reshape(b, n_tokens, n_heads * c_per_head)  # B x N_tokens x Cdef forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:# 输入投影:分别使用q_proj、k_proj和v_proj对query、key和value进行线性变换q = self.q_proj(q)k = self.k_proj(k)v = self.v_proj(v)# 分离头部:将变换后的query、key和value张量按照num_heads进行重塑,以便进行多头注意力计算# B,N_heads,N_tokens,C_per_headq = self._separate_heads(q, self.num_heads)k = self._separate_heads(k, self.num_heads)v = self._separate_heads(v, self.num_heads)# 注意力计算:# 计算query和key的点积,然后除以c_per_head的平方根进行归一化,以防止数值过大_, _, _, c_per_head = q.shapeattn = q @ k.permute(0, 1, 3, 2)  # B,N_heads,N_tokens,C_per_head# 归一化Scaleattn = attn / math.sqrt(c_per_head)# 应用softmax函数得到注意力权重attn = torch.softmax(attn, dim=-1)# 使用注意力权重对value进行加权求和,得到注意力输出out = attn @ v# # B,N_tokens,C# 重新组合头部:将多头注意力输出合并回原始的特征维度。out = self._recombine_heads(out)# 输出投影:最后,通过out_proj将输出投影回原始的embedding_dimout = self.out_proj(out)return out

「transformer_MLP」

transformer中MLP的结构如下图所示。

图片

# MLPBlock类是一个简单的多层感知机(MLP)模块,由两个全连接层(Linear)和一个激活函数组成
class MLPBlock(nn.Module):def __init__(self,# 输入的维度,通常是特征向量的长度embedding_dim: int,# MLP中间层的宽度,可以设置为比输入维度更大的值以增加模型的表达能力mlp_dim: int,# 激活函数,这里默认使用GELUact: Type[nn.Module] = nn.GELU,) -> None:super().__init__()# 第一个全连接层,将输入从embedding_dim维度变换到mlp_dim维度self.lin1 = nn.Linear(embedding_dim, mlp_dim)# 第二个全连接层,将mlp_dim维度的结果变换回embedding_dim维度,以保持与输入相同的维度self.lin2 = nn.Linear(mlp_dim, embedding_dim)# 激活函数实例,用于在全连接层之间引入非线性self.act = act()# 接收输入张量x,将其传递给lin1,然后应用激活函数act。# 将激活函数的输出传递给lin2,得到最终的输出张量def forward(self, x: torch.Tensor) -> torch.Tensor:return self.lin2(self.act(self.lin1(x)))

「upscaled」

这个上采样过程将Transformer的输出特征图恢复到更接近输入图像的分辨率,以便于生成掩模预测。upscaled的结构如下图所示。

图片

# 在MaskDecoder的__init__定义
# output_upscaling是一个序列模块,用于上采样Transformer输出的特征图
self.output_upscaling = nn.Sequential(# 使用nn.ConvTranspose2d,输入通道数为transformer_dim,输出通道数为transformer_dim // 4,内核大小为2,步长为2# 将特征图的尺寸放大两倍,同时将通道数减半# 内核大小为2的转置卷积相当于上采样2倍,步长为2确保输出尺寸翻倍nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),     #转置卷积 上采样2倍# 层归一化(LayerNorm2d)LayerNorm2d(transformer_dim // 4),# 激活函数activation(),# 再次使用nn.ConvTranspose2d,输入通道数为transformer_dim // 4,输出通道数为transformer_dim // 8,内核大小为2,步长为2。这一步继续将特征图的尺寸放大两倍,同时通道数再次减半nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),# 重复激活函数的过程,以进一步增强非线性表达activation(),
)
# 在MaskDecoder的predict_masks添加位置编码
upscaled_embedding = self.output_upscaling(src)

「mask_MLP」

此处的MLP基础模块不同于ViT的MLP(transformer_MLP)基础模块。

# 在MaskDecoder的__init__定义
# output_hypernetworks_mlps是一个nn.ModuleList,包含了多个多层感知机(MLP)。每个MLP的目的是根据输入的mask_tokens_out生成特定掩模的超网络权重
self.output_hypernetworks_mlps = nn.ModuleList([# transformer_dim: Transformer的输出维度,也是输入到MLP的通道数# transformer_dim // 8: MLP的输出通道数,用于生成超网络的权重# 3: MLP的中间层维度,用于增加模型的表达能力MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3)for i in range(self.num_mask_tokens)]
)
# 在MaskDecoder的predict_masks添加位置编码
# 对于self.num_mask_tokens个掩模token,遍历output_hypernetworks_mlps列表
for i in range(self.num_mask_tokens):# mask_tokens_out[:, i, :]: B,1,C# output_hypernetworks_mlps: B,1,c# 对每个掩模token,应用对应的MLP,输入是mask_tokens_out中对应位置的特征,输出为B, 1, c形状的张量,其中c是超网络的输出通道数# 将每个MLP的输出收集到hyper_in_list列表中hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]))
# B,n,c
# 将hyper_in_list堆叠成一个B, n, c形状的张量hyper_in,其中n是掩模token的数量
hyper_in = torch.stack(hyper_in_list, dim=1)
# 获取upscaled_embedding的形状b, c, h, w,其中b是批次大小,c是通道数,h和w是高度和宽度
b, c, h, w = upscaled_embedding.shape
# B,n,c × B,c,N-->B,n,h,w
# 执行矩阵乘法(@运算符)将hyper_in(B, n, c)与upscaled_embedding(在通道维度上展平为B, c, h * w)相结合
# 计算每个掩模token的超网络权重与上采样特征图的点积,得到B, n, h * w形状的张量
# 通过view操作将结果转换回B, n, h, w形状,生成了masks张量,表示每个掩模token对应的预测掩模
masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)

「iou_MLP」

此处的MLP基础模块不同于ViT的MLP(transformer_MLP)基础模块。

# 在MaskDecoder的__init__定义
# 一个多层感知机(MLP)模块,其目的是预测每个掩模token对应的IoU(Intersection over Union,交并比)值,以评估预测掩模与真实掩模的重合程度
self.iou_prediction_head = MLP(# transformer_dim: 输入到MLP的特征维度,通常与Transformer的输出维度相同# iou_head_hidden_dim: MLP中间层的维度,用于增强模型的表达能力# self.num_mask_tokens: 输出维度,即预测的掩模令牌数量,每个令牌对应一个IoU预测值transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth
)
# 在MaskDecoder的predict_masks添加位置编码
iou_pred = self.iou_prediction_head(iou_token_out)

「MaskDeco_MLP」

Mask Decoder中MLP的结构如下图所示。

图片

'''
定义了一个多层感知机,它包含一个可配置的隐藏层数目、输入和输出维度,并可以选择是否在输出层应用Sigmoid激活函数
'''
class MLP(nn.Module):def __init__(self,input_dim: int,         # 输入特征的维度,即输入张量的通道数hidden_dim: int,        # 隐藏层的通道数,中间层的宽度output_dim: int,        # 输出特征的维度,即输出张量的通道数num_layers: int,        # 多层感知机的层数,包括输入层和输出层sigmoid_output: bool = False, #  一个布尔值,表示是否在输出层应用Sigmoid激活函数,默认为False) -> None:'''内部组件'''super().__init__()# 存储输入的层数self.num_layers = num_layers# 一个列表,包含num_layers - 1个hidden_dim,用于构建中间层的线性变换h = [hidden_dim] * (num_layers - 1)#  一个nn.ModuleList,包含num_layers个线性层(全连接层),每个层的输入和输出通道数由h和input_dim、output_dim决定self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))self.sigmoid_output = sigmoid_outputdef forward(self, x):# 对输入张量x,遍历layers列表中的每个线性层for i, layer in enumerate(self.layers):# 如果当前层不是最后一层,应用ReLU激活函数(F.relu)x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)# 如果sigmoid_output为True,最后对输出应用Sigmoid激活函数if self.sigmoid_output:x = F.sigmoid(x)return x

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/26501.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

TF-IDF算法教程

前言 TF-IDF&#xff08;Term Frequency-Inverse Document Frequency&#xff09;是一种常用的文本分析技术&#xff0c;广泛应用于信息检索和文本挖掘领域。它是一种统计方法&#xff0c;用于评估一个词语在一个文档中的重要程度。TF-IDF的核心思想是&#xff1a;如果一个词语…

VS2019+QT5.15调用动态库dll带有命名空间

VS2019QT5.15调用动态库dll带有命名空间 vs创建动态库 参考&#xff1a; QT调用vs2019生成的c动态库-CSDN博客 demo的dll头文件&#xff1a; // 下列 ifdef 块是创建使从 DLL 导出更简单的 // 宏的标准方法。此 DLL 中的所有文件都是用命令行上定义的 DLL3_EXPORTS // 符号…

四十一、openlayers官网示例Flight Animation解析——在地图上绘制飞机航线、牵引线效果、动态动画

官网demo地址&#xff1a; Flight Animation 这篇介绍了如何实现飞机航线动画。 首先加载一张底图&#xff0c;定义一个样式。 const tileLayer new TileLayer({source: new StadiaMaps({layer: "outdoors",}),});const map new Map({layers: [tileLayer],target…

【实例分享】访问后端服务超时,银河麒麟服务器操作系统分析及处理建议

1.服务器环境以及配置 【机型】 处理器&#xff1a; Intel 32核 内存&#xff1a; 128G 整机类型/架构&#xff1a; x86_64虚拟机 【内核版本】 4.19.90-25.22.v2101.kylin.x86_64 【OS镜像版本】 kylin server V10 SP2 【第三方软件】 开阳k8s 2.问题现象描述 …

API工具--Apifox和Postman对比(区别)

&#x1f525; 交流讨论&#xff1a;欢迎加入我们一起学习&#xff01; &#x1f525; 资源分享&#xff1a;耗时200小时精选的「软件测试」资料包 &#x1f525; 教程推荐&#xff1a;火遍全网的《软件测试》教程 &#x1f4e2;欢迎点赞 &#x1f44d; 收藏 ⭐留言 &#x1…

【学习笔记】Linux

Linux 1、 介绍 1.1、概述 1.2、特点 1.3、Linux的发行版2、 基础篇 —— 文件系统 2.1、文件系统 2.2、目录结构3、 基础篇 —— VI/VIM 编辑器 3.1、概述 3.2、编辑器模式及常用命令4、 基础篇 —— 网络配置 4.1、VMware NetWork …

Linux so文件无法找到及某条命令找不到的解决办法

前言 在一些定制软件中可能会自带so文件。或者自带一些二进制命令。 这时会如果运行某个程序会发生 **.so 文件无法找到的错误。 以及 * 某条命令无法找到的错误。 比如像是下面这样 解决办法&#xff1a; so文件无法找到 通过往 LD_LIBRARY_PATH 变量中追加路径来告诉程序…

cdh中的zookeeper怎么配置zoo.cfg

你手动改了zoo.cfg目录是不会生效的&#xff0c;因为是cdh在管控&#xff0c;所以只能通过cdh修改。 首先打开cdh。 xxx:7180 点击zookeeper 选配置&#xff0c;然后选高级 在右边找&#xff0c;有一个就是zoo.cfg&#xff0c;可以点击右边的感叹号。然后在里面编辑的就会直…

LabVIEW RT环境中因字符串拼接导致的系统崩溃问题

在LabVIEW实时操作系统&#xff08;RT&#xff09;环境中运行的应用程序出现字符串拼接后死机的问题&#xff0c;通常涉及内存管理、内存泄漏或其他资源管理问题。以下是一些指导和步骤&#xff0c;帮助解决这个问题&#xff1a; 1. 内存泄漏检测 字符串拼接会在内存中创建新…

Android14音频进阶之CarAudioManager::getOutputDeviceForUsage流程分析(七十七)

简介: CSDN博客专家,专注Android/Linux系统,分享多mic语音方案、音视频、编解码等技术,与大家一起成长! 优质专栏:Audio工程师进阶系列【原创干货持续更新中……】🚀 优质专栏:多媒体系统工程师系列【原创干货持续更新中……】🚀 优质视频课程:AAOS车载系统+AOSP…

2024 年 19 种最佳大型语言模型

大型语言模型是 2023 年生成式人工智能热潮背后的推动力。然而&#xff0c;它们已经存在了一段时间了。 LLM是黑盒 AI 系统&#xff0c;它使用深度学习对超大数据集进行处理&#xff0c;以理解和生成新文本。现代 LLM 开始成型于 2014 年&#xff0c;当时一篇题为“通过联合学…

Github2024-06-12 开源项目日报 Top10

根据Github Trendings的统计,今日(2024-06-12统计)共有10个项目上榜。根据开发语言中项目的数量,汇总情况如下: 开发语言项目数量Python项目4JavaScript项目2Lua项目1PHP项目1Blade项目1非开发语言项目1TypeScript项目1Shell项目1从零开始构建你喜爱的技术 创建周期:2156 天…

C++ 25 之 调用函数调用规则

c25调用函数调用规则.cpp #include<iostream> using namespace std;class Students04{ // 1.创建好类之后&#xff0c;编译器会默认提供三个函数&#xff1a;默认构造函数、构造函数、拷贝构造函数 // 2.自己写了有参构造函数&#xff0c;编译器就不会提供默认构造函数&…

[imx6ull]Linux下的SocketCAN通信

文章目录 一、CAN总线协议1.简介2.电气属性3.通信原理①数据帧的帧格式&#xff1a;②总线同步③总线竞争④数据保护 二、Linux下CAN的操作1.硬件连接①CAN电平转换器②扩展板使用CAN 2.查询 can 信息3.开启/关闭 can4.发送/接收 can 数据5.设置 can 参数 三、CAN的回环测试四、…

【知识整理】软件版本号的定义及规范

版本号简述 在软件开发项目中&#xff0c;版本号是一个非常重要的概念&#xff0c;它能够告诉用户软件的功能、质量和安全性等信息&#xff0c;同时也可以帮助开发者追踪软件的历史和进展&#xff0c;并做好版本控制工作。在本文中&#xff0c;我们将介绍版本号的定义及规范&a…

Java基础面试重点-3

41. 简述线程生命周期(状态) 其它参考《多线程重点》中的说法。三种阻塞&#xff1a; 等待阻塞&#xff1a; 运行的线程执行o.wait()方法&#xff08;该线程已经持有锁&#xff09;&#xff0c;JVM会把该线程放入等待队列中。同步阻塞&#xff1a; 运行的线程在获取对象的同步…

数据挖掘丨轻松应用RapidMiner机器学习内置数据分析案例模板详解(下篇)

RapidMiner 案例模板 RapidMiner 机器学习平台提供了一个可视化的操作界面&#xff0c;允许用户通过拖放的方式构建数据分析流程。RapidMiner目前内置了 13 种案例模板&#xff0c;这些模板是预定义的数据分析流程&#xff0c;可以帮助用户快速启动和执行常见的数据分析任务。 …

jsp 实验20

三、源代码以及执行结果截图&#xff1a; NewFile.jsp <% page import "java.io.*" %> <% page contentType"text/html" %> <% page pageEncoding "utf-8" %> <jsp:useBean id"english" class "web.Engli…

QT--DAY1

不使用图形化界面实现一个登陆界面 #include "widget.h"Widget::Widget(QWidget *parent): QWidget(parent) {//设置窗口标题this->setWindowTitle("登录界面");//设置窗口大小this->resize(535,410);//固定窗口大小this->setFixedSize(535,410)…

181.二叉树:验证二叉树(力扣)

代码解决 /*** Definition for a binary tree node.* struct TreeNode {* int val;* TreeNode *left;* TreeNode *right;* TreeNode() : val(0), left(nullptr), right(nullptr) {}* TreeNode(int x) : val(x), left(nullptr), right(nullptr) {}* Tre…