文献阅读笔记:SAM大模型(Segment Anything)
- 摘要
- Abstract
- 1. SAM大模型
- 1.1 文献摘要
- 1.2 引言
- 1.3 SAM大模型网络结构
- 1.4 实验
- 1.4.1 零样本单点有效掩码评估
- 1.4.2 零样本边缘检测
- 1.4.3 零样本对象提议
- 1.4.4 零样本通过文本提示预测mask
- 1.5 SAM模型代码实现
- 2. 总结
摘要
本周学习了SAM大模型,该大模型用于图像分割的新任务、模型和数据集。该模型的设计和训练具有快速性,因此它可以将 零样本 转移到新的图像分布和任务。通过评估SAM在众多任务上的能力,该大模型的零样本性能与之前完全监督的结果相媲美甚至优于之前的结果。本文将详细介绍SAM模型。
Abstract
This week, we learned about the SAM Big Model, a new task, model, and dataset for image segmentation. The model is designed and trained with rapidity so that it can transfer zero samples to new image distributions and tasks. By evaluating SAM’s capabilities on a wide range of tasks, the zero-sample performance of this large model matches or even outperforms previous fully supervised results. This paper describes the SAM model in detail.
1. SAM大模型
论文题目:Segment Anything
1.1 文献摘要
作者在这篇文献中介绍了SAM大模型:该大模型用于图像分割的新任务、模型和数据集。作者构建了迄今为止最大的分割数据集,在 1100 万张许可且尊重隐私的图像上包含超过 10 亿个mask。该模型的设计和训练具有快速性,因此它可以将 零样本 转移到新的图像分布和任务。通过评估SAM在众多任务上的能力,该大模型的零样本性能与之前完全监督的结果相媲美甚至优于之前的结果。
1.2 引言
SAM大模型通过大规模的数据集进行预训练后,在各个场景可以通过零样本来进行。此方式可以通过手动文本的提示来进行,也可以通过用户操作后,向大模型传递一个目标点、目标框来进行响应。而大模型的精确度会随着模型的规模、数据集大小和训练次数增多而进一步改善。SAM大模型经过训练后,设计的文本提示可以零样本概括新的视觉概念和数据分布。这种编码器还可以与其他模块有效组合,以实现下游任务,所谓下游任务是指建立在已训练模型之上的任务,通常是使用预训练模型的中间表示(例如隐藏层的输出)作为输入。这些下游任务可以是分类、回归、序列标注、文本生成等,旨在利用预训练模型学到的知识来解决特定的问题。例如图像生成(例如,DALL·E )。
作者在这篇文献中提出了建立图像分割的基础模型。SAM大模型是一个可提示的模型,并使用能够实现强大泛化的任务在广泛的数据集上对其进行预训练。 作者提出了提示分割任务,其目标是在给定任何分割提示的情况下返回有效的分割mask(分割区域和对象)(见下图)。提示只是指定在图像中分割什么,例如,提示可以包括识别对象的区域或文本信息。有效输出mask的要求意味着,即使提示不明确并且可能引用多个对象(例如,衬衫上的点可能表示衬衫或穿着它的人),输出也应该是合理的mask,至少其中一个对象。
由于该SAM模型必须支持灵活的提示,需要实时计算和使用mask以允许交互使用,并且必须具有模糊性意识。作者通过将 SAM 分成图像编码器和快速提示编码器/mask解码器,可以通过不同的提示重复使用相同的图像嵌入。给定图像嵌入,提示编码器和mask解码器在网络浏览器根据提示预测掩码。提示信息可以通过用户标点、框和用户随机圈出的区域,并且还使用自由格式的文本提示来呈现初步结果。为了使 SAM 能够感知歧义,作者将其设计为预测单个提示的多个mask,从而使 SAM 能够自然地处理歧义,例如衬衫与人的示例。
为了实现对新数据分布的强大泛化能力,需要在大型和多样化的mask数据集上对SAM进行训练,但是mask并不是一种十分丰富的数据,作者针对此问题的解决方案是构建一个“数据引擎”。 即我们与模型在环数据注释相结合(见下图)。数据引擎分为三个阶段:辅助手动、半自动和全自动。
- 在第一阶段,SAM协助标注者标注掩码,类似于经典的交互式分割设置。
- 在第二阶段,SAM可以通过提示可能的对象位置自动生成一部分对象的掩码,标注者专注于标注剩余的对象,有助于增加掩码的多样性。
- 在最后阶段,我们用前景点的常规网格提示SAM,在平均每张图像上产生约100个高质量的掩码。具体来说,作者在图像上应用了一个常规的网格,每个网格点都被视为可能的前景点,然后将这些网格点作为提示输入给SAM模型。通过这种方式,他们可以在平均每张图像上生成约100个高质量的掩码,这些掩码可以用于模型的训练和改进。
"前景点"通常指图像分割任务中的感兴趣区域或对象的边界点或像素。在图像分割中,我们通常希望将图像中的不同物体或区域分割开来,其中 前景点指的是这些物体或区域的边缘或轮廓上的像素点。通过使用前景点作为提示,模型可以更好地理解图像中不同物体或区域的位置和形状,从而更准确地生成相应的掩码或分割结果。
1.3 SAM大模型网络结构
SAM 具有三个组件,如下图所示:图像编码器、提示编码器和快速mask解码器。
作者使用ViT模型最低限度地处理高分辨率的图像输入。图像编码器每张图像运行一次,并且可以在提示模型之前应用。考虑到有两组提示:稀疏(点、框、文本)和密集(mask)。作者通过位置编码 来表示点和框,并使用 CLIP 中现成的文本编码器对每种提示类型和自由格式文本进行学习嵌入。使用卷积嵌入密集提示(即mask),并与图像嵌入按元素求和。
CLIP(Contrastive Language-Image Pretraining)是一种由OpenAI提出的模型架构,用于处理图像和文本之间的交互。CLIP的独特之处在于,它通过对图像和文本之间的对比学习进行预训练,从而使模型能够理解图像和文本之间的语义关联。CLIP可以在没有任何监督标签的情况下学习,因此具有很强的泛化能力,可以用于各种各样的视觉和语言任务。
mask解码器有效地将图像嵌入、提示嵌入和输出标记映射到mask。该设计对 Transformer 解码器块的修改,后跟动态mask预测头。修改后的解码器块使用两个方向的即时自注意力和交叉注意力来更新所有嵌入。运行两个块后,我们对图像嵌入进行上采样,并且 MLP 将输出标记映射到动态线性分类器,然后计算每个图像位置的mask前景概率。
即时自注意力:即时自注意力(self-attention)是一种机制,用于处理序列数据(如自然语言或时间序列数据),在该机制中,每个输入位置都与其他位置进行交互,并计算出它们之间的注意力分布。这使得模型能够在不同位置之间建立长程依赖关系,并且能够捕捉输入序列中的重要信息。
交叉自注意力:交叉自注意力(cross-attention)是一种变体,用于处理不同类型的输入之间的交互,通常用于将不同模态的输入(如图像和文本)结合起来。在交叉自注意力中,一个输入序列(如文本)被用作查询向量,而另一个输入序列(如图像)被用作键和值向量。然后,根据查询向量和键向量之间的相似度,计算出注意力权重,并将这些权重应用于值向量,以生成最终的表示。这使得模型能够在不同模态的输入之间进行有效的交互和信息传递。
对于一个输出,如果给出模糊的提示,模型将对多个有效mask进行平均。为了解决这个问题,作者修改模型以预测单个提示的多个输出mask(见下图,注意每张图中红线圈住的区域)。我们发现 3 个mask输出足以解决最常见的情况(嵌套mask通常最多三层深度:整体、部分和子部分)。在训练期间,我们仅反向传播mask上的最小损失。为了对mask进行排名,模型预测每个mask的置信度得分(即估计的 IoU)。
整体模型设计很大程度上是出于效率的考虑。给定预先计算的图像嵌入,提示编码器和掩码解码器在 Web 浏览器中的 CPU 上运行,耗时约 50 毫秒。
1.4 实验
【注】mIoU指的是平均交并比(mean Intersection over Union),是一种常用于评估图像分割模型性能的指标。 在计算mIoU时,首先计算每个类别的IoU(Intersection over Union),然后将所有类别的IoU取平均值。
作者的实验考虑了五个任务,其中四个与用于训练 SAM 的提示分割任务显着不同。首先测试可提示分割的核心目标:根据任何提示生成有效的掩码。提示 SAM 执行边缘检测、分割所有内容,即对象提议生成,分割检测到的对象,即实例分割,分割来自自由格式文本的对象。这四个任务与 SAM 接受训练并通过即时工程实现的即时分割任务显着不同。
1.4.1 零样本单点有效掩码评估
我们通过人类研究来补充标准 mIoU 指标(即预测掩模和真实掩模之间所有 IoU 的平均值),其中标注者将掩模质量从 1(无意义)到 10(像素完美)进行评级。由于 SAM 能够预测多个掩模,因此默认情况下我们仅评估模型最置信度的掩模。基线都是单掩模方法。我们主要与 RITM 进行比较,RITM 是一种强大的交互式分段器,与其他强大的基线相比,它在我们的基准上表现最好。
数据集采用新编译的包含 23 个具有不同图像分布的数据集的套件。下图列出了数据集并显示了每个数据集的样本。作者使用所有 23 个数据集进行 mIoU 评估。
下图显示了额外的基线 SimpleClick 和 FocalClick,它们的单点性能低于 RITM 和 SAM。随着点数从 1 增加到 9,我们可以观察到方法之间的差距缩小。此外,SAM 并未针对非常高的 IoU 状态进行优化。最后,在图 d 中,采用了随机点采样替换默认的中心点采样。我们观察到 SAM 与基线之间的差距越来越大,并且 SAM 在任一采样方法下都能够获得可比较的结果。
1.4.2 零样本边缘检测
作者在下图中可视化了代表性边缘图。定性地,我们观察到即使 SAM 没有经过边缘检测训练,它也会产生合理的边缘图。与真实值相比,SAM 预测了更多边缘,包括 BSDS500 中未注释的合理边缘。
1.4.3 零样本对象提议
在下中,使用 ViTDet-H 的检测作为对象建议总体表现最佳。然而,SAM 在几个指标上都表现得非常好。值得注意的是,它在中型和大型物体以及稀有和常见物体上的性能优于 ViTDet-H。事实上,SAM 仅在小对象和频繁对象上表现不佳 ViTDet-H,其中 ViTDet-H 可以轻松学习 LVIS 特定注释偏差,因为它是在 LVIS 上训练的,与 SAM 不同。
1.4.4 零样本通过文本提示预测mask
从自由格式文本中分割对象。该实验是 SAM 处理文本提示能力的概念验证。具体来说,对于每个手动收集的面积大于 1002 的掩模,我们提取 CLIP 图像嵌入。然后,在训练过程中,我们提示 SAM 使用提取的 CLIP 图像嵌入作为其第一次交互。这里的关键观察是,由于 CLIP 的图像嵌入经过训练以与其文本嵌入对齐,因此我们可以使用图像嵌入进行训练,但使用文本嵌入进行推理。也就是说,在推理时,我们通过 CLIP 的文本编码器运行文本,然后将生成的文本嵌入作为 SAM 的提示。
下图中显示了定性结果。SAM 可以根据简单的文本提示(如“轮子”)以及短语(如“海狸齿格栅”)来分割对象。
【注】"Ablations"通常指的是在科学研究中对模型或实验进行的剖析或削弱,以便评估其各个组成部分对整体性能的影响。在机器学习和深度学习领域,ablation通常指通过移除或禁用模型的某些部分来研究它们对模型性能的影响。通过ablations实验,研究人员可以更好地理解模型的工作原理,并确定哪些组件对于模型的有效性和鲁棒性最为关键。
1.5 SAM模型代码实现
SAM模型
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.import torch
from torch import nn
from torch.nn import functional as Ffrom typing import Any, Dict, List, Tuplefrom .image_encoder import ImageEncoderViT
from .mask_decoder import MaskDecoder
from .prompt_encoder import PromptEncoderclass Sam(nn.Module):mask_threshold: float = 0.0image_format: str = "RGB"def __init__(self,image_encoder: ImageEncoderViT,prompt_encoder: PromptEncoder,mask_decoder: MaskDecoder,pixel_mean: List[float] = [123.675, 116.28, 103.53],pixel_std: List[float] = [58.395, 57.12, 57.375],) -> None:"""SAM predicts object masks from an image and input prompts.Arguments:image_encoder (ImageEncoderViT): The backbone used to encode theimage into image embeddings that allow for efficient mask prediction.prompt_encoder (PromptEncoder): Encodes various types of input prompts.mask_decoder (MaskDecoder): Predicts masks from the image embeddingsand encoded prompts.pixel_mean (list(float)): Mean values for normalizing pixels in the input image.pixel_std (list(float)): Std values for normalizing pixels in the input image."""super().__init__()self.image_encoder = image_encoderself.prompt_encoder = prompt_encoderself.mask_decoder = mask_decoderself.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False)self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)@propertydef device(self) -> Any:return self.pixel_mean.device@torch.no_grad()def forward(self,batched_input: List[Dict[str, Any]],multimask_output: bool,) -> List[Dict[str, torch.Tensor]]:input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0)image_embeddings = self.image_encoder(input_images)outputs = []for image_record, curr_embedding in zip(batched_input, image_embeddings):if "point_coords" in image_record:points = (image_record["point_coords"], image_record["point_labels"])else:points = Nonesparse_embeddings, dense_embeddings = self.prompt_encoder(points=points,boxes=image_record.get("boxes", None),masks=image_record.get("mask_inputs", None),)low_res_masks, iou_predictions = self.mask_decoder(image_embeddings=curr_embedding.unsqueeze(0),image_pe=self.prompt_encoder.get_dense_pe(),sparse_prompt_embeddings=sparse_embeddings,dense_prompt_embeddings=dense_embeddings,multimask_output=multimask_output,)masks = self.postprocess_masks(low_res_masks,input_size=image_record["image"].shape[-2:],original_size=image_record["original_size"],)masks = masks > self.mask_thresholdoutputs.append({"masks": masks,"iou_predictions": iou_predictions,"low_res_logits": low_res_masks,})return outputsdef postprocess_masks(self,masks: torch.Tensor,input_size: Tuple[int, ...],original_size: Tuple[int, ...],) -> torch.Tensor:masks = F.interpolate(masks,(self.image_encoder.img_size, self.image_encoder.img_size),mode="bilinear",align_corners=False,)masks = masks[..., : input_size[0], : input_size[1]]masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False)return masksdef preprocess(self, x: torch.Tensor) -> torch.Tensor:# Normalize colorsx = (x - self.pixel_mean) / self.pixel_std# Padh, w = x.shape[-2:]padh = self.image_encoder.img_size - hpadw = self.image_encoder.img_size - wx = F.pad(x, (0, padw, 0, padh))return x
image_encoder代码实现
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.import torch
import torch.nn as nn
import torch.nn.functional as Ffrom typing import Optional, Tuple, Typefrom .common import LayerNorm2d, MLPBlock# This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa
class ImageEncoderViT(nn.Module):def __init__(self,img_size: int = 1024,patch_size: int = 16,in_chans: int = 3,embed_dim: int = 768,depth: int = 12,num_heads: int = 12,mlp_ratio: float = 4.0,out_chans: int = 256,qkv_bias: bool = True,norm_layer: Type[nn.Module] = nn.LayerNorm,act_layer: Type[nn.Module] = nn.GELU,use_abs_pos: bool = True,use_rel_pos: bool = False,rel_pos_zero_init: bool = True,window_size: int = 0,global_attn_indexes: Tuple[int, ...] = (),) -> None:super().__init__()self.img_size = img_sizeself.patch_embed = PatchEmbed(kernel_size=(patch_size, patch_size),stride=(patch_size, patch_size),in_chans=in_chans,embed_dim=embed_dim,)self.pos_embed: Optional[nn.Parameter] = Noneif use_abs_pos:# Initialize absolute positional embedding with pretrain image size.self.pos_embed = nn.Parameter(torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim))self.blocks = nn.ModuleList()for i in range(depth):block = Block(dim=embed_dim,num_heads=num_heads,mlp_ratio=mlp_ratio,qkv_bias=qkv_bias,norm_layer=norm_layer,act_layer=act_layer,use_rel_pos=use_rel_pos,rel_pos_zero_init=rel_pos_zero_init,window_size=window_size if i not in global_attn_indexes else 0,input_size=(img_size // patch_size, img_size // patch_size),)self.blocks.append(block)self.neck = nn.Sequential(nn.Conv2d(embed_dim,out_chans,kernel_size=1,bias=False,),LayerNorm2d(out_chans),nn.Conv2d(out_chans,out_chans,kernel_size=3,padding=1,bias=False,),LayerNorm2d(out_chans),)def forward(self, x: torch.Tensor) -> torch.Tensor:x = self.patch_embed(x)if self.pos_embed is not None:x = x + self.pos_embedfor blk in self.blocks:x = blk(x)x = self.neck(x.permute(0, 3, 1, 2))return xclass Block(nn.Module):def __init__(self,dim: int,num_heads: int,mlp_ratio: float = 4.0,qkv_bias: bool = True,norm_layer: Type[nn.Module] = nn.LayerNorm,act_layer: Type[nn.Module] = nn.GELU,use_rel_pos: bool = False,rel_pos_zero_init: bool = True,window_size: int = 0,input_size: Optional[Tuple[int, int]] = None,) -> None:super().__init__()self.norm1 = norm_layer(dim)self.attn = Attention(dim,num_heads=num_heads,qkv_bias=qkv_bias,use_rel_pos=use_rel_pos,rel_pos_zero_init=rel_pos_zero_init,input_size=input_size if window_size == 0 else (window_size, window_size),)self.norm2 = norm_layer(dim)self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer)self.window_size = window_sizedef forward(self, x: torch.Tensor) -> torch.Tensor:shortcut = xx = self.norm1(x)# Window partitionif self.window_size > 0:H, W = x.shape[1], x.shape[2]x, pad_hw = window_partition(x, self.window_size)x = self.attn(x)# Reverse window partitionif self.window_size > 0:x = window_unpartition(x, self.window_size, pad_hw, (H, W))x = shortcut + xx = x + self.mlp(self.norm2(x))return xclass Attention(nn.Module):def __init__(self,dim: int,num_heads: int = 8,qkv_bias: bool = True,use_rel_pos: bool = False,rel_pos_zero_init: bool = True,input_size: Optional[Tuple[int, int]] = None,) -> None:super().__init__()self.num_heads = num_headshead_dim = dim // num_headsself.scale = head_dim**-0.5self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)self.proj = nn.Linear(dim, dim)self.use_rel_pos = use_rel_posif self.use_rel_pos:assert (input_size is not None), # initialize relative positional embeddingsself.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))def forward(self, x: torch.Tensor) -> torch.Tensor:B, H, W, _ = x.shape# qkv with shape (3, B, nHead, H * W, C)qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)# q, k, v with shape (B * nHead, H * W, C)q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)attn = (q * self.scale) @ k.transpose(-2, -1)if self.use_rel_pos:attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))attn = attn.softmax(dim=-1)x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)x = self.proj(x)return xdef window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]:B, H, W, C = x.shapepad_h = (window_size - H % window_size) % window_sizepad_w = (window_size - W % window_size) % window_sizeif pad_h > 0 or pad_w > 0:x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))Hp, Wp = H + pad_h, W + pad_wx = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)return windows, (Hp, Wp)def window_unpartition(windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int]
) -> torch.Tensor:Hp, Wp = pad_hwH, W = hwB = windows.shape[0] // (Hp * Wp // window_size // window_size)x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)if Hp > H or Wp > W:x = x[:, :H, :W, :].contiguous()return xdef get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:max_rel_dist = int(2 * max(q_size, k_size) - 1)# Interpolate rel pos if needed.if rel_pos.shape[0] != max_rel_dist:# Interpolate rel pos.rel_pos_resized = F.interpolate(rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),size=max_rel_dist,mode="linear",)rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)else:rel_pos_resized = rel_pos# Scale the coords with short length if shapes for q and k are different.q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)return rel_pos_resized[relative_coords.long()]def add_decomposed_rel_pos(attn: torch.Tensor,q: torch.Tensor,rel_pos_h: torch.Tensor,rel_pos_w: torch.Tensor,q_size: Tuple[int, int],k_size: Tuple[int, int],
) -> torch.Tensor:q_h, q_w = q_sizek_h, k_w = k_sizeRh = get_rel_pos(q_h, k_h, rel_pos_h)Rw = get_rel_pos(q_w, k_w, rel_pos_w)B, _, dim = q.shaper_q = q.reshape(B, q_h, q_w, dim)rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)attn = (attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]).view(B, q_h * q_w, k_h * k_w)return attnclass PatchEmbed(nn.Module):def __init__(self,kernel_size: Tuple[int, int] = (16, 16),stride: Tuple[int, int] = (16, 16),padding: Tuple[int, int] = (0, 0),in_chans: int = 3,embed_dim: int = 768,) -> None:super().__init__()self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding)def forward(self, x: torch.Tensor) -> torch.Tensor:x = self.proj(x)# B C H W -> B H W Cx = x.permute(0, 2, 3, 1)return x
mask_encoder代码实现
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.import torch
from torch import nn
from torch.nn import functional as Ffrom typing import List, Tuple, Typefrom .common import LayerNorm2dclass MaskDecoder(nn.Module):def __init__(self,*,transformer_dim: int,transformer: nn.Module,num_multimask_outputs: int = 3,activation: Type[nn.Module] = nn.GELU,iou_head_depth: int = 3,iou_head_hidden_dim: int = 256,) -> None:super().__init__()self.transformer_dim = transformer_dimself.transformer = transformerself.num_multimask_outputs = num_multimask_outputsself.iou_token = nn.Embedding(1, transformer_dim)self.num_mask_tokens = num_multimask_outputs + 1self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)self.output_upscaling = nn.Sequential(nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),LayerNorm2d(transformer_dim // 4),activation(),nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),activation(),)self.output_hypernetworks_mlps = nn.ModuleList([MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3)for i in range(self.num_mask_tokens)])self.iou_prediction_head = MLP(transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth)def forward(self,image_embeddings: torch.Tensor,image_pe: torch.Tensor,sparse_prompt_embeddings: torch.Tensor,dense_prompt_embeddings: torch.Tensor,multimask_output: bool,) -> Tuple[torch.Tensor, torch.Tensor]:masks, iou_pred = self.predict_masks(image_embeddings=image_embeddings,image_pe=image_pe,sparse_prompt_embeddings=sparse_prompt_embeddings,dense_prompt_embeddings=dense_prompt_embeddings,)# Select the correct mask or masks for outputif multimask_output:mask_slice = slice(1, None)else:mask_slice = slice(0, 1)masks = masks[:, mask_slice, :, :]iou_pred = iou_pred[:, mask_slice]# Prepare outputreturn masks, iou_preddef predict_masks(self,image_embeddings: torch.Tensor,image_pe: torch.Tensor,sparse_prompt_embeddings: torch.Tensor,dense_prompt_embeddings: torch.Tensor,) -> Tuple[torch.Tensor, torch.Tensor]:"""Predicts masks. See 'forward' for more details."""# Concatenate output tokensoutput_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)# Expand per-image data in batch direction to be per-masksrc = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)src = src + dense_prompt_embeddingspos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)b, c, h, w = src.shape# Run the transformerhs, src = self.transformer(src, pos_src, tokens)iou_token_out = hs[:, 0, :]mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]# Upscale mask embeddings and predict masks using the mask tokenssrc = src.transpose(1, 2).view(b, c, h, w)upscaled_embedding = self.output_upscaling(src)hyper_in_list: List[torch.Tensor] = []for i in range(self.num_mask_tokens):hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]))hyper_in = torch.stack(hyper_in_list, dim=1)b, c, h, w = upscaled_embedding.shapemasks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)# Generate mask quality predictionsiou_pred = self.iou_prediction_head(iou_token_out)return masks, iou_pred# Lightly adapted from
# https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa
class MLP(nn.Module):def __init__(self,input_dim: int,hidden_dim: int,output_dim: int,num_layers: int,sigmoid_output: bool = False,) -> None:super().__init__()self.num_layers = num_layersh = [hidden_dim] * (num_layers - 1)self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))self.sigmoid_output = sigmoid_outputdef forward(self, x):for i, layer in enumerate(self.layers):x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)if self.sigmoid_output:x = F.sigmoid(x)return x
prompt_encoder代码实现
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.import numpy as np
import torch
from torch import nnfrom typing import Any, Optional, Tuple, Typefrom .common import LayerNorm2dclass PromptEncoder(nn.Module):def __init__(self,embed_dim: int,image_embedding_size: Tuple[int, int],input_image_size: Tuple[int, int],mask_in_chans: int,activation: Type[nn.Module] = nn.GELU,) -> None:super().__init__()self.embed_dim = embed_dimself.input_image_size = input_image_sizeself.image_embedding_size = image_embedding_sizeself.pe_layer = PositionEmbeddingRandom(embed_dim // 2)self.num_point_embeddings: int = 4 # pos/neg point + 2 box cornerspoint_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)]self.point_embeddings = nn.ModuleList(point_embeddings)self.not_a_point_embed = nn.Embedding(1, embed_dim)self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1])self.mask_downscaling = nn.Sequential(nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2),LayerNorm2d(mask_in_chans // 4),activation(),nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2),LayerNorm2d(mask_in_chans),activation(),nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),)self.no_mask_embed = nn.Embedding(1, embed_dim)def get_dense_pe(self) -> torch.Tensor:return self.pe_layer(self.image_embedding_size).unsqueeze(0)def _embed_points(self,points: torch.Tensor,labels: torch.Tensor,pad: bool,) -> torch.Tensor:"""Embeds point prompts."""points = points + 0.5 # Shift to center of pixelif pad:padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device)padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)points = torch.cat([points, padding_point], dim=1)labels = torch.cat([labels, padding_label], dim=1)point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size)point_embedding[labels == -1] = 0.0point_embedding[labels == -1] += self.not_a_point_embed.weightpoint_embedding[labels == 0] += self.point_embeddings[0].weightpoint_embedding[labels == 1] += self.point_embeddings[1].weightreturn point_embeddingdef _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:"""Embeds box prompts."""boxes = boxes + 0.5 # Shift to center of pixelcoords = boxes.reshape(-1, 2, 2)corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size)corner_embedding[:, 0, :] += self.point_embeddings[2].weightcorner_embedding[:, 1, :] += self.point_embeddings[3].weightreturn corner_embeddingdef _embed_masks(self, masks: torch.Tensor) -> torch.Tensor:"""Embeds mask inputs."""mask_embedding = self.mask_downscaling(masks)return mask_embeddingdef _get_batch_size(self,points: Optional[Tuple[torch.Tensor, torch.Tensor]],boxes: Optional[torch.Tensor],masks: Optional[torch.Tensor],) -> int:"""Gets the batch size of the output given the batch size of the input prompts."""if points is not None:return points[0].shape[0]elif boxes is not None:return boxes.shape[0]elif masks is not None:return masks.shape[0]else:return 1def _get_device(self) -> torch.device:return self.point_embeddings[0].weight.devicedef forward(self,points: Optional[Tuple[torch.Tensor, torch.Tensor]],boxes: Optional[torch.Tensor],masks: Optional[torch.Tensor],) -> Tuple[torch.Tensor, torch.Tensor]:bs = self._get_batch_size(points, boxes, masks)sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device())if points is not None:coords, labels = pointspoint_embeddings = self._embed_points(coords, labels, pad=(boxes is None))sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)if boxes is not None:box_embeddings = self._embed_boxes(boxes)sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1)if masks is not None:dense_embeddings = self._embed_masks(masks)else:dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(bs, -1, self.image_embedding_size[0], self.image_embedding_size[1])return sparse_embeddings, dense_embeddingsclass PositionEmbeddingRandom(nn.Module):"""Positional encoding using random spatial frequencies."""def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:super().__init__()if scale is None or scale <= 0.0:scale = 1.0self.register_buffer("positional_encoding_gaussian_matrix",scale * torch.randn((2, num_pos_feats)),)def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:"""Positionally encode points that are normalized to [0,1]."""# assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shapecoords = 2 * coords - 1coords = coords @ self.positional_encoding_gaussian_matrixcoords = 2 * np.pi * coords# outputs d_1 x ... x d_n x C shapereturn torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)def forward(self, size: Tuple[int, int]) -> torch.Tensor:"""Generate positional encoding for a grid of the specified size."""h, w = sizedevice: Any = self.positional_encoding_gaussian_matrix.devicegrid = torch.ones((h, w), device=device, dtype=torch.float32)y_embed = grid.cumsum(dim=0) - 0.5x_embed = grid.cumsum(dim=1) - 0.5y_embed = y_embed / hx_embed = x_embed / wpe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1))return pe.permute(2, 0, 1) # C x H x Wdef forward_with_coords(self, coords_input: torch.Tensor, image_size: Tuple[int, int]) -> torch.Tensor:"""Positionally encode points that are not normalized to [0,1]."""coords = coords_input.clone()coords[:, :, 0] = coords[:, :, 0] / image_size[1]coords[:, :, 1] = coords[:, :, 1] / image_size[0]return self._pe_encoding(coords.to(torch.float)) # B x N x C
2. 总结
本周学习了SAM大模型,该模型的设计和训练具有快速性,它可以将 零样本 转移到新的图像分布和任务,通过点、框、mask进行分割。同时也对SAM代码的实现有了一定的理解,下周我会继续努力。