文献阅读笔记:SAM大模型(Segment Anything)

文献阅读笔记:SAM大模型(Segment Anything)

  • 摘要
  • Abstract
  • 1. SAM大模型
    • 1.1 文献摘要
    • 1.2 引言
    • 1.3 SAM大模型网络结构
    • 1.4 实验
      • 1.4.1 零样本单点有效掩码评估
      • 1.4.2 零样本边缘检测
      • 1.4.3 零样本对象提议
      • 1.4.4 零样本通过文本提示预测mask
    • 1.5 SAM模型代码实现
    • 2. 总结

摘要

本周学习了SAM大模型,该大模型用于图像分割的新任务、模型和数据集。该模型的设计和训练具有快速性,因此它可以将 零样本 转移到新的图像分布和任务。通过评估SAM在众多任务上的能力,该大模型的零样本性能与之前完全监督的结果相媲美甚至优于之前的结果。本文将详细介绍SAM模型。

Abstract

This week, we learned about the SAM Big Model, a new task, model, and dataset for image segmentation. The model is designed and trained with rapidity so that it can transfer zero samples to new image distributions and tasks. By evaluating SAM’s capabilities on a wide range of tasks, the zero-sample performance of this large model matches or even outperforms previous fully supervised results. This paper describes the SAM model in detail.

1. SAM大模型

论文题目:Segment Anything

1.1 文献摘要

作者在这篇文献中介绍了SAM大模型:该大模型用于图像分割的新任务、模型和数据集。作者构建了迄今为止最大的分割数据集,在 1100 万张许可且尊重隐私的图像上包含超过 10 亿个mask。该模型的设计和训练具有快速性,因此它可以将 零样本 转移到新的图像分布和任务。通过评估SAM在众多任务上的能力,该大模型的零样本性能与之前完全监督的结果相媲美甚至优于之前的结果。

1.2 引言

SAM大模型通过大规模的数据集进行预训练后,在各个场景可以通过零样本来进行。此方式可以通过手动文本的提示来进行,也可以通过用户操作后,向大模型传递一个目标点、目标框来进行响应。而大模型的精确度会随着模型的规模、数据集大小和训练次数增多而进一步改善。SAM大模型经过训练后,设计的文本提示可以零样本概括新的视觉概念和数据分布。这种编码器还可以与其他模块有效组合,以实现下游任务,所谓下游任务是指建立在已训练模型之上的任务,通常是使用预训练模型的中间表示(例如隐藏层的输出)作为输入。这些下游任务可以是分类、回归、序列标注、文本生成等,旨在利用预训练模型学到的知识来解决特定的问题。例如图像生成(例如,DALL·E )。

作者在这篇文献中提出了建立图像分割的基础模型。SAM大模型是一个可提示的模型,并使用能够实现强大泛化的任务在广泛的数据集上对其进行预训练。 作者提出了提示分割任务,其目标是在给定任何分割提示的情况下返回有效的分割mask(分割区域和对象)(见下图)。提示只是指定在图像中分割什么,例如,提示可以包括识别对象的区域或文本信息。有效输出mask的要求意味着,即使提示不明确并且可能引用多个对象(例如,衬衫上的点可能表示衬衫或穿着它的人),输出也应该是合理的mask,至少其中一个对象。
在这里插入图片描述
由于该SAM模型必须支持灵活的提示,需要实时计算和使用mask以允许交互使用,并且必须具有模糊性意识。作者通过将 SAM 分成图像编码器和快速提示编码器/mask解码器,可以通过不同的提示重复使用相同的图像嵌入。给定图像嵌入,提示编码器和mask解码器在网络浏览器根据提示预测掩码。提示信息可以通过用户标点、框和用户随机圈出的区域,并且还使用自由格式的文本提示来呈现初步结果。为了使 SAM 能够感知歧义,作者将其设计为预测单个提示的多个mask,从而使 SAM 能够自然地处理歧义,例如衬衫与人的示例。在这里插入图片描述
为了实现对新数据分布的强大泛化能力,需要在大型和多样化的mask数据集上对SAM进行训练,但是mask并不是一种十分丰富的数据,作者针对此问题的解决方案是构建一个“数据引擎”。 即我们与模型在环数据注释相结合(见下图)。数据引擎分为三个阶段:辅助手动、半自动和全自动。

  1. 在第一阶段,SAM协助标注者标注掩码,类似于经典的交互式分割设置。
  2. 在第二阶段,SAM可以通过提示可能的对象位置自动生成一部分对象的掩码,标注者专注于标注剩余的对象,有助于增加掩码的多样性。
  3. 在最后阶段,我们用前景点的常规网格提示SAM,在平均每张图像上产生约100个高质量的掩码。具体来说,作者在图像上应用了一个常规的网格,每个网格点都被视为可能的前景点,然后将这些网格点作为提示输入给SAM模型。通过这种方式,他们可以在平均每张图像上生成约100个高质量的掩码,这些掩码可以用于模型的训练和改进。

