sam代码简析

  • Segment Anything:建立了迄今为止最大的分割数据集,在1100万张图像上有超过1亿个掩码,模型的设计和训练是灵活的,其重要的特点是Zero-shot(零样本迁移性)转移到新的图像分布和任务,一个图像分割新的任务、模型和数据集。SAM由三个部分组成:一个强大的图像编码器(Image encoder)计算图像嵌入,一个提示编码器(Prompt encoder)嵌入提示,然后将两个信息源组合在一个轻量级掩码解码器(Mask decoder)中来预测分割掩码。在视觉领域通过Prompt+基础大模型的套路来解决目标分割的问题。

  • 需要下载官方给的权重pth下载链接,权重文件可以在给的readme.md上的链接下载。下载好权重文件之后,我们就开始配置并调用SAM,主要的文件其实就在amg.py上面进行配置运行即可,其他文件大家有兴趣的可以仔细阅读一下了解。

  • 主要我们就需要一个input文件,放入我们需要分割的文件路径,最好是jpg,png格式的,可以看官方支持什么格式,还有一个output文件路径,放入我们结果生成的文件。model-type就是刚才说的权重文件的类型。checkpoint就是权重文件路径,刚才下载的文件,把路径放进去即可。

    • parser.add_argument("--input",type=str,required=False,default=r'.\JPEGImages',help="Path to either a single input image or folder of images.",
      )
      parser.add_argument("--output",type=str,required=False,default=r'.\JPEGImages\result',help=("Path to the directory where masks will be output. Output will be either a folder of PNGs per image or a single json with COCO-style masks."),
      )
      parser.add_argument("--model-type",type=str,required=False,default='vit_h',help="The type of model to load, in ['default', 'vit_h', 'vit_l', 'vit_b']",
      )
      parser.add_argument("--checkpoint",type=str,required=False,default=r'.\segment-anything-main\sam_vit_h_4b8939.pth',help="The path to the SAM checkpoint to use for mask generation.",
      )
      
  • SAM 源码提供了3种不同大小的模型。sam_model_registry函数在segment_anything/build_sam.py文件内定义,SAM的3种模型通过字典形式保存。

    • sam_model_registry = {"default": build_sam_vit_h,"vit_h": build_sam_vit_h,"vit_l": build_sam_vit_l,"vit_b": build_sam_vit_b,
      }# 选择合适的模型以及加载对应权重
      sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
      sam.to(device=device)
      
    • sam_model_registry中的 3 种模型结构是一致的,部分参数不同导致模型的大小有别

    • def build_sam_vit_h(checkpoint=None):return _build_sam(encoder_embed_dim=1280,encoder_depth=32,encoder_num_heads=16,encoder_global_attn_indexes=[7, 15, 23, 31],checkpoint=checkpoint,)
      def build_sam_vit_l(checkpoint=None):return _build_sam(encoder_embed_dim=1024,encoder_depth=24,encoder_num_heads=16,encoder_global_attn_indexes=[5, 11, 17, 23],checkpoint=checkpoint,)
      def build_sam_vit_b(checkpoint=None):return _build_sam(encoder_embed_dim=768,encoder_depth=12,encoder_num_heads=12,encoder_global_attn_indexes=[2, 5, 8, 11],checkpoint=checkpoint,)
      
  • 最后是_build_sam方法,完成了sam模型的初始化以及权重的加载,这里可以注意到sam模型由三个神经网络模块组成:ImageEncoderViT(Image encoder)、PromptEncoder和MaskDecoder

    • def _build_sam(encoder_embed_dim,encoder_depth,encoder_num_heads,encoder_global_attn_indexes,checkpoint=None,
      ):prompt_embed_dim = 256image_size = 1024vit_patch_size = 16image_embedding_size = image_size // vit_patch_sizesam = Sam(image_encoder=ImageEncoderViT(depth=encoder_depth,embed_dim=encoder_embed_dim,img_size=image_size,mlp_ratio=4,norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),num_heads=encoder_num_heads,patch_size=vit_patch_size,qkv_bias=True,use_rel_pos=True,global_attn_indexes=encoder_global_attn_indexes,window_size=14,out_chans=prompt_embed_dim,),prompt_encoder=PromptEncoder(embed_dim=prompt_embed_dim,image_embedding_size=(image_embedding_size, image_embedding_size),input_image_size=(image_size, image_size),mask_in_chans=16,),mask_decoder=MaskDecoder(num_multimask_outputs=3,transformer=TwoWayTransformer(depth=2,embedding_dim=prompt_embed_dim,mlp_dim=2048,num_heads=8,),transformer_dim=prompt_embed_dim,iou_head_depth=3,iou_head_hidden_dim=256,),pixel_mean=[123.675, 116.28, 103.53],pixel_std=[58.395, 57.12, 57.375],)sam.eval()if checkpoint is not None:with open(checkpoint, "rb") as f:state_dict = torch.load(f)sam.load_state_dict(state_dict)return sam
      
  • SamPredictor类,sam模型被封装在SamPredictor类的对象中,方便使用。SamPredictor类在segment_anything/predictor.py文件

    • predictor = SamPredictor(sam)
      predictor.set_image(image) # image_encoder操作在set_image时就已经执行了,而不是在predic时
      
  • 首先确认输入是否是RGB或BGR三通道图像,将BGR图像统一为RGB,而后并对图像尺寸和channel顺序作出调整满足神经网络的输入要求。

    • def set_image(self, image: np.ndarray, image_format: str = "RGB",) -> None:# 图像不是['RGB', 'BGR']格式则报错assert image_format in ["RGB","BGR",], f"image_format must be in ['RGB', 'BGR'], is {image_format}."# H,W,Cif image_format != self.model.image_format:image = image[..., ::-1]            # H,W,C中 C通道的逆序RGB-->BGR# Transform the image to the form expected by the model 改变图像尺寸input_image = self.transform.apply_image(image)# torch 浅拷贝 转tensorinput_image_torch = torch.as_tensor(input_image, device=self.device)# permute H,W,C-->C,H,W# contiguous 连续内存# [None, :, :, :] C,H,W -->1,C,H,Winput_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :]self.set_torch_image(input_image_torch, image.shape[:2])
      
  • set_torch_image:用padding填补缩放后的图片,在 H 和 W 满足神经网络需要的标准尺寸,而后通过image_encoder模型获得图像特征数据并保存在self.features中,同时self.is_image_set设为true。注意image_encoder过程不是在predict_torch时与Prompt encoder过程和Mask decoder过程一同执行的,而是在set_image时就已经执行了。

    • def set_torch_image(self,transformed_image: torch.Tensor,original_image_size: Tuple[int, ...],
      ) -> None:# 满足输入是四个维度且为B,C,H,Wassert (len(transformed_image.shape) == 4and transformed_image.shape[1] == 3and max(*transformed_image.shape[2:]) == self.model.image_encoder.img_size), f"set_torch_image input must be BCHW with long side {self.model.image_encoder.img_size}."self.reset_image()# 原始图像的尺寸self.original_size = original_image_size# torch图像的尺寸self.input_size = tuple(transformed_image.shape[-2:])# torch图像进行paddinginput_image = self.model.preprocess(transformed_image)# image_encoder网络模块对图像进行编码self.features = self.model.image_encoder(input_image)# 图像设置flagself.is_image_set = True
      
  • predict对输入到模型中进行预测的数据(标记点 apply_coords 和标记框 apply_boxes )进行一个预处理,并接受和处理模型返回的预测结果

    • def predict(self,# 标记点的坐标point_coords: Optional[np.ndarray] = None,# 标记点的标签point_labels: Optional[np.ndarray] = None,# 标记框的坐标box: Optional[np.ndarray] = None,# 输入的maskmask_input: Optional[np.ndarray] = None,# 输出多个mask供选择multimask_output: bool = True,# ture 返回掩码logits, false返回阈值处理的二进制掩码。return_logits: bool = False,
      ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:# 假设没有设置图像,报错if not self.is_image_set:raise RuntimeError("An image must be set with .set_image(...) before mask prediction.")# Transform input prompts # 输入提示转换为torchcoords_torch, labels_torch, box_torch, mask_input_torch = None, None, None, Noneif point_coords is not None:# 标记点坐标对应的标记点标签不能为空assert (point_labels is not None), "point_labels must be supplied if point_coords is supplied."# 图像改变了原始尺寸,所以对应的点位置也会发生改变point_coords = self.transform.apply_coords(point_coords, self.original_size)# 标记点坐标和标记点标签 np-->tensorcoords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=self.device)labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=self.device)# 增加维度# coords_torch:N,2-->1,N,2# labels_torch: N-->1,Ncoords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :]if box is not None:# 图像改变了原始尺寸,所以对应的框坐标位置也会发生改变box = self.transform.apply_boxes(box, self.original_size)# 标记框坐标 np-->tensorbox_torch = torch.as_tensor(box, dtype=torch.float, device=self.device)# 增加维度 N,4-->1,N,4box_torch = box_torch[None, :]if mask_input is not None:# mask np-->tensormask_input_torch = torch.as_tensor(mask_input, dtype=torch.float, device=self.device)# 增加维度 1,H,W-->B,1,H,Wmask_input_torch = mask_input_torch[None, :, :, :]# 输入数据预处理完毕,可以输入到网络中 masks, iou_predictions, low_res_masks = self.predict_torch(coords_torch,labels_torch,box_torch,mask_input_torch,multimask_output,return_logits=return_logits,)# 因为batchsize为1,压缩维度# maskmasks = masks[0].detach().cpu().numpy()# scoreiou_predictions = iou_predictions[0].detach().cpu().numpy()low_res_masks = low_res_masks[0].detach().cpu().numpy()return masks, iou_predictions, low_res_masks
      def postprocess_masks(self,masks: torch.Tensor,input_size: Tuple[int, ...],original_size: Tuple[int, ...],
      ) -> torch.Tensor:# mask上采样到与输入到模型中的图片尺寸一致masks = F.interpolate(masks,(self.image_encoder.img_size, self.image_encoder.img_size),mode="bilinear",align_corners=False,)masks = masks[..., : input_size[0], : input_size[1]]# mask resize 到与未做处理的原始图片尺寸一致masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False)return masks
      
  • predict_torch:输入数据经过预处理后输入到模型中预测结果。Prompt encoder过程和Mask decoder过程是在predict_torch时执行的

    • def predict_torch(self,point_coords: Optional[torch.Tensor],point_labels: Optional[torch.Tensor],boxes: Optional[torch.Tensor] = None,mask_input: Optional[torch.Tensor] = None,multimask_output: bool = True,return_logits: bool = False,
      ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:# 假设没有设置图像,报错if not self.is_image_set:raise RuntimeError("An image must be set with .set_image(...) before mask prediction.")# 绑定标记点和标记点标签if point_coords is not None:points = (point_coords, point_labels)else:points = None# ----- Prompt encoder -----sparse_embeddings, dense_embeddings = self.model.prompt_encoder(points=points,boxes=boxes,masks=mask_input,)# ----- Prompt encoder -----# ----- Mask decoder -----low_res_masks, iou_predictions = self.model.mask_decoder(image_embeddings=self.features,image_pe=self.model.prompt_encoder.get_dense_pe(),sparse_prompt_embeddings=sparse_embeddings,dense_prompt_embeddings=dense_embeddings,multimask_output=multimask_output,)#  ----- Mask decoder -----# 上采样mask掩膜到原始图片尺寸# Upscale the masks to the original image resolutionmasks = self.model.postprocess_masks(low_res_masks, self.input_size, self.original_size)if not return_logits:masks = masks > self.model.mask_thresholdreturn masks, iou_predictions, low_res_masks
      
  • get_image_embedding:获得图像image_encoder的特征。

    • def get_image_embedding(self) -> torch.Tensor:if not self.is_image_set:raise RuntimeError("An image must be set with .set_image(...) to generate an embedding.")assert self.features is not None, "Features must exist if an image has been set."return self.features
      
  • ResizeLongestSide是专门用来处理图片、标记点和标记框的工具类。ResizeLongestSide类在segment_anything/utils/transforms.py文件

    • apply_image:原图尺寸根据标准尺寸计算调整(get_preprocess_shape)得新尺寸

    • def apply_image(self, image: np.ndarray) -> np.ndarray:target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length)# to_pil_image将numpy装变为PIL.Image,而后resizereturn np.array(resize(to_pil_image(image), target_size))
      
    • 不直接使用resize的目的是为了不破坏原图片中各个物体的比例关系。通过计算获得与标准尺寸对应的缩放比例并缩放图片,后续通过padding补零操作(虚线部分),将所有图片的尺寸都变成标准尺寸

    • apply_coords:图像改变了原始尺寸,对应的标记点坐标位置也要改变[get_preprocess_shape]。

    • def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray:old_h, old_w = original_size# 图像改变了原始尺寸,所以对应的标记点坐标位置也会发生改变new_h, new_w = self.get_preprocess_shape(original_size[0], original_size[1], self.target_length)# 深拷贝coordscoords = deepcopy(coords).astype(float)# 改变对应标记点坐标coords[..., 0] = coords[..., 0] * (new_w / old_w)coords[..., 1] = coords[..., 1] * (new_h / old_h)return coords
      
    • apply_boxes:图像改变了原始尺寸,对应的标记框坐标位置也要改变[get_preprocess_shape]。

    • def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray:# 图像改变了原始尺寸,所以对应的框坐标位置也会发生改变# reshape: N,4-->N,2,2boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size)# reshape: N,2,2-->N,4return boxes.reshape(-1, 4)
      
    • get_preprocess_shape

    • def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]:# H和W的长边(大值)作为基准,计算比例,缩放H W的大小scale = long_side_length * 1.0 / max(oldh, oldw)newh, neww = oldh * scale, oldw * scale# 四舍五入neww = int(neww + 0.5)newh = int(newh + 0.5)return (newh, neww)
      
  • 图像编码器

    • SAM模型关于ViT网络的配置,以sam_vit_b为例,分析ViT网络的结构。

    • def build_sam_vit_b(checkpoint=None):return _build_sam(# 图像编码channelencoder_embed_dim=768,# 主体编码器的个数encoder_depth=12,# attention中head的个数encoder_num_heads=12,# 需要将相对位置嵌入添加到注意力图的编码器( Encoder Block)encoder_global_attn_indexes=[2, 5, 8, 11],# 权重checkpoint=checkpoint,)
      
    • sam模型中image_encoder模块初始化

    • image_encoder=ImageEncoderViT(# 主体编码器的个数depth=encoder_depth,# 图像编码channelembed_dim=encoder_embed_dim,# 输入图像的标准尺寸img_size=image_size,# mlp中channel缩放的比例mlp_ratio=4,# 归一化层norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),# attention中head的个数num_heads=encoder_num_heads,# patch的大小patch_size=vit_patch_size,# qkv全连接层的偏置qkv_bias=True,# 是否需要将相对位置嵌入添加到注意力图use_rel_pos=True,# 需要将相对位置嵌入添加到注意力图的编码器序号(Encoder Block)global_attn_indexes=encoder_global_attn_indexes,# attention中的窗口大小window_size=14,# 输出的channelout_chans=prompt_embed_dim,
      ),
      
    • ViT网络(ImageEncoderViT类)结构参数配置。

    • def __init__(self,img_size: int = 1024,       # 输入图像的标准尺寸patch_size: int = 16,       # patch的大小in_chans: int = 3,          # 输入图像channelembed_dim: int = 768,       # 图像编码channeldepth: int = 12,            # 主体编码器的个数num_heads: int = 12,        # attention中head的个数mlp_ratio: float = 4.0,     # mlp中channel缩放的比例out_chans: int = 256,       # 输出特征的channelqkv_bias: bool = True,      # qkv全连接层的偏置flagnorm_layer: Type[nn.Module] = nn.LayerNorm,     # 归一化层act_layer: Type[nn.Module] = nn.GELU,           # 激活层use_abs_pos: bool = True,               # 是否使用绝对位置嵌入use_rel_pos: bool = False,              # 是否需要将相对位置嵌入添加到注意力图rel_pos_zero_init: bool = True,         # 源码暂时没有用到window_size: int = 0,                   # attention中的窗口大小global_attn_indexes: Tuple[int, ...] = (),      # 需要将相对位置嵌入添加到注意力图的编码器序号(Encoder Block)
      ) -> None:super().__init__()self.img_size = img_size# -----patch embedding-----self.patch_embed = PatchEmbed(kernel_size=(patch_size, patch_size),stride=(patch_size, patch_size),in_chans=in_chans,embed_dim=embed_dim,)# -----patch embedding-----# -----positional embedding-----self.pos_embed: Optional[nn.Parameter] = Noneif use_abs_pos:# Initialize absolute positional embedding with pretrain image size.# 使用预训练图像大小初始化绝对位置嵌入。self.pos_embed = nn.Parameter(torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim))# -----positional embedding-----# -----Transformer Encoder-----self.blocks = nn.ModuleList()for i in range(depth):block = Block(dim=embed_dim,num_heads=num_heads,mlp_ratio=mlp_ratio,qkv_bias=qkv_bias,norm_layer=norm_layer,act_layer=act_layer,use_rel_pos=use_rel_pos,rel_pos_zero_init=rel_pos_zero_init,window_size=window_size if i not in global_attn_indexes else 0,input_size=(img_size // patch_size, img_size // patch_size),)self.blocks.append(block)# -----Transformer Encoder-----# -----Neck-----self.neck = nn.Sequential(nn.Conv2d(embed_dim,out_chans,kernel_size=1,bias=False,),LayerNorm2d(out_chans),nn.Conv2d(out_chans,out_chans,kernel_size=3,padding=1,bias=False,),LayerNorm2d(out_chans),)# -----Neck----- 
      
  • ViT网络(ImageEncoderViT类)在特征提取中的几个基本步骤:

    • patch embedding:将图片切分成图片序列块,再经过维度映射后展平成一维向量

    • positional embedding:嵌入位置编码(用于保留位置信息)

    • Transformer Encoder:主体编码器

    • Neck:过渡层

    • def forward(self, x: torch.Tensor) -> torch.Tensor:# patch embedding过程x = self.patch_embed(x)# positional embedding过程if self.pos_embed is not None:x = x + self.pos_embed# Transformer Encoder过程for blk in self.blocks:x = blk(x)# Neck过程 B H W C -> B C H Wx = self.neck(x.permute(0, 3, 1, 2))return x
      
    • PatchEmbed类: 源码其实就是卷积核大小16x16(巧妙切分成固定大小16x16的patch),卷积核通道3×768的卷积操作。图像大小决定了patch的数量

    • 在这里插入图片描述

    • class PatchEmbed(nn.Module):def __init__(self,kernel_size: Tuple[int, int] = (16, 16),    # 卷积核大小stride: Tuple[int, int] = (16, 16),         # 步长padding: Tuple[int, int] = (0, 0),          # paddingin_chans: int = 3,                          # 输入channelembed_dim: int = 768,                       # 输出channel) -> None:super().__init__()self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding)def forward(self, x: torch.Tensor) -> torch.Tensor:x = self.proj(x)# B C H W -> B H W Cx = x.permute(0, 2, 3, 1)return x
      
    • 经过patch embedding后输出tokens需要加入位置编码,位置编码可以理解为一张map,map的行数与输入序列个数相同,每一行代表一个向量,向量的维度和输入序列tokens的维度相同,位置编码的操作是sum,所以维度依旧保持不变。图像尺寸是1024的,因此patch数量是64(=1024/16)

    • # 在ImageEncoderViT的__init__定义
      if use_abs_pos:# Initialize absolute positional embedding with pretrain image size.# 使用预训练图像大小初始化绝对位置嵌入。self.pos_embed = nn.Parameter(torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim))
      # 在ImageEncoderViT的forward添加位置编码
      if self.pos_embed is not None:x = x + self.pos_embed
      
    • Transformer Encoder多个重复堆叠Encoder Block组成。

    • # 在ImageEncoderViT的__init__定义
      # -----Transformer Encoder-----
      self.blocks = nn.ModuleList()
      for i in range(depth):block = Block(dim=embed_dim,                  # 输入channelnum_heads=num_heads,            # attention中head的个数mlp_ratio=mlp_ratio,            # mlp中channel缩放的比例qkv_bias=qkv_bias,              # qkv全连接层的偏置flagnorm_layer=norm_layer,          # 归一化层act_layer=act_layer,            # 激活层use_rel_pos=use_rel_pos,        # 是否需要将相对位置嵌入添加到注意力图rel_pos_zero_init=rel_pos_zero_init,        # 源码暂时没有用到window_size=window_size if i not in global_attn_indexes else 0,      # attention中的窗口大小input_size=(img_size // patch_size, img_size // patch_size),         # 输入特征的尺寸)self.blocks.append(block)
      # -----Transformer Encoder-----
      
    • Encoder Block从低到高由LayerNorm 、Multi-Head Attention和MLP构成。

    • class Block(nn.Module):def __init__(self,dim: int,                           # 输入channelnum_heads: int,                     # attention中head的个数mlp_ratio: float = 4.0,             # mlp中channel缩放的比例qkv_bias: bool = True,              # qkv全连接层的偏置flagnorm_layer: Type[nn.Module] = nn.LayerNorm,     # 归一化层act_layer: Type[nn.Module] = nn.GELU,           # 激活层use_rel_pos: bool = False,                      # 是否需要将相对位置嵌入添加到注意力图rel_pos_zero_init: bool = True,                 # 源码暂时没有用到window_size: int = 0,                           # attention中的窗口大小input_size: Optional[Tuple[int, int]] = None,   # 输入特征的尺寸) -> None:super().__init__()self.norm1 = norm_layer(dim)         # 激活层self.attn = Attention(               # Multi-Head Attentiondim,num_heads=num_heads,qkv_bias=qkv_bias,use_rel_pos=use_rel_pos,rel_pos_zero_init=rel_pos_zero_init,input_size=input_size if window_size == 0 else (window_size, window_size),)self.norm2 = norm_layer(dim)self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer)     # MLPself.window_size = window_size              #def forward(self, x: torch.Tensor) -> torch.Tensor:shortcut = xx = self.norm1(x)# Window partition 对X进行paddingif self.window_size > 0:H, W = x.shape[1], x.shape[2]x, pad_hw = window_partition(x, self.window_size)x = self.attn(x)# Reverse window partition 去除X的padding部分if self.window_size > 0:x = window_unpartition(x, self.window_size, pad_hw, (H, W))x = shortcut + xx = x + self.mlp(self.norm2(x))return x
      
    • Partition操作

    • def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]:B, H, W, C = x.shapepad_h = (window_size - H % window_size) % window_sizepad_w = (window_size - W % window_size) % window_sizeif pad_h > 0 or pad_w > 0:x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))Hp, Wp = H + pad_h, W + pad_w# B,Hp/S,S,Wp/S,S,Cx = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)# B,Hp/S,Wp/S,S,S,C-->BHpWp/SS,S,S,Cwindows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)return windows, (Hp, Wp)
      
    • Unpartition操作

    • def window_unpartition(windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int]
      ) -> torch.Tensor:Hp, Wp = pad_hwH, W = hwB = windows.shape[0] // (Hp * Wp // window_size // window_size)# BHpWp/SS,S,S,C-->B,Hp/S,Wp/S,S,S,Cx = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)# B,Hp/S,Wp/S,S,S,C-->B,Hp,Wp,Cx = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)if Hp > H or Wp > W:x = x[:, :H, :W, :].contiguous()# B,H,W,Creturn x
      
    • 在这里插入图片描述

    • window_partition调整了原始特征尺寸为(H×W–>S×S),目的是了在后续的Multi-Head Attention过程中将相对位置嵌入添加到注意力图(attn),并不是所有Block都需要在注意力图中嵌入相对位置信息;window_unpartition则是恢复特征的原始尺寸(S×S–>H×W)。

    • Multi-Head Attention:先从Attention讲解,再到Multi-Head Attention,最后再讲注意力特征嵌入了相对位置特征的Multi-Head Attention。

    • class Attention(nn.Module):"""Multi-head Attention block with relative position embeddings."""def __init__(self,dim: int,               # 输入channelnum_heads: int = 8,     # head数目qkv_bias: bool = True,use_rel_pos: bool = False,rel_pos_zero_init: bool = True,input_size: Optional[Tuple[int, int]] = None, # 嵌入相对位置注意力特征的尺寸) -> None:super().__init__()self.num_heads = num_headshead_dim = dim // num_headsself.scale = head_dim**-0.5self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)self.proj = nn.Linear(dim, dim)self.use_rel_pos = use_rel_posif self.use_rel_pos:        # 使用相对位置编码assert (input_size is not None), "Input size must be provided if using relative positional encoding."# initialize relative positional embeddings# 2S-1,Eposself.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))def forward(self, x: torch.Tensor) -> torch.Tensor:B, H, W, _ = x.shape# qkv with shape (3, B, nHead, H * W, C)qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)# q, k, v with shape (B * nHead, H * W, C)q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)# attn with shape (B * nHead, H * W,  H * W)attn = (q * self.scale) @ k.transpose(-2, -1)if self.use_rel_pos:# 假设use_rel_pos是true (H, W)是 S×Sattn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))attn = attn.softmax(dim=-1)x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)x = self.proj(x)return x
      
    • 对于输入到Multi-head attention模块的特征 F(N×E) ,通过attention模块的nn.Linear进一步提取特征获得输出特征 v(value) 。为了考虑 N 个特征之间存在的亲疏和位置关系对于 v 的影响,所以需要一个额外 attn(attention) 或者理解为权重 w(weight) 对 v 进行加权操作,这引出了计算 w 所需的 q(query) 与 k(key) ,因此可以看到任何V都考虑了N 个token特征之间相互的影响。Multi-head attention的流程如下图所示(不考虑batchsize):

      • 首先将每个token的qkv特征维度embed_dim均拆分到每个head的上
      • 每个head分别通过q和k计算得到权重w,权重w和v得到输出output,合并所有head的output得到最终的output
    • get_rel_pos用于计算h和w的相对位置的嵌入特征

    • def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:max_rel_dist = int(2 * max(q_size, k_size) - 1)# Interpolate rel pos if needed.if rel_pos.shape[0] != max_rel_dist:# Interpolate rel pos.  相关位置进行插值rel_pos_resized = F.interpolate(# 1,N,Ep --> 1,Ep,N --> 1,Ep,2S-1rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),size=max_rel_dist,mode="linear",)# Ep,2S-1 --> 2S-1,Eprel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)else:rel_pos_resized = rel_pos# Scale the coords with short length if shapes for q and k are different.# 如果q和k长度值不同,则用短边长度缩放坐标。q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)# S,Srelative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)# tensor索引是tensor时,即tensor1[tensor2]# 假设tensor2某个具体位置值是2,则tensor1[2]位置的tensor1切片替换tensor2中的2# tensor1->shape 5,5,3 tensor2->shape 2,2,3 tensor1切片->shape 5,3 tensor1[tensor2]->shape 2,2,3,5,3# tensor1->shape 5,5 tensor2->shape 3,2,3 tensor1切片->shape 5 tensor1[tensor2]->shape 3,2,3,5# 2S-1,Ep-->S,S,Epreturn rel_pos_resized[relative_coords.long()]
      
    • 在这里插入图片描述

    • add_decomposed_rel_pos为atten注意力特征添加相对位置的嵌入特征。

    • def add_decomposed_rel_pos(attn: torch.Tensor,q: torch.Tensor,rel_pos_h: torch.Tensor,rel_pos_w: torch.Tensor,q_size: Tuple[int, int],k_size: Tuple[int, int],
      ) -> torch.Tensor:# S,Sq_h, q_w = q_sizek_h, k_w = k_size# rel_pos_h -> 2S-1×EposRh = get_rel_pos(q_h, k_h, rel_pos_h)Rw = get_rel_pos(q_w, k_w, rel_pos_w)B, _, dim = q.shaper_q = q.reshape(B, q_h, q_w, dim)# torch.einsum用于简洁的表示乘积、点积、转置等方法# B,q_h, q_w, k_hrel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)# B,q_h, q_w, k_wrel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)attn = (# B,q_h, q_w, k_h, k_wattn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]).view(B, q_h * q_w, k_h * k_w)return attn
      
    • MLP

    • class MLPBlock(nn.Module):def __init__(self,embedding_dim: int,mlp_dim: int,act: Type[nn.Module] = nn.GELU,) -> None:super().__init__()self.lin1 = nn.Linear(embedding_dim, mlp_dim)self.lin2 = nn.Linear(mlp_dim, embedding_dim)self.act = act()def forward(self, x: torch.Tensor) -> torch.Tensor:return self.lin2(self.act(self.lin1(x)))
      
    • Neck

    • # 在ImageEncoderViT的__init__定义
      # -----Neck-----
      self.neck = nn.Sequential(nn.Conv2d(embed_dim,out_chans,kernel_size=1,bias=False,),LayerNorm2d(out_chans),nn.Conv2d(out_chans,out_chans,kernel_size=3,padding=1,bias=False,),LayerNorm2d(out_chans),
      )
      # -----Neck-----
      class LayerNorm2d(nn.Module):def __init__(self, num_channels: int, eps: float = 1e-6) -> None:super().__init__()self.weight = nn.Parameter(torch.ones(num_channels))self.bias = nn.Parameter(torch.zeros(num_channels))self.eps = epsdef forward(self, x: torch.Tensor) -> torch.Tensor:u = x.mean(1, keepdim=True)       # dim=1维度求均值并保留通道s = (x - u).pow(2).mean(1, keepdim=True)x = (x - u) / torch.sqrt(s + self.eps)x = self.weight[:, None, None] * x + self.bias[:, None, None]return x
      
  • sam模型中prompt_encoder模块初始化

    • prompt_encoder=PromptEncoder(# 提示编码channel(和image_encoder输出channel一致,后续会融合)embed_dim=prompt_embed_dim,# mask的编码尺寸(和image_encoder输出尺寸一致)image_embedding_size=(image_embedding_size, image_embedding_size),# 输入图像的标准尺寸input_image_size=(image_size, image_size),# 对输入掩码编码的通道数mask_in_chans=16,
      ),
      
  • ProEnco网络结构与执行流程,ProEnco网络(PromptEncoder类)结构参数配置。

    • def __init__(self,embed_dim: int,                         # 提示编码channelimage_embedding_size: Tuple[int, int],  # mask的编码尺寸input_image_size: Tuple[int, int],      # 输入图像的标准尺寸mask_in_chans: int,                     # 输入掩码编码的通道数activation: Type[nn.Module] = nn.GELU,  # 激活层
      ) -> None:super().__init__()self.embed_dim = embed_dim              # 提示编码channelself.input_image_size = input_image_size          # 输入图像的标准尺寸self.image_embedding_size = image_embedding_size  # mask的编码尺寸self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)self.num_point_embeddings: int = 4                # 4个点:正负点,框的俩个点point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)]   # 4个点的嵌入向量# nn.ModuleList它是一个存储不同module,# 并自动将每个module的parameters添加到网络之中的容器self.point_embeddings = nn.ModuleList(point_embeddings)                     # 4个点的嵌入向量添加到网络self.not_a_point_embed = nn.Embedding(1, embed_dim)                         # 不是点的嵌入向量self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1])           # mask的输入尺寸self.mask_downscaling = nn.Sequential( # 输入mask时 4倍下采样nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2),LayerNorm2d(mask_in_chans // 4),activation(),nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2),LayerNorm2d(mask_in_chans),activation(),nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),)self.no_mask_embed = nn.Embedding(1, embed_dim) # 没有mask输入时 嵌入向量
      
    • SAM模型中ProEnco网络结构如下图所示:

    • 在这里插入图片描述

    • ProEnco网络(PromptEncoder类)在特征提取中的几个基本步骤:

      • Embed_Points:标记点编码(标记点由点转变为向量)
      • Embed_Boxes:标记框编码(标记框由点转变为向量)
      • Embed_Masks:mask编码(mask下采样保证与Image encoder输出一致)
    • def forward(self,points: Optional[Tuple[torch.Tensor, torch.Tensor]],boxes: Optional[torch.Tensor],masks: Optional[torch.Tensor],
      ) -> Tuple[torch.Tensor, torch.Tensor]:# 获得 batchsize  当前predict为1bs = self._get_batch_size(points, boxes, masks)# -----sparse_embeddings----sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device())if points is not None:coords, labels = pointspoint_embeddings = self._embed_points(coords, labels, pad=(boxes is None))sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)if boxes is not None:box_embeddings = self._embed_boxes(boxes)sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1)# -----sparse_embeddings----# -----dense_embeddings----if masks is not None:dense_embeddings = self._embed_masks(masks)else:dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(bs, -1, self.image_embedding_size[0], self.image_embedding_size[1])# -----dense_embeddings----return sparse_embeddings, dense_embeddings
      def _get_batch_size(self,points: Optional[Tuple[torch.Tensor, torch.Tensor]],boxes: Optional[torch.Tensor],masks: Optional[torch.Tensor],
      ) -> int:if points is not None:return points[0].shape[0]elif boxes is not None:return boxes.shape[0]elif masks is not None:return masks.shape[0]else:return 1
      def _get_device(self) -> torch.device:return self.point_embeddings[0].weight.device
      
    • Embed_Points:标记点预处理,将channel由2变成embed_dim(MatMul:forward_with_coords),然后再加上位置编码权重。

    • def _embed_points(self,points: torch.Tensor,labels: torch.Tensor,pad: bool,
      ) -> torch.Tensor:# 移到像素中心points = points + 0.5# points和boxes联合则不需要padif pad:padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device)  # B,1,2padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)     # B,1points = torch.cat([points, padding_point], dim=1) # B,N+1,2labels = torch.cat([labels, padding_label], dim=1) # B,N+1point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size)  # B,N+1,2f# labels为-1是非标记点,设为非标记点权重point_embedding[labels == -1] = 0.0point_embedding[labels == -1] += self.not_a_point_embed.weight# labels为0是背景点,加上背景点权重point_embedding[labels == 0] += self.point_embeddings[0].weight# labels为1的目标点,加上目标点权重point_embedding[labels == 1] += self.point_embeddings[1].weightreturn point_embedding
      
    • pad的作用相当于box占位符号,box和points可以联合标定完成图像分割的,但是此时的box只能有一个,不能有多个。

    • Embed_Boxes:标记框预处理,将channel由4到2再变成embed_dim(MatMul:forward_with_coords),然后再加上位置编码权重。

    • def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:# 移到像素中心boxes = boxes + 0.5coords = boxes.reshape(-1, 2, 2)corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size)    ## 目标框起始点的和末位点分别加上权重corner_embedding[:, 0, :] += self.point_embeddings[2].weightcorner_embedding[:, 1, :] += self.point_embeddings[3].weightreturn corner_embedding
      
    • boxes reshape 后 batchsize 是会增加的,B,N,4–>BN,2,2;因此这里可以得出box和points联合标定时,box为什么只能是一个,而不能是多个

    • Embed_Masks:mask的输出尺寸是Image encoder模块输出的图像编码尺寸的4倍,因此为了保持一致,需要4倍下采样。

    • def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor:# mask下采样4倍mask_embedding = self.mask_downscaling(masks)return mask_embedding
      # 在PromptEncoder的__init__定义
      self.mask_downscaling = nn.Sequential(                                                      # 输入mask时 4倍下采样nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2),LayerNorm2d(mask_in_chans // 4),activation(),nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2),LayerNorm2d(mask_in_chans),activation(),nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),)
      
    • 假设没有mask输入,则将no_mask_embed编码扩展到与图像编码一致的尺寸代替mask

    • # 在PromptEncoder的forward定义
      dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(bs, -1, self.image_embedding_size[0], self.image_embedding_size[1]
      )
      
    • PositionEmbeddingRandom:用于将标记点和标记框的坐标进行提示编码预处理。

    • def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:super().__init__()if scale is None or scale <= 0.0:scale = 1.0# 理解为模型的常数 [2,f]self.register_buffer("positional_encoding_gaussian_matrix",scale * torch.randn((2, num_pos_feats)),)
      
    • 将标记点的坐标具体的位置转变为[0~1]之间的比例位置

    • def forward_with_coords(self, coords_input: torch.Tensor, image_size: Tuple[int, int]
      ) -> torch.Tensor:coords = coords_input.clone()# 将坐标位置缩放到[0~1]之间coords[:, :, 0] = coords[:, :, 0] / image_size[1]coords[:, :, 1] = coords[:, :, 1] / image_size[0]# B,N+1,2-->B,N+1,2freturn self._pe_encoding(coords.to(torch.float))
      
    • 标记点位置编码,因为sin和cos,编码的值归一化至 [-1,1]

    • def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:# assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shapecoords = 2 * coords - 1# B,N+1,2 × 2,f --> B,N+1,fcoords = coords @ self.positional_encoding_gaussian_matrixcoords = 2 * np.pi * coords# outputs d_1 x ... x d_n x C shape# B,N+1,2freturn torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
      
  • MaskDecoder网络简述,sam模型中Mask_decoder模块初始化

    • mask_decoder=MaskDecoder(# 消除掩码歧义预测的掩码数num_multimask_outputs=3,# 用于预测mask的网咯transformertransformer=TwoWayTransformer(# 层数depth=2,# 输入channelembedding_dim=prompt_embed_dim,# MLP内部channelmlp_dim=2048,# attention的head数num_heads=8,),# transformer的channeltransformer_dim=prompt_embed_dim,# MLP的深度,MLP用于预测掩模质量的iou_head_depth=3,# MLP隐藏channeliou_head_hidden_dim=256,
      ),
      
    • MaskDeco网络(MaskDecoder类)结构参数配置。

    • def __init__(self,*,# transformer的channeltransformer_dim: int,# 用于预测mask的网咯transformertransformer: nn.Module,# 消除掩码歧义预测的掩码数num_multimask_outputs: int = 3,# 激活层activation: Type[nn.Module] = nn.GELU,# MLP深度,MLP用于预测掩模质量的iou_head_depth: int = 3,# MLP隐藏channeliou_head_hidden_dim: int = 256,
      ) -> None:super().__init__()self.transformer_dim = transformer_dim  # transformer的channel#----- transformer -----self.transformer = transformer       # 用于预测mask的网咯transformer# ----- transformer -----self.num_multimask_outputs = num_multimask_outputs  # 消除掩码歧义预测的掩码数self.iou_token = nn.Embedding(1, transformer_dim)   # iou的takenself.num_mask_tokens = num_multimask_outputs + 1    # mask数self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)      # mask的tokens数#----- upscaled -----# 4倍上采样self.output_upscaling = nn.Sequential(nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),     #转置卷积 上采样2倍LayerNorm2d(transformer_dim // 4),activation(),nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),activation(),)# ----- upscaled -----# ----- MLP -----# 对应mask数的MLPself.output_hypernetworks_mlps = nn.ModuleList([MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3)for i in range(self.num_mask_tokens)])# ----- MLP -----# ----- MLP -----# 对应iou的MLPself.iou_prediction_head = MLP(transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth)# ----- MLP -----
    • SAM模型中MaskDeco网络结构如下图所示:

    • 在这里插入图片描述

    • MaskDeco网络(MaskDecoder类)在特征提取中的几个基本步骤:

      • transformer:融合特征(提示信息特征与图像特征)获得粗略掩膜src
      • upscaled:对粗略掩膜src上采样
      • mask_MLP:全连接层组(计算加权权重,使粗掩膜src转变为掩膜mask)
      • iou_MLP:全连接层组(计算掩膜mask的Score)
    • def forward(self,# image encoder 图像特征image_embeddings: torch.Tensor,# 位置编码image_pe: torch.Tensor,# 标记点和标记框的嵌入编码sparse_prompt_embeddings: torch.Tensor,# 输入mask的嵌入编码dense_prompt_embeddings: torch.Tensor,# 是否输出多个maskmultimask_output: bool,
      ) -> Tuple[torch.Tensor, torch.Tensor]:masks, iou_pred = self.predict_masks(image_embeddings=image_embeddings,image_pe=image_pe,sparse_prompt_embeddings=sparse_prompt_embeddings,dense_prompt_embeddings=dense_prompt_embeddings,)# Select the correct mask or masks for outputif multimask_output:mask_slice = slice(1, None)else:mask_slice = slice(0, 1)masks = masks[:, mask_slice, :, :]iou_pred = iou_pred[:, mask_slice]return masks, iou_pred
      def predict_masks(self,image_embeddings: torch.Tensor,image_pe: torch.Tensor,sparse_prompt_embeddings: torch.Tensor,dense_prompt_embeddings: torch.Tensor,
      ) -> Tuple[torch.Tensor, torch.Tensor]:# Concatenate output tokens# 1,E and 4,E --> 5,Eoutput_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)# 5,E --> B,5,Eoutput_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)# B,5,E and B,N,E -->B,5+N,E       N是点的个数(标记点和标记框的点)tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)# 扩展image_embeddings的B维度,因为boxes标记分割时,n个box时batchsize=batchsize*n# Expand per-image data in batch direction to be per-mask# B,C,H,Wsrc = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)# B,C,H,W + 1,C,H,W ---> B,C,H,Wsrc = src + dense_prompt_embeddings# 1,C,H,W---> B,C,H,Wpos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)b, c, h, w = src.shape# ----- transformer -----# Run the transformer# B,N,Chs, src = self.transformer(src, pos_src, tokens)# ----- transformer -----iou_token_out = hs[:, 0, :]mask_tokens_out = hs[:, 1: (1 + self.num_mask_tokens), :]# Upscale mask embeddings and predict masks using the mask tokens# B,N,C-->B,C,H,W src = src.transpose(1, 2).view(b, c, h, w)# ----- upscaled -----# 4倍上采样upscaled_embedding = self.output_upscaling(src)# ----- upscaled -----hyper_in_list: List[torch.Tensor] = []# ----- mlp -----for i in range(self.num_mask_tokens):# mask_tokens_out[:, i, :]: B,1,C# output_hypernetworks_mlps: B,1,chyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]))# B,n,chyper_in = torch.stack(hyper_in_list, dim=1)# ----- mlp -----b, c, h, w = upscaled_embedding.shape# B,n,c × B,c,N-->B,n,h,wmasks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)# ----- mlp -----# Generate mask quality predictions# iou_token_out: B,1,niou_pred = self.iou_prediction_head(iou_token_out)# ----- mlp -----# masks: B,n,h,w# iou_pred: B,1,nreturn masks, iou_pred
      
    • MaskDeco由多个重复堆叠TwoWayAttention Block和1个Multi-Head Attention组成。

    • class TwoWayTransformer(nn.Module):def __init__(self,# 层数depth: int,# 输入channelembedding_dim: int,# attention的head数num_heads: int,# MLP内部channelmlp_dim: int,activation: Type[nn.Module] = nn.ReLU,attention_downsample_rate: int = 2,) -> None:super().__init__()self.depth = depth      # 层数self.embedding_dim = embedding_dim          # 输入channelself.num_heads = num_heads                  # attention的head数self.mlp_dim = mlp_dim                      # MLP内部隐藏channelself.layers = nn.ModuleList()for i in range(depth):self.layers.append(TwoWayAttentionBlock(embedding_dim=embedding_dim,    # 输入channelnum_heads=num_heads,            # attention的head数mlp_dim=mlp_dim,                # MLP中间channelactivation=activation,          # 激活层attention_downsample_rate=attention_downsample_rate,      # 下采样skip_first_layer_pe=(i == 0),))self.final_attn_token_to_image = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate)self.norm_final_attn = nn.LayerNorm(embedding_dim)def forward(self,image_embedding: Tensor,image_pe: Tensor,point_embedding: Tensor,) -> Tuple[Tensor, Tensor]:# BxCxHxW -> BxHWxC == B x N_image_tokens x Cbs, c, h, w = image_embedding.shape# 图像编码(image_encoder的输出)# BxHWxC=>B,N,Cimage_embedding = image_embedding.flatten(2).permute(0, 2, 1)# 图像位置编码# BxHWxC=>B,N,Cimage_pe = image_pe.flatten(2).permute(0, 2, 1)# 标记点编码# B,N,Cqueries = point_embeddingkeys = image_embedding# -----TwoWayAttention-----for layer in self.layers:queries, keys = layer(queries=queries,keys=keys,query_pe=point_embedding,key_pe=image_pe,)# -----TwoWayAttention-----q = queries + point_embeddingk = keys + image_pe# -----Attention-----attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)# -----Attention-----queries = queries + attn_outqueries = self.norm_final_attn(queries)return queries, keys
      
    • TwoWayAttention Block由LayerNorm 、Multi-Head Attention和MLP构成。

    • class TwoWayAttentionBlock(nn.Module):def __init__(self,embedding_dim: int,         # 输入channelnum_heads: int,             # attention的head数mlp_dim: int = 2048,        # MLP中间channelactivation: Type[nn.Module] = nn.ReLU,      # 激活层attention_downsample_rate: int = 2,         # 下采样skip_first_layer_pe: bool = False,) -> None:super().__init__()self.self_attn = Attention(embedding_dim, num_heads)self.norm1 = nn.LayerNorm(embedding_dim)self.cross_attn_token_to_image = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate)self.norm2 = nn.LayerNorm(embedding_dim)self.mlp = MLPBlock(embedding_dim, mlp_dim, activation)self.norm3 = nn.LayerNorm(embedding_dim)self.norm4 = nn.LayerNorm(embedding_dim)self.cross_attn_image_to_token = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate)self.skip_first_layer_pe = skip_first_layer_pedef forward(self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor) -> Tuple[Tensor, Tensor]:# queries:标记点编码相关(原始标记点编码经过一系列特征提取)# keys:原始图像编码相关(原始图像编码经过一系列特征提取)# query_pe:原始标记点编码# key_pe:原始图像位置编码# 第一轮本身queries==query_pe没比较再"残差"if self.skip_first_layer_pe:queries = self.self_attn(q=queries, k=queries, v=queries)else:q = queries + query_peattn_out = self.self_attn(q=q, k=q, v=queries)queries = queries + attn_outqueries = self.norm1(queries)# Cross attention block, tokens attending to image embeddingq = queries + query_pek = keys + key_peattn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)queries = queries + attn_outqueries = self.norm2(queries)# MLP blockmlp_out = self.mlp(queries)queries = queries + mlp_outqueries = self.norm3(queries)# Cross attention block, image embedding attending to tokensq = queries + query_pek = keys + key_peattn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)keys = keys + attn_outkeys = self.norm4(keys)return queries, keys
      
    • TwoWayAttentionBlock是Prompt encoder的提示信息特征与Image encoder的图像特征的融合过程,而Prompt encoder对提示信息没有过多处理,因此TwoWayAttentionBlock的目的是边对提示信息特征做进一步处理边与图像特征融合

    • MaskDeco的Attention与ViT的Attention有些细微的不同:MaskDeco的Attention是3个FC层分别接受3个输入获得q、k和v,而ViT的Attention是1个FC层接受1个输入后将结果均拆分获得q、k和v

    • class Attention(nn.Module):def __init__(self,embedding_dim: int,         # 输入channelnum_heads: int,             # attention的head数downsample_rate: int = 1,   # 下采样) -> None:super().__init__()self.embedding_dim = embedding_dimself.internal_dim = embedding_dim // downsample_rateself.num_heads = num_headsassert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim."# qkv获取self.q_proj = nn.Linear(embedding_dim, self.internal_dim)self.k_proj = nn.Linear(embedding_dim, self.internal_dim)self.v_proj = nn.Linear(embedding_dim, self.internal_dim)self.out_proj = nn.Linear(self.internal_dim, embedding_dim)def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:b, n, c = x.shapex = x.reshape(b, n, num_heads, c // num_heads)return x.transpose(1, 2)  # B x N_heads x N_tokens x C_per_headdef _recombine_heads(self, x: Tensor) -> Tensor:b, n_heads, n_tokens, c_per_head = x.shapex = x.transpose(1, 2)return x.reshape(b, n_tokens, n_heads * c_per_head)  # B x N_tokens x Cdef forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:# Input projectionsq = self.q_proj(q)k = self.k_proj(k)v = self.v_proj(v)# Separate into heads# B,N_heads,N_tokens,C_per_headq = self._separate_heads(q, self.num_heads)k = self._separate_heads(k, self.num_heads)v = self._separate_heads(v, self.num_heads)# Attention_, _, _, c_per_head = q.shapeattn = q @ k.permute(0, 1, 3, 2)  # B,N_heads,N_tokens,C_per_head# Scaleattn = attn / math.sqrt(c_per_head)attn = torch.softmax(attn, dim=-1)# Get outputout = attn @ v# # B,N_tokens,Cout = self._recombine_heads(out)out = self.out_proj(out)return out
      
    • MaskDeco的Attention和ViT的Attention的结构对比示意图:

    • 在这里插入图片描述

    • transformer_MLP

    • class MLPBlock(nn.Module):def __init__(self,embedding_dim: int,mlp_dim: int,act: Type[nn.Module] = nn.GELU,) -> None:super().__init__()self.lin1 = nn.Linear(embedding_dim, mlp_dim)self.lin2 = nn.Linear(mlp_dim, embedding_dim)self.act = act()def forward(self, x: torch.Tensor) -> torch.Tensor:return self.lin2(self.act(self.lin1(x)))
      
    • upscaled

    • # 在MaskDecoder的__init__定义
      self.output_upscaling = nn.Sequential(nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),     #转置卷积 上采样2倍LayerNorm2d(transformer_dim // 4),activation(),nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),activation(),
      )
      # 在MaskDecoder的predict_masks添加位置编码
      upscaled_embedding = self.output_upscaling(src)
      
    • mask_MLP

    • # 在MaskDecoder的__init__定义
      self.output_hypernetworks_mlps = nn.ModuleList([MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3)for i in range(self.num_mask_tokens)]
      )
      # 在MaskDecoder的predict_masks添加位置编码for i in range(self.num_mask_tokens):# mask_tokens_out[:, i, :]: B,1,C# output_hypernetworks_mlps: B,1,chyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]))# B,n,chyper_in = torch.stack(hyper_in_list, dim=1)b, c, h, w = upscaled_embedding.shape# B,n,c × B,c,N-->B,n,h,wmasks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
      
    • iou_MLP

    • # 在MaskDecoder的__init__定义
      self.iou_prediction_head = MLP(transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth
      )
      # 在MaskDecoder的predict_masks添加位置编码
      iou_pred = self.iou_prediction_head(iou_token_out)
      
    • MaskDeco_MLP

    • class MLP(nn.Module):def __init__(self,input_dim: int,         # 输入channelhidden_dim: int,        # 中间channeloutput_dim: int,        # 输出channelnum_layers: int,        # fc的层数sigmoid_output: bool = False,) -> None:super().__init__()self.num_layers = num_layersh = [hidden_dim] * (num_layers - 1)self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))self.sigmoid_output = sigmoid_outputdef forward(self, x):for i, layer in enumerate(self.layers):x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)if self.sigmoid_output:x = F.sigmoid(x)return x
      
    • iou_MLP

    • # 在MaskDecoder的__init__定义
      self.iou_prediction_head = MLP(transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth
      )
      # 在MaskDecoder的predict_masks添加位置编码
      iou_pred = self.iou_prediction_head(iou_token_out)
      
    • MaskDeco_MLP

    • class MLP(nn.Module):def __init__(self,input_dim: int,         # 输入channelhidden_dim: int,        # 中间channeloutput_dim: int,        # 输出channelnum_layers: int,        # fc的层数sigmoid_output: bool = False,) -> None:super().__init__()self.num_layers = num_layersh = [hidden_dim] * (num_layers - 1)self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))self.sigmoid_output = sigmoid_outputdef forward(self, x):for i, layer in enumerate(self.layers):x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)if self.sigmoid_output:x = F.sigmoid(x)return x
      

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/839211.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

记录centos中操作(查找、结束、批量)进程以及crontab定时写法的知识

环境&#xff1a;vps&#xff0c;centos7&#xff0c;python3。 近期写了个python程序&#xff0c;用青龙面板在centos上运行。程序中有while无限循环&#xff0c;但是我在青龙中设置了定时任务&#xff08;每隔半小时运行一次&#xff09;&#xff0c;于是造成了进程中有多个…

Java进阶16 单元测试XML注解

Java进阶16 一、单元测试 单元测试就是针对最小的功能单元编写测试代码&#xff0c;Java程序最小的功能单元是方法&#xff0c;因此&#xff0c;单原测试就是针对Java方法的测试&#xff0c;进而检查方法的正确性。简单理解就是测试代码的工具。 1、Junit 1.1 Junit引入 目…

全面了解CC攻击和防范策略

前言 “ CC攻击的原理就是攻击者控制某些主机不停地发大量数据包给对方服务器造成服务器资源耗尽&#xff0c;一直到宕机崩溃。” 什么是CC攻击&#xff1f; CC攻击前身是一个名为Fatboy的攻击程序&#xff0c;而之所以后来人们会称之为CC&#xff0c;也叫HTTP-FLOOD&#xff…

程序语言基础知识

文章目录 1.程序设计语言2. 程序设计语言的特点和分类3. 编译程序&#xff08;编译器&#xff09;的工作原理4. 程序语言的数据成分4.1 数据成分4.2 运算成分4.3 控制成分4.4 传输成分 1.程序设计语言 低级语言&#xff1a;机器语言和汇编语言。 机器语言&#xff1a;二进制代…

Java面向对象-常用类 (包装类)

常用类 – 包装类 基本数据类型的包装类 理解&#xff1a;包装类是8种基本数据类型对应的类 出现原因&#xff1a;Java是一种纯面向对象语言&#xff0c;但是java中有8种基本数据类型&#xff0c;破坏了java为纯面向对象的特征。为了承诺在java中一切皆对象&#xff0c;java…

c/c++ 判断质数(素数)

目录 一.常规方法 二.进阶方法 三.代码示例&#xff08;运用进阶方法&#xff09; 质数是整数且仅能被自身和1整除 一.常规方法 所以我们根据质数的这个定义便可用以下思路判断&#xff1a;设需要检测的数为x。y为除1和自己的除数 逐步检测x是否可被y整除&#xff0c;如x…

MySQL之架构设计与历史(一)

架构设计与历史 概述 和其他数据库系统相比&#xff0c;MySQL有点与众不同&#xff0c;它的架构可以在多种不同场景中应用并发挥好的作用&#xff0c;但同时也会带来一点选择上的困难。MySQL并不完美&#xff0c;却足够灵活&#xff0c;能够适应高要求的环境&#xff0c;例如…

Android 逆向学习【1】——版本/体系结构/代码学习

#Android 历史版本 参考链接&#xff1a;一篇文章让你了解Android各个版本的历程 - 知乎 (zhihu.com) 三个部分&#xff1a;api等级、版本号、代号&#xff08;这三个东西都是指的同一个系统&#xff09; API等级&#xff1a;在APP开发的时候写在清单列表里面的 版本号&…

Vitis HLS 学习笔记--控制驱动TLP - Dataflow视图

目录 1. 简介 2. 功能特性 2.1 Dataflow Viewer 的功能 2.2 Dataflow 和 Pipeline 的区别 3. 具体演示 4. 总结 1. 简介 Dataflow视图&#xff0c;即数据流查看器。 DATAFLOW优化属于一种动态优化过程&#xff0c;其完整性依赖于与RTL协同仿真的完成。因此&#xff0c;…

力扣第206题-反转链表

反转链表的效果示意图 要改变链表结构时&#xff0c;通常加入一个创建的临时头结点会更容易操作 时间复杂度&#xff1a;遍历2遍&#xff0c;2n 空间复杂度&#xff1a;额外创建一个栈&#xff0c;n (空间创建一个数组长度最大为5000&#xff0c;你说这个数组是栈也可以&…

【C++】详解多态

目录 初识多态 多态的条件 接口继承和实现继承 override 和 final 多态原理 继承与虚函数表 析构函数与多态 抽象类 本篇内容关联知识的链接 【C】详解C的继承-CSDN博客 【C】详解C的模板-CSDN博客 【C】C的内存管理-CSDN博客 初识多态 父类被不同子类继承后&#…

报名开启!2024 开源之夏丨Serverless Devs 课题已上线!

Serverless 是近年来云计算领域热门话题&#xff0c;凭借极致弹性、按量付费、降本提效等众多优势受到很多人的追捧&#xff0c;各云厂商也在不断地布局 Serverless 领域。 Serverless Devs 是一个由阿里巴巴发起的 Serverless 领域的开源项目&#xff0c;其目的是要和开发者们…

leetcode以及牛客网单链表相关的题、移除链表元素、链表的中间节点、合并两个有序链表、反转链表、链表分割、倒数第k个节点等的介绍

文章目录 前言一、移除链表元素二、链表的中间节点三、合并两个有序链表四、反转链表五、链表分割六、倒数第k个节点总结 前言 leetcode以及牛客网单链表相关的题、移除链表元素、链表的中间节点、合并两个有序链表、反转链表、链表分割、倒数第k个节点等的介绍 一、移除链表元…

扫盲:如何提升医学图像分割性能-to do list

导读&#xff1a;本文主要讨论了如何改进图像分割项目中的分割性能&#xff0c;包括一般性和具体性的问题解决方案&#xff0c;以及如何通过调整模型参数、改善数据集质量、优化模型架构、调整超参数、增加训练时长、改善图像分辨率和后处理技术等方法来提高分割效果。 图像分…

拼多多暂时超越阿里成为电商第一

关注卢松松&#xff0c;会经常给你分享一些我的经验和观点。 拼多多的财报又炸裂了&#xff1a; 拼多多发布了第一季度财报&#xff0c;营收868亿&#xff0c;增长了131%&#xff0c;净利润279亿&#xff0c;增长了246%&#xff0c;营销服务收入424亿&#xff0c;也就是商家的…

小林coding笔记

MySQL执行流程 MySQL 的架构共分为两层&#xff1a;Server 层和存储引擎层。Server 层负责建立连接、分析和执行 SQL。存储引擎层负责数据的存储和提取。 Mysql执行 启动Mysql net start mysql登陆 mysql -u root -p输入密码

SwiftUI中的动画.animation和withAnimation

动画是通过改变视图的状态来给视图添加平滑视图变化的能力。SwiftUI中有两种类型的动画:隐式动画和显式动画。 不管是哪种动画&#xff0c;我们都需要一个被State包装的状态属性值&#xff0c;通过这个值的改变来促使与之相关的UI刷新&#xff0c;继而执行动画。 隐式动画.ani…

正点原子[第二期]Linux之ARM(MX6U)裸机篇学习笔记-19讲 串口实验UART

前言&#xff1a; 本文是根据哔哩哔哩网站上“正点原子[第二期]Linux之ARM&#xff08;MX6U&#xff09;裸机篇”视频的学习笔记&#xff0c;在这里会记录下正点原子 I.MX6ULL 开发板的配套视频教程所作的实验和学习笔记内容。本文大量引用了正点原子教学视频和链接中的内容。…

Vivado IP核的快速入门 官方手册和例程

在IP Catalog中选择要使用的IP核&#xff0c;可以查看支持的器件与资料。 在设计源sources页面中选中配置完成的IP核点击右键选择 Open IP Example Design&#xff0c;等待工程加载完成即可&#xff0c;可以点击Run Simulation进行功能仿真进行IP核的学习。 参考&#xff1…

Mac Pro中的开源虚拟机UTM安装ubuntu(Applce M1,M2芯片)(1)

MacPro安装UTM 1 UTM 下载UTM虚拟机链接: https://mac.getutm.app/ 建议官网下载&#xff1a; 下载 Ubuntu Arm 64版 下载 Ubuntu Arm 64版链接: https://cn.ubuntu.com/download/server/arm 2 安装UTM 2.1 安装在mac上 2.2 点Open 2.3 建虚拟机### 2.4 点出虚拟机 2.5 O…