"前景点"通常指图像分割任务中的感兴趣区域或对象的边界点或像素。在图像分割中,我们通常希望将图像中的不同物体或区域分割开来,其中 前景点指的是这些物体或区域的边缘或轮廓上的像素点。通过使用前景点作为提示,模型可以更好地理解图像中不同物体或区域的位置和形状,从而更准确地生成相应的掩码或分割结果。

在这里插入图片描述

1.3 SAM大模型网络结构

SAM 具有三个组件,如下图所示:图像编码器、提示编码器和快速mask解码器。
在这里插入图片描述

作者使用ViT模型最低限度地处理高分辨率的图像输入。图像编码器每张图像运行一次,并且可以在提示模型之前应用。考虑到有两组提示:稀疏(点、框、文本)和密集(mask)。作者通过位置编码 来表示点和框,并使用 CLIP 中现成的文本编码器对每种提示类型和自由格式文本进行学习嵌入。使用卷积嵌入密集提示(即mask),并与图像嵌入按元素求和。

CLIP(Contrastive Language-Image Pretraining)是一种由OpenAI提出的模型架构,用于处理图像和文本之间的交互。CLIP的独特之处在于,它通过对图像和文本之间的对比学习进行预训练,从而使模型能够理解图像和文本之间的语义关联。CLIP可以在没有任何监督标签的情况下学习,因此具有很强的泛化能力,可以用于各种各样的视觉和语言任务。

mask解码器有效地将图像嵌入、提示嵌入和输出标记映射到mask。该设计对 Transformer 解码器块的修改,后跟动态mask预测头。修改后的解码器块使用两个方向的即时自注意力和交叉注意力来更新所有嵌入。运行两个块后,我们对图像嵌入进行上采样,并且 MLP 将输出标记映射到动态线性分类器,然后计算每个图像位置的mask前景概率。

即时自注意力:即时自注意力(self-attention)是一种机制,用于处理序列数据(如自然语言或时间序列数据),在该机制中,每个输入位置都与其他位置进行交互,并计算出它们之间的注意力分布。这使得模型能够在不同位置之间建立长程依赖关系,并且能够捕捉输入序列中的重要信息。
交叉自注意力:交叉自注意力(cross-attention)是一种变体,用于处理不同类型的输入之间的交互,通常用于将不同模态的输入(如图像和文本)结合起来。在交叉自注意力中,一个输入序列(如文本)被用作查询向量,而另一个输入序列(如图像)被用作键和值向量。然后,根据查询向量和键向量之间的相似度,计算出注意力权重,并将这些权重应用于值向量,以生成最终的表示。这使得模型能够在不同模态的输入之间进行有效的交互和信息传递。

对于一个输出,如果给出模糊的提示,模型将对多个有效mask进行平均。为了解决这个问题,作者修改模型以预测单个提示的多个输出mask(见下图,注意每张图中红线圈住的区域)。我们发现 3 个mask输出足以解决最常见的情况(嵌套mask通常最多三层深度:整体、部分和子部分)。在训练期间,我们仅反向传播mask上的最小损失。为了对mask进行排名,模型预测每个mask的置信度得分(即估计的 IoU)。
在这里插入图片描述
整体模型设计很大程度上是出于效率的考虑。给定预先计算的图像嵌入,提示编码器和掩码解码器在 Web 浏览器中的 CPU 上运行,耗时约 50 毫秒。

1.4 实验

【注】mIoU指的是平均交并比(mean Intersection over Union),是一种常用于评估图像分割模型性能的指标。 在计算mIoU时,首先计算每个类别的IoU(Intersection over Union),然后将所有类别的IoU取平均值。

作者的实验考虑了五个任务,其中四个与用于训练 SAM 的提示分割任务显着不同。首先测试可提示分割的核心目标:根据任何提示生成有效的掩码。提示 SAM 执行边缘检测、分割所有内容,即对象提议生成,分割检测到的对象,即实例分割,分割来自自由格式文本的对象。这四个任务与 SAM 接受训练并通过即时工程实现的即时分割任务显着不同。

1.4.1 零样本单点有效掩码评估

我们通过人类研究来补充标准 mIoU 指标(即预测掩模和真实掩模之间所有 IoU 的平均值),其中标注者将掩模质量从 1(无意义)到 10(像素完美)进行评级。由于 SAM 能够预测多个掩模,因此默认情况下我们仅评估模型最置信度的掩模。基线都是单掩模方法。我们主要与 RITM 进行比较,RITM 是一种强大的交互式分段器,与其他强大的基线相比,它在我们的基准上表现最好。

数据集采用新编译的包含 23 个具有不同图像分布的数据集的套件。下图列出了数据集并显示了每个数据集的样本。作者使用所有 23 个数据集进行 mIoU 评估。
在这里插入图片描述
下图显示了额外的基线 SimpleClick 和 FocalClick,它们的单点性能低于 RITM 和 SAM。随着点数从 1 增加到 9,我们可以观察到方法之间的差距缩小。此外,SAM 并未针对非常高的 IoU 状态进行优化。最后,在图 d 中,采用了随机点采样替换默认的中心点采样。我们观察到 SAM 与基线之间的差距越来越大,并且 SAM 在任一采样方法下都能够获得可比较的结果。
在这里插入图片描述

1.4.2 零样本边缘检测

作者在下图中可视化了代表性边缘图。定性地,我们观察到即使 SAM 没有经过边缘检测训练,它也会产生合理的边缘图。与真实值相比,SAM 预测了更多边缘,包括 BSDS500 中未注释的合理边缘。在这里插入图片描述

1.4.3 零样本对象提议

在下中,使用 ViTDet-H 的检测作为对象建议总体表现最佳。然而,SAM 在几个指标上都表现得非常好。值得注意的是,它在中型和大型物体以及稀有和常见物体上的性能优于 ViTDet-H。事实上,SAM 仅在小对象和频繁对象上表现不佳 ViTDet-H,其中 ViTDet-H 可以轻松学习 LVIS 特定注释偏差,因为它是在 LVIS 上训练的,与 SAM 不同。

1.4.4 零样本通过文本提示预测mask

从自由格式文本中分割对象。该实验是 SAM 处理文本提示能力的概念验证。具体来说,对于每个手动收集的面积大于 1002 的掩模,我们提取 CLIP 图像嵌入。然后,在训练过程中,我们提示 SAM 使用提取的 CLIP 图像嵌入作为其第一次交互。这里的关键观察是,由于 CLIP 的图像嵌入经过训练以与其文本嵌入对齐,因此我们可以使用图像嵌入进行训练,但使用文本嵌入进行推理。也就是说,在推理时,我们通过 CLIP 的文本编码器运行文本,然后将生成的文本嵌入作为 SAM 的提示。

下图中显示了定性结果。SAM 可以根据简单的文本提示(如“轮子”)以及短语(如“海狸齿格栅”)来分割对象。

在这里插入图片描述
【注】"Ablations"通常指的是在科学研究中对模型或实验进行的剖析或削弱,以便评估其各个组成部分对整体性能的影响。在机器学习和深度学习领域,ablation通常指通过移除或禁用模型的某些部分来研究它们对模型性能的影响。通过ablations实验,研究人员可以更好地理解模型的工作原理,并确定哪些组件对于模型的有效性和鲁棒性最为关键。

1.5 SAM模型代码实现

SAM模型

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.import torch
from torch import nn
from torch.nn import functional as Ffrom typing import Any, Dict, List, Tuplefrom .image_encoder import ImageEncoderViT
from .mask_decoder import MaskDecoder
from .prompt_encoder import PromptEncoderclass Sam(nn.Module):mask_threshold: float = 0.0image_format: str = "RGB"def __init__(self,image_encoder: ImageEncoderViT,prompt_encoder: PromptEncoder,mask_decoder: MaskDecoder,pixel_mean: List[float] = [123.675, 116.28, 103.53],pixel_std: List[float] = [58.395, 57.12, 57.375],) -> None:"""SAM predicts object masks from an image and input prompts.Arguments:image_encoder (ImageEncoderViT): The backbone used to encode theimage into image embeddings that allow for efficient mask prediction.prompt_encoder (PromptEncoder): Encodes various types of input prompts.mask_decoder (MaskDecoder): Predicts masks from the image embeddingsand encoded prompts.pixel_mean (list(float)): Mean values for normalizing pixels in the input image.pixel_std (list(float)): Std values for normalizing pixels in the input image."""super().__init__()self.image_encoder = image_encoderself.prompt_encoder = prompt_encoderself.mask_decoder = mask_decoderself.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False)self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)@propertydef device(self) -> Any:return self.pixel_mean.device@torch.no_grad()def forward(self,batched_input: List[Dict[str, Any]],multimask_output: bool,) -> List[Dict[str, torch.Tensor]]:input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0)image_embeddings = self.image_encoder(input_images)outputs = []for image_record, curr_embedding in zip(batched_input, image_embeddings):if "point_coords" in image_record:points = (image_record["point_coords"], image_record["point_labels"])else:points = Nonesparse_embeddings, dense_embeddings = self.prompt_encoder(points=points,boxes=image_record.get("boxes", None),masks=image_record.get("mask_inputs", None),)low_res_masks, iou_predictions = self.mask_decoder(image_embeddings=curr_embedding.unsqueeze(0),image_pe=self.prompt_encoder.get_dense_pe(),sparse_prompt_embeddings=sparse_embeddings,dense_prompt_embeddings=dense_embeddings,multimask_output=multimask_output,)masks = self.postprocess_masks(low_res_masks,input_size=image_record["image"].shape[-2:],original_size=image_record["original_size"],)masks = masks > self.mask_thresholdoutputs.append({"masks": masks,"iou_predictions": iou_predictions,"low_res_logits": low_res_masks,})return outputsdef postprocess_masks(self,masks: torch.Tensor,input_size: Tuple[int, ...],original_size: Tuple[int, ...],) -> torch.Tensor:masks = F.interpolate(masks,(self.image_encoder.img_size, self.image_encoder.img_size),mode="bilinear",align_corners=False,)masks = masks[..., : input_size[0], : input_size[1]]masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False)return masksdef preprocess(self, x: torch.Tensor) -> torch.Tensor:# Normalize colorsx = (x - self.pixel_mean) / self.pixel_std# Padh, w = x.shape[-2:]padh = self.image_encoder.img_size - hpadw = self.image_encoder.img_size - wx = F.pad(x, (0, padw, 0, padh))return x

image_encoder代码实现

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.import torch
import torch.nn as nn
import torch.nn.functional as Ffrom typing import Optional, Tuple, Typefrom .common import LayerNorm2d, MLPBlock# This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa
class ImageEncoderViT(nn.Module):def __init__(self,img_size: int = 1024,patch_size: int = 16,in_chans: int = 3,embed_dim: int = 768,depth: int = 12,num_heads: int = 12,mlp_ratio: float = 4.0,out_chans: int = 256,qkv_bias: bool = True,norm_layer: Type[nn.Module] = nn.LayerNorm,act_layer: Type[nn.Module] = nn.GELU,use_abs_pos: bool = True,use_rel_pos: bool = False,rel_pos_zero_init: bool = True,window_size: int = 0,global_attn_indexes: Tuple[int, ...] = (),) -> None:super().__init__()self.img_size = img_sizeself.patch_embed = PatchEmbed(kernel_size=(patch_size, patch_size),stride=(patch_size, patch_size),in_chans=in_chans,embed_dim=embed_dim,)self.pos_embed: Optional[nn.Parameter] = Noneif use_abs_pos:# Initialize absolute positional embedding with pretrain image size.self.pos_embed = nn.Parameter(torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim))self.blocks = nn.ModuleList()for i in range(depth):block = Block(dim=embed_dim,num_heads=num_heads,mlp_ratio=mlp_ratio,qkv_bias=qkv_bias,norm_layer=norm_layer,act_layer=act_layer,use_rel_pos=use_rel_pos,rel_pos_zero_init=rel_pos_zero_init,window_size=window_size if i not in global_attn_indexes else 0,input_size=(img_size // patch_size, img_size // patch_size),)self.blocks.append(block)self.neck = nn.Sequential(nn.Conv2d(embed_dim,out_chans,kernel_size=1,bias=False,),LayerNorm2d(out_chans),nn.Conv2d(out_chans,out_chans,kernel_size=3,padding=1,bias=False,),LayerNorm2d(out_chans),)def forward(self, x: torch.Tensor) -> torch.Tensor:x = self.patch_embed(x)if self.pos_embed is not None:x = x + self.pos_embedfor blk in self.blocks:x = blk(x)x = self.neck(x.permute(0, 3, 1, 2))return xclass Block(nn.Module):def __init__(self,dim: int,num_heads: int,mlp_ratio: float = 4.0,qkv_bias: bool = True,norm_layer: Type[nn.Module] = nn.LayerNorm,act_layer: Type[nn.Module] = nn.GELU,use_rel_pos: bool = False,rel_pos_zero_init: bool = True,window_size: int = 0,input_size: Optional[Tuple[int, int]] = None,) -> None:super().__init__()self.norm1 = norm_layer(dim)self.attn = Attention(dim,num_heads=num_heads,qkv_bias=qkv_bias,use_rel_pos=use_rel_pos,rel_pos_zero_init=rel_pos_zero_init,input_size=input_size if window_size == 0 else (window_size, window_size),)self.norm2 = norm_layer(dim)self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer)self.window_size = window_sizedef forward(self, x: torch.Tensor) -> torch.Tensor:shortcut = xx = self.norm1(x)# Window partitionif self.window_size > 0:H, W = x.shape[1], x.shape[2]x, pad_hw = window_partition(x, self.window_size)x = self.attn(x)# Reverse window partitionif self.window_size > 0:x = window_unpartition(x, self.window_size, pad_hw, (H, W))x = shortcut + xx = x + self.mlp(self.norm2(x))return xclass Attention(nn.Module):def __init__(self,dim: int,num_heads: int = 8,qkv_bias: bool = True,use_rel_pos: bool = False,rel_pos_zero_init: bool = True,input_size: Optional[Tuple[int, int]] = None,) -> None:super().__init__()self.num_heads = num_headshead_dim = dim // num_headsself.scale = head_dim**-0.5self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)self.proj = nn.Linear(dim, dim)self.use_rel_pos = use_rel_posif self.use_rel_pos:assert (input_size is not None), # initialize relative positional embeddingsself.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))def forward(self, x: torch.Tensor) -> torch.Tensor:B, H, W, _ = x.shape# qkv with shape (3, B, nHead, H * W, C)qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)# q, k, v with shape (B * nHead, H * W, C)q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)attn = (q * self.scale) @ k.transpose(-2, -1)if self.use_rel_pos:attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))attn = attn.softmax(dim=-1)x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)x = self.proj(x)return xdef window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]:B, H, W, C = x.shapepad_h = (window_size - H % window_size) % window_sizepad_w = (window_size - W % window_size) % window_sizeif pad_h > 0 or pad_w > 0:x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))Hp, Wp = H + pad_h, W + pad_wx = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)return windows, (Hp, Wp)def window_unpartition(windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int]
) -> torch.Tensor:Hp, Wp = pad_hwH, W = hwB = windows.shape[0] // (Hp * Wp // window_size // window_size)x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)if Hp > H or Wp > W:x = x[:, :H, :W, :].contiguous()return xdef get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:max_rel_dist = int(2 * max(q_size, k_size) - 1)# Interpolate rel pos if needed.if rel_pos.shape[0] != max_rel_dist:# Interpolate rel pos.rel_pos_resized = F.interpolate(rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),size=max_rel_dist,mode="linear",)rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)else:rel_pos_resized = rel_pos# Scale the coords with short length if shapes for q and k are different.q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)return rel_pos_resized[relative_coords.long()]def add_decomposed_rel_pos(attn: torch.Tensor,q: torch.Tensor,rel_pos_h: torch.Tensor,rel_pos_w: torch.Tensor,q_size: Tuple[int, int],k_size: Tuple[int, int],
) -> torch.Tensor:q_h, q_w = q_sizek_h, k_w = k_sizeRh = get_rel_pos(q_h, k_h, rel_pos_h)Rw = get_rel_pos(q_w, k_w, rel_pos_w)B, _, dim = q.shaper_q = q.reshape(B, q_h, q_w, dim)rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)attn = (attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]).view(B, q_h * q_w, k_h * k_w)return attnclass PatchEmbed(nn.Module):def __init__(self,kernel_size: Tuple[int, int] = (16, 16),stride: Tuple[int, int] = (16, 16),padding: Tuple[int, int] = (0, 0),in_chans: int = 3,embed_dim: int = 768,) -> None:super().__init__()self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding)def forward(self, x: torch.Tensor) -> torch.Tensor:x = self.proj(x)# B C H W -> B H W Cx = x.permute(0, 2, 3, 1)return x

mask_encoder代码实现

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.import torch
from torch import nn
from torch.nn import functional as Ffrom typing import List, Tuple, Typefrom .common import LayerNorm2dclass MaskDecoder(nn.Module):def __init__(self,*,transformer_dim: int,transformer: nn.Module,num_multimask_outputs: int = 3,activation: Type[nn.Module] = nn.GELU,iou_head_depth: int = 3,iou_head_hidden_dim: int = 256,) -> None:super().__init__()self.transformer_dim = transformer_dimself.transformer = transformerself.num_multimask_outputs = num_multimask_outputsself.iou_token = nn.Embedding(1, transformer_dim)self.num_mask_tokens = num_multimask_outputs + 1self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)self.output_upscaling = nn.Sequential(nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),LayerNorm2d(transformer_dim // 4),activation(),nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),activation(),)self.output_hypernetworks_mlps = nn.ModuleList([MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3)for i in range(self.num_mask_tokens)])self.iou_prediction_head = MLP(transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth)def forward(self,image_embeddings: torch.Tensor,image_pe: torch.Tensor,sparse_prompt_embeddings: torch.Tensor,dense_prompt_embeddings: torch.Tensor,multimask_output: bool,) -> Tuple[torch.Tensor, torch.Tensor]:masks, iou_pred = self.predict_masks(image_embeddings=image_embeddings,image_pe=image_pe,sparse_prompt_embeddings=sparse_prompt_embeddings,dense_prompt_embeddings=dense_prompt_embeddings,)# Select the correct mask or masks for outputif multimask_output:mask_slice = slice(1, None)else:mask_slice = slice(0, 1)masks = masks[:, mask_slice, :, :]iou_pred = iou_pred[:, mask_slice]# Prepare outputreturn masks, iou_preddef predict_masks(self,image_embeddings: torch.Tensor,image_pe: torch.Tensor,sparse_prompt_embeddings: torch.Tensor,dense_prompt_embeddings: torch.Tensor,) -> Tuple[torch.Tensor, torch.Tensor]:"""Predicts masks. See 'forward' for more details."""# Concatenate output tokensoutput_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)# Expand per-image data in batch direction to be per-masksrc = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)src = src + dense_prompt_embeddingspos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)b, c, h, w = src.shape# Run the transformerhs, src = self.transformer(src, pos_src, tokens)iou_token_out = hs[:, 0, :]mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]# Upscale mask embeddings and predict masks using the mask tokenssrc = src.transpose(1, 2).view(b, c, h, w)upscaled_embedding = self.output_upscaling(src)hyper_in_list: List[torch.Tensor] = []for i in range(self.num_mask_tokens):hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]))hyper_in = torch.stack(hyper_in_list, dim=1)b, c, h, w = upscaled_embedding.shapemasks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)# Generate mask quality predictionsiou_pred = self.iou_prediction_head(iou_token_out)return masks, iou_pred# Lightly adapted from
# https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa
class MLP(nn.Module):def __init__(self,input_dim: int,hidden_dim: int,output_dim: int,num_layers: int,sigmoid_output: bool = False,) -> None:super().__init__()self.num_layers = num_layersh = [hidden_dim] * (num_layers - 1)self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))self.sigmoid_output = sigmoid_outputdef forward(self, x):for i, layer in enumerate(self.layers):x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)if self.sigmoid_output:x = F.sigmoid(x)return x

prompt_encoder代码实现

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.import numpy as np
import torch
from torch import nnfrom typing import Any, Optional, Tuple, Typefrom .common import LayerNorm2dclass PromptEncoder(nn.Module):def __init__(self,embed_dim: int,image_embedding_size: Tuple[int, int],input_image_size: Tuple[int, int],mask_in_chans: int,activation: Type[nn.Module] = nn.GELU,) -> None:super().__init__()self.embed_dim = embed_dimself.input_image_size = input_image_sizeself.image_embedding_size = image_embedding_sizeself.pe_layer = PositionEmbeddingRandom(embed_dim // 2)self.num_point_embeddings: int = 4  # pos/neg point + 2 box cornerspoint_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)]self.point_embeddings = nn.ModuleList(point_embeddings)self.not_a_point_embed = nn.Embedding(1, embed_dim)self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1])self.mask_downscaling = nn.Sequential(nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2),LayerNorm2d(mask_in_chans // 4),activation(),nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2),LayerNorm2d(mask_in_chans),activation(),nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),)self.no_mask_embed = nn.Embedding(1, embed_dim)def get_dense_pe(self) -> torch.Tensor:return self.pe_layer(self.image_embedding_size).unsqueeze(0)def _embed_points(self,points: torch.Tensor,labels: torch.Tensor,pad: bool,) -> torch.Tensor:"""Embeds point prompts."""points = points + 0.5  # Shift to center of pixelif pad:padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device)padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)points = torch.cat([points, padding_point], dim=1)labels = torch.cat([labels, padding_label], dim=1)point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size)point_embedding[labels == -1] = 0.0point_embedding[labels == -1] += self.not_a_point_embed.weightpoint_embedding[labels == 0] += self.point_embeddings[0].weightpoint_embedding[labels == 1] += self.point_embeddings[1].weightreturn point_embeddingdef _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:"""Embeds box prompts."""boxes = boxes + 0.5  # Shift to center of pixelcoords = boxes.reshape(-1, 2, 2)corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size)corner_embedding[:, 0, :] += self.point_embeddings[2].weightcorner_embedding[:, 1, :] += self.point_embeddings[3].weightreturn corner_embeddingdef _embed_masks(self, masks: torch.Tensor) -> torch.Tensor:"""Embeds mask inputs."""mask_embedding = self.mask_downscaling(masks)return mask_embeddingdef _get_batch_size(self,points: Optional[Tuple[torch.Tensor, torch.Tensor]],boxes: Optional[torch.Tensor],masks: Optional[torch.Tensor],) -> int:"""Gets the batch size of the output given the batch size of the input prompts."""if points is not None:return points[0].shape[0]elif boxes is not None:return boxes.shape[0]elif masks is not None:return masks.shape[0]else:return 1def _get_device(self) -> torch.device:return self.point_embeddings[0].weight.devicedef forward(self,points: Optional[Tuple[torch.Tensor, torch.Tensor]],boxes: Optional[torch.Tensor],masks: Optional[torch.Tensor],) -> Tuple[torch.Tensor, torch.Tensor]:bs = self._get_batch_size(points, boxes, masks)sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device())if points is not None:coords, labels = pointspoint_embeddings = self._embed_points(coords, labels, pad=(boxes is None))sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)if boxes is not None:box_embeddings = self._embed_boxes(boxes)sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1)if masks is not None:dense_embeddings = self._embed_masks(masks)else:dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(bs, -1, self.image_embedding_size[0], self.image_embedding_size[1])return sparse_embeddings, dense_embeddingsclass PositionEmbeddingRandom(nn.Module):"""Positional encoding using random spatial frequencies."""def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:super().__init__()if scale is None or scale <= 0.0:scale = 1.0self.register_buffer("positional_encoding_gaussian_matrix",scale * torch.randn((2, num_pos_feats)),)def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:"""Positionally encode points that are normalized to [0,1]."""# assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shapecoords = 2 * coords - 1coords = coords @ self.positional_encoding_gaussian_matrixcoords = 2 * np.pi * coords# outputs d_1 x ... x d_n x C shapereturn torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)def forward(self, size: Tuple[int, int]) -> torch.Tensor:"""Generate positional encoding for a grid of the specified size."""h, w = sizedevice: Any = self.positional_encoding_gaussian_matrix.devicegrid = torch.ones((h, w), device=device, dtype=torch.float32)y_embed = grid.cumsum(dim=0) - 0.5x_embed = grid.cumsum(dim=1) - 0.5y_embed = y_embed / hx_embed = x_embed / wpe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1))return pe.permute(2, 0, 1)  # C x H x Wdef forward_with_coords(self, coords_input: torch.Tensor, image_size: Tuple[int, int]) -> torch.Tensor:"""Positionally encode points that are not normalized to [0,1]."""coords = coords_input.clone()coords[:, :, 0] = coords[:, :, 0] / image_size[1]coords[:, :, 1] = coords[:, :, 1] / image_size[0]return self._pe_encoding(coords.to(torch.float))  # B x N x C

2. 总结

本周学习了SAM大模型,该模型的设计和训练具有快速性,它可以将 零样本 转移到新的图像分布和任务,通过点、框、mask进行分割。同时也对SAM代码的实现有了一定的理解,下周我会继续努力。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/751848.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

27-2 文件上传漏洞 - 前端绕过

环境准备:构建完善的安全渗透测试环境:推荐工具、资源和下载链接_渗透测试靶机下载-CSDN博客 前端绕过思路 - 禁用 JavaScript: 背景: 当前开发行业大多采用前后端分离模式,后端使用多种开发语言如 PHP、Java 等,而前端主要使用 JavaScript(JS)。因此,禁用 JavaScrip…

基于YOLOv8/YOLOv7/YOLOv6/YOLOv5的火焰与烟雾检测系统详解(深度学习模型+UI界面升级版+训练数据集)

摘要&#xff1a;本研究详细介绍了一种集成了最新YOLOv8算法的火焰与烟雾检测系统&#xff0c;并与YOLOv7、YOLOv6、YOLOv5等早期算法进行性能评估对比。该系统能够在包括图像、视频文件、实时视频流及批量文件中准确识别火焰与烟雾。文章深入探讨了YOLOv8算法的原理&#xff0…

误删电脑C盘要重装系统吗 误删电脑C盘文件怎么恢复 误删c盘系统文件怎么修复 不小心删除C盘的东西恢复

C盘通常是操作系统(如Windows)的默认安装目录。它包含了操作系统的核心文件、驱动程序及系统所需的各种支持文件。这些文件对于计算机的正常运行至关重要。如果我们不小心将C盘的重要文件删除&#xff0c;会导致应用无法打开。本篇文章&#xff0c;我们将学习误删电脑C盘要重装…

面试算法-39-删除链表的倒数第 N 个结点

题目 给你一个链表&#xff0c;删除链表的倒数第 n 个结点&#xff0c;并且返回链表的头结点。 示例 1&#xff1a; 输入&#xff1a;head [1,2,3,4,5], n 2 输出&#xff1a;[1,2,3,5] 解 class Solution {public ListNode removeNthFromEnd(ListNode head, int n) {L…

基于支持向量机SVM的沉降预测,SVM详细原理,Libsvm详解

目录 支持向量机SVM的详细原理 SVM的定义 SVM理论 Libsvm工具箱详解 简介 参数说明 易错及常见问题 完整代码和数据下载链接:基于支持向量机SVM的沉降预测资源-CSDN文库 https://download.csdn.net/download/abc991835105/88947544 SVM应用实例,基于支持向量机SVM的沉降预测…

指挥航空公司架次与延误率占比

打开前端Vue项目kongguan_web&#xff0c;创建前端 src/components/Delay.vue 页面&#xff0c;并添加柱状图与折线图叠加&#xff0c;设置双Y轴。 页面div设计&#xff0c;代码如下&#xff1a; <template><div><div class"home"><div id&qu…

关于volatile与指令重排序的探讨

写在开头 在之前的学习我们了解到&#xff0c;为了充分利用缓存&#xff0c;提高程序的执行速度&#xff0c;编译器在底层执行的时候&#xff0c;会进行指令重排序的优化操作&#xff0c;但这种优化&#xff0c;在有些时候会带来 有序性 的问题。 那何为有序性呢&#xff1f;…

Halcon OCR文字识别

1、OCR文字识别 OCR&#xff08;Optical Character Recognition&#xff0c;光学字符识别&#xff09;工具对图像中的文字进行识别和分析。 FontFile : Universal_0-9_NoRej dev_update_window (off) read_image (bottle, bottle2) get_image_size (bottle, Width, Height) dev…

JavaScript 中实现请求并发控制

文章目录 浏览器并发请求限制数&#xff08;图&#xff09;实现代码三方插件 假设有 30 个待办任务要执行&#xff0c;而我们希望限制同时执行的任务个数&#xff0c;即最多只有 3 个任务能同时执行。当正在执行任务列表 中的任何 1 个任务完成后&#xff0c;程序会自动从 待办…

VMware安装Ubuntu 18.04.2

下载Ubuntu映像 下载地址&#xff1a;http://old-releases.ubuntu.com/releases/18.04/ 下载名称&#xff1a; ubuntu-18.04.2-desktop-amd64.iso 清华镜像站&#xff1a;https://mirrors.tuna.tsinghua.edu.cn/ubuntu-releases/ 阿里云镜像站&#xff1a;https://mirrors.ali…

python 统计中国观鸟记录中心官网已观测的鸟类种类

python 统计中国观鸟记录中心官网已观测的鸟类种类 中国观鸟记录中心网站&#xff1a;https://www.birdreport.cn/ 先下载官网 Excel 文件 文件放置目录如下&#xff1a; home dataset xxx.xlsxxxx.xlsxxxx.xlsx Excelgrep.py &#xff08;进行文件内容提取的程序&#xff…

关于Ubuntu虚拟机突然上不了网的问题

今天刚重新把Ubuntu虚拟机下回来准备大干一场&#xff0c;结果去吃饭回来虚拟机就上不去网了&#xff0c;具体体现为右上角没有网络的图标&#xff0c;下图是有网络的情况&#xff0c;废话不多说&#xff0c;直接给出解决方案&#xff1a;博客在此 我就是运行了这三行代码就成功…

设计模式 — — 单例模式

一、是什么 单例模式只会在全局作用域下创建一次实例对象&#xff0c;让所有需要调用的地方都共享这一单例对象 二、实现 // 单例构造函数 function CreateSingleton (name) {this.name name;this.getName(); };// 获取实例的名字 CreateSingleton.prototype.getName func…

✅技术社区—跨域问题及解决方案

一、什么是跨域、为什么会跨域&#xff1f; 我们把问题分解 谁出现的跨域&#xff1f; 》 浏览器&#xff01; 为何出现&#xff1f; 》 同源策略 什么是同源策略&#xff1f; 根据百度百科 同源策略/SOP&#xff08;Same origin policy&#xff09;是一种约定&#xff0…

Linux 时间系统调用

UNIX及LinuxQ的时间系统是由「新纪元时间」Epoch开始计算起。Epoch是指定为1970年1月1日凌晨零点零分零秒&#xff0c;格林威治时间。目前大部份的UNX系统都是用32位来记录时间&#xff0c;正值表示为1970以后&#xff0c;负值则表示1970年以前。 对于当前时间到Epoch 我们用两…

代码算法训练营day10 | 232.用栈实现队列、225. 用队列实现栈

day10: 232.用栈实现队列225. 用队列实现栈 232.用栈实现队列 题目链接 状态&#xff1a; 文档&#xff1a;programmercarl.com 思路&#xff1a; 用栈实现队列。要先明白两者的区别。 栈&#xff1a;单开门&#xff0c;先进后出&#xff0c;只有一端能进出。 队列&#xff1a;…

继承 ResponseEntityExceptionHandler

目录 作用概述 示例-HttpRequestMethodNotSupportedException 示例-自定义异常处理 总示例 使用了ResponseEntityExceptionHandler后&#xff0c;为什么发生了异常后返回体为空 方法执行顺序 作用概述 这是一个方便的基类&#xff0c;用于希望通过 ExceptionHandler 方法…

Vue项目的搭建

Node.js 下载 Node.js — Download (nodejs.org)https://nodejs.org/en/download/ 安装 测试 winR->cmd执行 node -v配置 在安装目录下创建两个子文件夹node_cache和node_global,我的就是 D:\nodejs\node_cache D:\nodejs\node_global 在node_global文件下再创建一个…

并查集(详解+例题)

1、作用 将两个集合合并 询问两个元素是否在一个集合中 2、基本原理 每个集合用一颗树表示。树根的编号就是整个集合的编号。每个节点存储它的父节点&#xff0c;p[x]表示x的父节点。 3、实现 问题1&#xff1a;如何判断树根&#xff1a;if(p[x]x); 问题2&#xff1a;如何求…

C++ 特殊类及单例模式

文章目录 1. 前言2. 不能被拷贝的类3. 不能被继承的类4. 只能在堆上创建对象的类5. 只能在栈上创建对象的类6. 只能创建一个对象的类&#xff08;单例模式&#xff09; 1. 前言 在实际场景中&#xff0c;我们在编写类的过程中总会遇到一些特殊情况&#xff0c;比如设计一个类不…