-
Segment Anything:建立了迄今为止最大的分割数据集,在1100万张图像上有超过1亿个掩码,模型的设计和训练是灵活的,其重要的特点是Zero-shot(零样本迁移性)转移到新的图像分布和任务,一个图像分割新的任务、模型和数据集。SAM由三个部分组成:一个强大的图像编码器(Image encoder)计算图像嵌入,一个提示编码器(Prompt encoder)嵌入提示,然后将两个信息源组合在一个轻量级掩码解码器(Mask decoder)中来预测分割掩码。在视觉领域通过Prompt+基础大模型的套路来解决目标分割的问题。
-
需要下载官方给的权重pth下载链接,权重文件可以在给的readme.md上的链接下载。下载好权重文件之后,我们就开始配置并调用SAM,主要的文件其实就在amg.py上面进行配置运行即可,其他文件大家有兴趣的可以仔细阅读一下了解。
-
主要我们就需要一个input文件,放入我们需要分割的文件路径,最好是jpg,png格式的,可以看官方支持什么格式,还有一个output文件路径,放入我们结果生成的文件。model-type就是刚才说的权重文件的类型。checkpoint就是权重文件路径,刚才下载的文件,把路径放进去即可。
-
parser.add_argument("--input",type=str,required=False,default=r'.\JPEGImages',help="Path to either a single input image or folder of images.", ) parser.add_argument("--output",type=str,required=False,default=r'.\JPEGImages\result',help=("Path to the directory where masks will be output. Output will be either a folder of PNGs per image or a single json with COCO-style masks."), ) parser.add_argument("--model-type",type=str,required=False,default='vit_h',help="The type of model to load, in ['default', 'vit_h', 'vit_l', 'vit_b']", ) parser.add_argument("--checkpoint",type=str,required=False,default=r'.\segment-anything-main\sam_vit_h_4b8939.pth',help="The path to the SAM checkpoint to use for mask generation.", )
-
-
SAM 源码提供了3种不同大小的模型。sam_model_registry函数在segment_anything/build_sam.py文件内定义,SAM的3种模型通过字典形式保存。
-
-
sam_model_registry = {"default": build_sam_vit_h,"vit_h": build_sam_vit_h,"vit_l": build_sam_vit_l,"vit_b": build_sam_vit_b, }# 选择合适的模型以及加载对应权重 sam = sam_model_registry[model_type](checkpoint=sam_checkpoint) sam.to(device=device)
-
sam_model_registry中的 3 种模型结构是一致的,部分参数不同导致模型的大小有别。
-
def build_sam_vit_h(checkpoint=None):return _build_sam(encoder_embed_dim=1280,encoder_depth=32,encoder_num_heads=16,encoder_global_attn_indexes=[7, 15, 23, 31],checkpoint=checkpoint,) def build_sam_vit_l(checkpoint=None):return _build_sam(encoder_embed_dim=1024,encoder_depth=24,encoder_num_heads=16,encoder_global_attn_indexes=[5, 11, 17, 23],checkpoint=checkpoint,) def build_sam_vit_b(checkpoint=None):return _build_sam(encoder_embed_dim=768,encoder_depth=12,encoder_num_heads=12,encoder_global_attn_indexes=[2, 5, 8, 11],checkpoint=checkpoint,)
-
-
最后是_build_sam方法,完成了sam模型的初始化以及权重的加载,这里可以注意到sam模型由三个神经网络模块组成:ImageEncoderViT(Image encoder)、PromptEncoder和MaskDecoder。
-
def _build_sam(encoder_embed_dim,encoder_depth,encoder_num_heads,encoder_global_attn_indexes,checkpoint=None, ):prompt_embed_dim = 256image_size = 1024vit_patch_size = 16image_embedding_size = image_size // vit_patch_sizesam = Sam(image_encoder=ImageEncoderViT(depth=encoder_depth,embed_dim=encoder_embed_dim,img_size=image_size,mlp_ratio=4,norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),num_heads=encoder_num_heads,patch_size=vit_patch_size,qkv_bias=True,use_rel_pos=True,global_attn_indexes=encoder_global_attn_indexes,window_size=14,out_chans=prompt_embed_dim,),prompt_encoder=PromptEncoder(embed_dim=prompt_embed_dim,image_embedding_size=(image_embedding_size, image_embedding_size),input_image_size=(image_size, image_size),mask_in_chans=16,),mask_decoder=MaskDecoder(num_multimask_outputs=3,transformer=TwoWayTransformer(depth=2,embedding_dim=prompt_embed_dim,mlp_dim=2048,num_heads=8,),transformer_dim=prompt_embed_dim,iou_head_depth=3,iou_head_hidden_dim=256,),pixel_mean=[123.675, 116.28, 103.53],pixel_std=[58.395, 57.12, 57.375],)sam.eval()if checkpoint is not None:with open(checkpoint, "rb") as f:state_dict = torch.load(f)sam.load_state_dict(state_dict)return sam
-
-
SamPredictor类,sam模型被封装在SamPredictor类的对象中,方便使用。SamPredictor类在segment_anything/predictor.py文件。
-
predictor = SamPredictor(sam) predictor.set_image(image) # image_encoder操作在set_image时就已经执行了,而不是在predic时
-
-
首先确认输入是否是RGB或BGR三通道图像,将BGR图像统一为RGB,而后并对图像尺寸和channel顺序作出调整满足神经网络的输入要求。
-
def set_image(self, image: np.ndarray, image_format: str = "RGB",) -> None:# 图像不是['RGB', 'BGR']格式则报错assert image_format in ["RGB","BGR",], f"image_format must be in ['RGB', 'BGR'], is {image_format}."# H,W,Cif image_format != self.model.image_format:image = image[..., ::-1] # H,W,C中 C通道的逆序RGB-->BGR# Transform the image to the form expected by the model 改变图像尺寸input_image = self.transform.apply_image(image)# torch 浅拷贝 转tensorinput_image_torch = torch.as_tensor(input_image, device=self.device)# permute H,W,C-->C,H,W# contiguous 连续内存# [None, :, :, :] C,H,W -->1,C,H,Winput_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :]self.set_torch_image(input_image_torch, image.shape[:2])
-
-
set_torch_image:用padding填补缩放后的图片,在 H 和 W 满足神经网络需要的标准尺寸,而后通过image_encoder模型获得图像特征数据并保存在self.features中,同时self.is_image_set设为true。注意image_encoder过程不是在predict_torch时与Prompt encoder过程和Mask decoder过程一同执行的,而是在set_image时就已经执行了。
-
def set_torch_image(self,transformed_image: torch.Tensor,original_image_size: Tuple[int, ...], ) -> None:# 满足输入是四个维度且为B,C,H,Wassert (len(transformed_image.shape) == 4and transformed_image.shape[1] == 3and max(*transformed_image.shape[2:]) == self.model.image_encoder.img_size), f"set_torch_image input must be BCHW with long side {self.model.image_encoder.img_size}."self.reset_image()# 原始图像的尺寸self.original_size = original_image_size# torch图像的尺寸self.input_size = tuple(transformed_image.shape[-2:])# torch图像进行paddinginput_image = self.model.preprocess(transformed_image)# image_encoder网络模块对图像进行编码self.features = self.model.image_encoder(input_image)# 图像设置flagself.is_image_set = True
-
-
predict对输入到模型中进行预测的数据(标记点 apply_coords 和标记框 apply_boxes )进行一个预处理,并接受和处理模型返回的预测结果。
-
def predict(self,# 标记点的坐标point_coords: Optional[np.ndarray] = None,# 标记点的标签point_labels: Optional[np.ndarray] = None,# 标记框的坐标box: Optional[np.ndarray] = None,# 输入的maskmask_input: Optional[np.ndarray] = None,# 输出多个mask供选择multimask_output: bool = True,# ture 返回掩码logits, false返回阈值处理的二进制掩码。return_logits: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:# 假设没有设置图像,报错if not self.is_image_set:raise RuntimeError("An image must be set with .set_image(...) before mask prediction.")# Transform input prompts # 输入提示转换为torchcoords_torch, labels_torch, box_torch, mask_input_torch = None, None, None, Noneif point_coords is not None:# 标记点坐标对应的标记点标签不能为空assert (point_labels is not None), "point_labels must be supplied if point_coords is supplied."# 图像改变了原始尺寸,所以对应的点位置也会发生改变point_coords = self.transform.apply_coords(point_coords, self.original_size)# 标记点坐标和标记点标签 np-->tensorcoords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=self.device)labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=self.device)# 增加维度# coords_torch:N,2-->1,N,2# labels_torch: N-->1,Ncoords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :]if box is not None:# 图像改变了原始尺寸,所以对应的框坐标位置也会发生改变box = self.transform.apply_boxes(box, self.original_size)# 标记框坐标 np-->tensorbox_torch = torch.as_tensor(box, dtype=torch.float, device=self.device)# 增加维度 N,4-->1,N,4box_torch = box_torch[None, :]if mask_input is not None:# mask np-->tensormask_input_torch = torch.as_tensor(mask_input, dtype=torch.float, device=self.device)# 增加维度 1,H,W-->B,1,H,Wmask_input_torch = mask_input_torch[None, :, :, :]# 输入数据预处理完毕,可以输入到网络中 masks, iou_predictions, low_res_masks = self.predict_torch(coords_torch,labels_torch,box_torch,mask_input_torch,multimask_output,return_logits=return_logits,)# 因为batchsize为1,压缩维度# maskmasks = masks[0].detach().cpu().numpy()# scoreiou_predictions = iou_predictions[0].detach().cpu().numpy()low_res_masks = low_res_masks[0].detach().cpu().numpy()return masks, iou_predictions, low_res_masks def postprocess_masks(self,masks: torch.Tensor,input_size: Tuple[int, ...],original_size: Tuple[int, ...], ) -> torch.Tensor:# mask上采样到与输入到模型中的图片尺寸一致masks = F.interpolate(masks,(self.image_encoder.img_size, self.image_encoder.img_size),mode="bilinear",align_corners=False,)masks = masks[..., : input_size[0], : input_size[1]]# mask resize 到与未做处理的原始图片尺寸一致masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False)return masks
-
-
predict_torch:输入数据经过预处理后输入到模型中预测结果。Prompt encoder过程和Mask decoder过程是在predict_torch时执行的。
-
def predict_torch(self,point_coords: Optional[torch.Tensor],point_labels: Optional[torch.Tensor],boxes: Optional[torch.Tensor] = None,mask_input: Optional[torch.Tensor] = None,multimask_output: bool = True,return_logits: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:# 假设没有设置图像,报错if not self.is_image_set:raise RuntimeError("An image must be set with .set_image(...) before mask prediction.")# 绑定标记点和标记点标签if point_coords is not None:points = (point_coords, point_labels)else:points = None# ----- Prompt encoder -----sparse_embeddings, dense_embeddings = self.model.prompt_encoder(points=points,boxes=boxes,masks=mask_input,)# ----- Prompt encoder -----# ----- Mask decoder -----low_res_masks, iou_predictions = self.model.mask_decoder(image_embeddings=self.features,image_pe=self.model.prompt_encoder.get_dense_pe(),sparse_prompt_embeddings=sparse_embeddings,dense_prompt_embeddings=dense_embeddings,multimask_output=multimask_output,)# ----- Mask decoder -----# 上采样mask掩膜到原始图片尺寸# Upscale the masks to the original image resolutionmasks = self.model.postprocess_masks(low_res_masks, self.input_size, self.original_size)if not return_logits:masks = masks > self.model.mask_thresholdreturn masks, iou_predictions, low_res_masks
-
-
get_image_embedding:获得图像image_encoder的特征。
-
def get_image_embedding(self) -> torch.Tensor:if not self.is_image_set:raise RuntimeError("An image must be set with .set_image(...) to generate an embedding.")assert self.features is not None, "Features must exist if an image has been set."return self.features
-
-
ResizeLongestSide是专门用来处理图片、标记点和标记框的工具类。ResizeLongestSide类在segment_anything/utils/transforms.py文件。
-
apply_image:原图尺寸根据标准尺寸计算调整(get_preprocess_shape)得新尺寸。
-
def apply_image(self, image: np.ndarray) -> np.ndarray:target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length)# to_pil_image将numpy装变为PIL.Image,而后resizereturn np.array(resize(to_pil_image(image), target_size))
-
不直接使用resize的目的是为了不破坏原图片中各个物体的比例关系。通过计算获得与标准尺寸对应的缩放比例并缩放图片,后续通过padding补零操作(虚线部分),将所有图片的尺寸都变成标准尺寸。
-
apply_coords:图像改变了原始尺寸,对应的标记点坐标位置也要改变[get_preprocess_shape]。
-
def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray:old_h, old_w = original_size# 图像改变了原始尺寸,所以对应的标记点坐标位置也会发生改变new_h, new_w = self.get_preprocess_shape(original_size[0], original_size[1], self.target_length)# 深拷贝coordscoords = deepcopy(coords).astype(float)# 改变对应标记点坐标coords[..., 0] = coords[..., 0] * (new_w / old_w)coords[..., 1] = coords[..., 1] * (new_h / old_h)return coords
-
apply_boxes:图像改变了原始尺寸,对应的标记框坐标位置也要改变[get_preprocess_shape]。
-
def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray:# 图像改变了原始尺寸,所以对应的框坐标位置也会发生改变# reshape: N,4-->N,2,2boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size)# reshape: N,2,2-->N,4return boxes.reshape(-1, 4)
-
get_preprocess_shape
-
def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]:# H和W的长边(大值)作为基准,计算比例,缩放H W的大小scale = long_side_length * 1.0 / max(oldh, oldw)newh, neww = oldh * scale, oldw * scale# 四舍五入neww = int(neww + 0.5)newh = int(newh + 0.5)return (newh, neww)
-
-
图像编码器
-
SAM模型关于ViT网络的配置,以sam_vit_b为例,分析ViT网络的结构。
-
def build_sam_vit_b(checkpoint=None):return _build_sam(# 图像编码channelencoder_embed_dim=768,# 主体编码器的个数encoder_depth=12,# attention中head的个数encoder_num_heads=12,# 需要将相对位置嵌入添加到注意力图的编码器( Encoder Block)encoder_global_attn_indexes=[2, 5, 8, 11],# 权重checkpoint=checkpoint,)
-
sam模型中image_encoder模块初始化
-
image_encoder=ImageEncoderViT(# 主体编码器的个数depth=encoder_depth,# 图像编码channelembed_dim=encoder_embed_dim,# 输入图像的标准尺寸img_size=image_size,# mlp中channel缩放的比例mlp_ratio=4,# 归一化层norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),# attention中head的个数num_heads=encoder_num_heads,# patch的大小patch_size=vit_patch_size,# qkv全连接层的偏置qkv_bias=True,# 是否需要将相对位置嵌入添加到注意力图use_rel_pos=True,# 需要将相对位置嵌入添加到注意力图的编码器序号(Encoder Block)global_attn_indexes=encoder_global_attn_indexes,# attention中的窗口大小window_size=14,# 输出的channelout_chans=prompt_embed_dim, ),
-
ViT网络(ImageEncoderViT类)结构参数配置。
-
def __init__(self,img_size: int = 1024, # 输入图像的标准尺寸patch_size: int = 16, # patch的大小in_chans: int = 3, # 输入图像channelembed_dim: int = 768, # 图像编码channeldepth: int = 12, # 主体编码器的个数num_heads: int = 12, # attention中head的个数mlp_ratio: float = 4.0, # mlp中channel缩放的比例out_chans: int = 256, # 输出特征的channelqkv_bias: bool = True, # qkv全连接层的偏置flagnorm_layer: Type[nn.Module] = nn.LayerNorm, # 归一化层act_layer: Type[nn.Module] = nn.GELU, # 激活层use_abs_pos: bool = True, # 是否使用绝对位置嵌入use_rel_pos: bool = False, # 是否需要将相对位置嵌入添加到注意力图rel_pos_zero_init: bool = True, # 源码暂时没有用到window_size: int = 0, # attention中的窗口大小global_attn_indexes: Tuple[int, ...] = (), # 需要将相对位置嵌入添加到注意力图的编码器序号(Encoder Block) ) -> None:super().__init__()self.img_size = img_size# -----patch embedding-----self.patch_embed = PatchEmbed(kernel_size=(patch_size, patch_size),stride=(patch_size, patch_size),in_chans=in_chans,embed_dim=embed_dim,)# -----patch embedding-----# -----positional embedding-----self.pos_embed: Optional[nn.Parameter] = Noneif use_abs_pos:# Initialize absolute positional embedding with pretrain image size.# 使用预训练图像大小初始化绝对位置嵌入。self.pos_embed = nn.Parameter(torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim))# -----positional embedding-----# -----Transformer Encoder-----self.blocks = nn.ModuleList()for i in range(depth):block = Block(dim=embed_dim,num_heads=num_heads,mlp_ratio=mlp_ratio,qkv_bias=qkv_bias,norm_layer=norm_layer,act_layer=act_layer,use_rel_pos=use_rel_pos,rel_pos_zero_init=rel_pos_zero_init,window_size=window_size if i not in global_attn_indexes else 0,input_size=(img_size // patch_size, img_size // patch_size),)self.blocks.append(block)# -----Transformer Encoder-----# -----Neck-----self.neck = nn.Sequential(nn.Conv2d(embed_dim,out_chans,kernel_size=1,bias=False,),LayerNorm2d(out_chans),nn.Conv2d(out_chans,out_chans,kernel_size=3,padding=1,bias=False,),LayerNorm2d(out_chans),)# -----Neck-----
-
-
ViT网络(ImageEncoderViT类)在特征提取中的几个基本步骤:
-
patch embedding:将图片切分成图片序列块,再经过维度映射后展平成一维向量
-
positional embedding:嵌入位置编码(用于保留位置信息)
-
Transformer Encoder:主体编码器
-
Neck:过渡层
-
def forward(self, x: torch.Tensor) -> torch.Tensor:# patch embedding过程x = self.patch_embed(x)# positional embedding过程if self.pos_embed is not None:x = x + self.pos_embed# Transformer Encoder过程for blk in self.blocks:x = blk(x)# Neck过程 B H W C -> B C H Wx = self.neck(x.permute(0, 3, 1, 2))return x
-
PatchEmbed类: 源码其实就是卷积核大小16x16(巧妙切分成固定大小16x16的patch),卷积核通道3×768的卷积操作。图像大小决定了patch的数量
-
class PatchEmbed(nn.Module):def __init__(self,kernel_size: Tuple[int, int] = (16, 16), # 卷积核大小stride: Tuple[int, int] = (16, 16), # 步长padding: Tuple[int, int] = (0, 0), # paddingin_chans: int = 3, # 输入channelembed_dim: int = 768, # 输出channel) -> None:super().__init__()self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding)def forward(self, x: torch.Tensor) -> torch.Tensor:x = self.proj(x)# B C H W -> B H W Cx = x.permute(0, 2, 3, 1)return x
-
经过patch embedding后输出tokens需要加入位置编码,位置编码可以理解为一张map,map的行数与输入序列个数相同,每一行代表一个向量,向量的维度和输入序列tokens的维度相同,位置编码的操作是sum,所以维度依旧保持不变。图像尺寸是1024的,因此patch数量是64(=1024/16)
-
# 在ImageEncoderViT的__init__定义 if use_abs_pos:# Initialize absolute positional embedding with pretrain image size.# 使用预训练图像大小初始化绝对位置嵌入。self.pos_embed = nn.Parameter(torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim)) # 在ImageEncoderViT的forward添加位置编码 if self.pos_embed is not None:x = x + self.pos_embed
-
Transformer Encoder多个重复堆叠Encoder Block组成。
-
# 在ImageEncoderViT的__init__定义 # -----Transformer Encoder----- self.blocks = nn.ModuleList() for i in range(depth):block = Block(dim=embed_dim, # 输入channelnum_heads=num_heads, # attention中head的个数mlp_ratio=mlp_ratio, # mlp中channel缩放的比例qkv_bias=qkv_bias, # qkv全连接层的偏置flagnorm_layer=norm_layer, # 归一化层act_layer=act_layer, # 激活层use_rel_pos=use_rel_pos, # 是否需要将相对位置嵌入添加到注意力图rel_pos_zero_init=rel_pos_zero_init, # 源码暂时没有用到window_size=window_size if i not in global_attn_indexes else 0, # attention中的窗口大小input_size=(img_size // patch_size, img_size // patch_size), # 输入特征的尺寸)self.blocks.append(block) # -----Transformer Encoder-----
-
Encoder Block从低到高由LayerNorm 、Multi-Head Attention和MLP构成。
-
class Block(nn.Module):def __init__(self,dim: int, # 输入channelnum_heads: int, # attention中head的个数mlp_ratio: float = 4.0, # mlp中channel缩放的比例qkv_bias: bool = True, # qkv全连接层的偏置flagnorm_layer: Type[nn.Module] = nn.LayerNorm, # 归一化层act_layer: Type[nn.Module] = nn.GELU, # 激活层use_rel_pos: bool = False, # 是否需要将相对位置嵌入添加到注意力图rel_pos_zero_init: bool = True, # 源码暂时没有用到window_size: int = 0, # attention中的窗口大小input_size: Optional[Tuple[int, int]] = None, # 输入特征的尺寸) -> None:super().__init__()self.norm1 = norm_layer(dim) # 激活层self.attn = Attention( # Multi-Head Attentiondim,num_heads=num_heads,qkv_bias=qkv_bias,use_rel_pos=use_rel_pos,rel_pos_zero_init=rel_pos_zero_init,input_size=input_size if window_size == 0 else (window_size, window_size),)self.norm2 = norm_layer(dim)self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer) # MLPself.window_size = window_size #def forward(self, x: torch.Tensor) -> torch.Tensor:shortcut = xx = self.norm1(x)# Window partition 对X进行paddingif self.window_size > 0:H, W = x.shape[1], x.shape[2]x, pad_hw = window_partition(x, self.window_size)x = self.attn(x)# Reverse window partition 去除X的padding部分if self.window_size > 0:x = window_unpartition(x, self.window_size, pad_hw, (H, W))x = shortcut + xx = x + self.mlp(self.norm2(x))return x
-
Partition操作
-
def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]:B, H, W, C = x.shapepad_h = (window_size - H % window_size) % window_sizepad_w = (window_size - W % window_size) % window_sizeif pad_h > 0 or pad_w > 0:x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))Hp, Wp = H + pad_h, W + pad_w# B,Hp/S,S,Wp/S,S,Cx = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)# B,Hp/S,Wp/S,S,S,C-->BHpWp/SS,S,S,Cwindows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)return windows, (Hp, Wp)
-
Unpartition操作
-
def window_unpartition(windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int] ) -> torch.Tensor:Hp, Wp = pad_hwH, W = hwB = windows.shape[0] // (Hp * Wp // window_size // window_size)# BHpWp/SS,S,S,C-->B,Hp/S,Wp/S,S,S,Cx = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)# B,Hp/S,Wp/S,S,S,C-->B,Hp,Wp,Cx = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)if Hp > H or Wp > W:x = x[:, :H, :W, :].contiguous()# B,H,W,Creturn x
-
window_partition调整了原始特征尺寸为(H×W–>S×S),目的是了在后续的Multi-Head Attention过程中将相对位置嵌入添加到注意力图(attn),并不是所有Block都需要在注意力图中嵌入相对位置信息;window_unpartition则是恢复特征的原始尺寸(S×S–>H×W)。
-
Multi-Head Attention:先从Attention讲解,再到Multi-Head Attention,最后再讲注意力特征嵌入了相对位置特征的Multi-Head Attention。
-
class Attention(nn.Module):"""Multi-head Attention block with relative position embeddings."""def __init__(self,dim: int, # 输入channelnum_heads: int = 8, # head数目qkv_bias: bool = True,use_rel_pos: bool = False,rel_pos_zero_init: bool = True,input_size: Optional[Tuple[int, int]] = None, # 嵌入相对位置注意力特征的尺寸) -> None:super().__init__()self.num_heads = num_headshead_dim = dim // num_headsself.scale = head_dim**-0.5self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)self.proj = nn.Linear(dim, dim)self.use_rel_pos = use_rel_posif self.use_rel_pos: # 使用相对位置编码assert (input_size is not None), "Input size must be provided if using relative positional encoding."# initialize relative positional embeddings# 2S-1,Eposself.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))def forward(self, x: torch.Tensor) -> torch.Tensor:B, H, W, _ = x.shape# qkv with shape (3, B, nHead, H * W, C)qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)# q, k, v with shape (B * nHead, H * W, C)q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)# attn with shape (B * nHead, H * W, H * W)attn = (q * self.scale) @ k.transpose(-2, -1)if self.use_rel_pos:# 假设use_rel_pos是true (H, W)是 S×Sattn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))attn = attn.softmax(dim=-1)x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)x = self.proj(x)return x
-
对于输入到Multi-head attention模块的特征 F(N×E) ,通过attention模块的nn.Linear进一步提取特征获得输出特征 v(value) 。为了考虑 N 个特征之间存在的亲疏和位置关系对于 v 的影响,所以需要一个额外 attn(attention) 或者理解为权重 w(weight) 对 v 进行加权操作,这引出了计算 w 所需的 q(query) 与 k(key) ,因此可以看到任何V都考虑了N 个token特征之间相互的影响。Multi-head attention的流程如下图所示(不考虑batchsize):
- 首先将每个token的qkv特征维度embed_dim均拆分到每个head的上
- 每个head分别通过q和k计算得到权重w,权重w和v得到输出output,合并所有head的output得到最终的output
-
get_rel_pos用于计算h和w的相对位置的嵌入特征
-
def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:max_rel_dist = int(2 * max(q_size, k_size) - 1)# Interpolate rel pos if needed.if rel_pos.shape[0] != max_rel_dist:# Interpolate rel pos. 相关位置进行插值rel_pos_resized = F.interpolate(# 1,N,Ep --> 1,Ep,N --> 1,Ep,2S-1rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),size=max_rel_dist,mode="linear",)# Ep,2S-1 --> 2S-1,Eprel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)else:rel_pos_resized = rel_pos# Scale the coords with short length if shapes for q and k are different.# 如果q和k长度值不同,则用短边长度缩放坐标。q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)# S,Srelative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)# tensor索引是tensor时,即tensor1[tensor2]# 假设tensor2某个具体位置值是2,则tensor1[2]位置的tensor1切片替换tensor2中的2# tensor1->shape 5,5,3 tensor2->shape 2,2,3 tensor1切片->shape 5,3 tensor1[tensor2]->shape 2,2,3,5,3# tensor1->shape 5,5 tensor2->shape 3,2,3 tensor1切片->shape 5 tensor1[tensor2]->shape 3,2,3,5# 2S-1,Ep-->S,S,Epreturn rel_pos_resized[relative_coords.long()]
-
add_decomposed_rel_pos为atten注意力特征添加相对位置的嵌入特征。
-
def add_decomposed_rel_pos(attn: torch.Tensor,q: torch.Tensor,rel_pos_h: torch.Tensor,rel_pos_w: torch.Tensor,q_size: Tuple[int, int],k_size: Tuple[int, int], ) -> torch.Tensor:# S,Sq_h, q_w = q_sizek_h, k_w = k_size# rel_pos_h -> 2S-1×EposRh = get_rel_pos(q_h, k_h, rel_pos_h)Rw = get_rel_pos(q_w, k_w, rel_pos_w)B, _, dim = q.shaper_q = q.reshape(B, q_h, q_w, dim)# torch.einsum用于简洁的表示乘积、点积、转置等方法# B,q_h, q_w, k_hrel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)# B,q_h, q_w, k_wrel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)attn = (# B,q_h, q_w, k_h, k_wattn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]).view(B, q_h * q_w, k_h * k_w)return attn
-
MLP
-
class MLPBlock(nn.Module):def __init__(self,embedding_dim: int,mlp_dim: int,act: Type[nn.Module] = nn.GELU,) -> None:super().__init__()self.lin1 = nn.Linear(embedding_dim, mlp_dim)self.lin2 = nn.Linear(mlp_dim, embedding_dim)self.act = act()def forward(self, x: torch.Tensor) -> torch.Tensor:return self.lin2(self.act(self.lin1(x)))
-
Neck
-
# 在ImageEncoderViT的__init__定义 # -----Neck----- self.neck = nn.Sequential(nn.Conv2d(embed_dim,out_chans,kernel_size=1,bias=False,),LayerNorm2d(out_chans),nn.Conv2d(out_chans,out_chans,kernel_size=3,padding=1,bias=False,),LayerNorm2d(out_chans), ) # -----Neck----- class LayerNorm2d(nn.Module):def __init__(self, num_channels: int, eps: float = 1e-6) -> None:super().__init__()self.weight = nn.Parameter(torch.ones(num_channels))self.bias = nn.Parameter(torch.zeros(num_channels))self.eps = epsdef forward(self, x: torch.Tensor) -> torch.Tensor:u = x.mean(1, keepdim=True) # dim=1维度求均值并保留通道s = (x - u).pow(2).mean(1, keepdim=True)x = (x - u) / torch.sqrt(s + self.eps)x = self.weight[:, None, None] * x + self.bias[:, None, None]return x
-
-
sam模型中prompt_encoder模块初始化
-
prompt_encoder=PromptEncoder(# 提示编码channel(和image_encoder输出channel一致,后续会融合)embed_dim=prompt_embed_dim,# mask的编码尺寸(和image_encoder输出尺寸一致)image_embedding_size=(image_embedding_size, image_embedding_size),# 输入图像的标准尺寸input_image_size=(image_size, image_size),# 对输入掩码编码的通道数mask_in_chans=16, ),
-
-
ProEnco网络结构与执行流程,ProEnco网络(PromptEncoder类)结构参数配置。
-
def __init__(self,embed_dim: int, # 提示编码channelimage_embedding_size: Tuple[int, int], # mask的编码尺寸input_image_size: Tuple[int, int], # 输入图像的标准尺寸mask_in_chans: int, # 输入掩码编码的通道数activation: Type[nn.Module] = nn.GELU, # 激活层 ) -> None:super().__init__()self.embed_dim = embed_dim # 提示编码channelself.input_image_size = input_image_size # 输入图像的标准尺寸self.image_embedding_size = image_embedding_size # mask的编码尺寸self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)self.num_point_embeddings: int = 4 # 4个点:正负点,框的俩个点point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)] # 4个点的嵌入向量# nn.ModuleList它是一个存储不同module,# 并自动将每个module的parameters添加到网络之中的容器self.point_embeddings = nn.ModuleList(point_embeddings) # 4个点的嵌入向量添加到网络self.not_a_point_embed = nn.Embedding(1, embed_dim) # 不是点的嵌入向量self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1]) # mask的输入尺寸self.mask_downscaling = nn.Sequential( # 输入mask时 4倍下采样nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2),LayerNorm2d(mask_in_chans // 4),activation(),nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2),LayerNorm2d(mask_in_chans),activation(),nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),)self.no_mask_embed = nn.Embedding(1, embed_dim) # 没有mask输入时 嵌入向量
-
SAM模型中ProEnco网络结构如下图所示:
-
ProEnco网络(PromptEncoder类)在特征提取中的几个基本步骤:
- Embed_Points:标记点编码(标记点由点转变为向量)
- Embed_Boxes:标记框编码(标记框由点转变为向量)
- Embed_Masks:mask编码(mask下采样保证与Image encoder输出一致)
-
def forward(self,points: Optional[Tuple[torch.Tensor, torch.Tensor]],boxes: Optional[torch.Tensor],masks: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]:# 获得 batchsize 当前predict为1bs = self._get_batch_size(points, boxes, masks)# -----sparse_embeddings----sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device())if points is not None:coords, labels = pointspoint_embeddings = self._embed_points(coords, labels, pad=(boxes is None))sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)if boxes is not None:box_embeddings = self._embed_boxes(boxes)sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1)# -----sparse_embeddings----# -----dense_embeddings----if masks is not None:dense_embeddings = self._embed_masks(masks)else:dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(bs, -1, self.image_embedding_size[0], self.image_embedding_size[1])# -----dense_embeddings----return sparse_embeddings, dense_embeddings def _get_batch_size(self,points: Optional[Tuple[torch.Tensor, torch.Tensor]],boxes: Optional[torch.Tensor],masks: Optional[torch.Tensor], ) -> int:if points is not None:return points[0].shape[0]elif boxes is not None:return boxes.shape[0]elif masks is not None:return masks.shape[0]else:return 1 def _get_device(self) -> torch.device:return self.point_embeddings[0].weight.device
-
Embed_Points:标记点预处理,将channel由2变成embed_dim(MatMul:forward_with_coords),然后再加上位置编码权重。
-
def _embed_points(self,points: torch.Tensor,labels: torch.Tensor,pad: bool, ) -> torch.Tensor:# 移到像素中心points = points + 0.5# points和boxes联合则不需要padif pad:padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device) # B,1,2padding_label = -torch.ones((labels.shape[0], 1), device=labels.device) # B,1points = torch.cat([points, padding_point], dim=1) # B,N+1,2labels = torch.cat([labels, padding_label], dim=1) # B,N+1point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size) # B,N+1,2f# labels为-1是非标记点,设为非标记点权重point_embedding[labels == -1] = 0.0point_embedding[labels == -1] += self.not_a_point_embed.weight# labels为0是背景点,加上背景点权重point_embedding[labels == 0] += self.point_embeddings[0].weight# labels为1的目标点,加上目标点权重point_embedding[labels == 1] += self.point_embeddings[1].weightreturn point_embedding
-
pad的作用相当于box占位符号,box和points可以联合标定完成图像分割的,但是此时的box只能有一个,不能有多个。
-
Embed_Boxes:标记框预处理,将channel由4到2再变成embed_dim(MatMul:forward_with_coords),然后再加上位置编码权重。
-
def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:# 移到像素中心boxes = boxes + 0.5coords = boxes.reshape(-1, 2, 2)corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size) ## 目标框起始点的和末位点分别加上权重corner_embedding[:, 0, :] += self.point_embeddings[2].weightcorner_embedding[:, 1, :] += self.point_embeddings[3].weightreturn corner_embedding
-
boxes reshape 后 batchsize 是会增加的,B,N,4–>BN,2,2;因此这里可以得出box和points联合标定时,box为什么只能是一个,而不能是多个。
-
Embed_Masks:mask的输出尺寸是Image encoder模块输出的图像编码尺寸的4倍,因此为了保持一致,需要4倍下采样。
-
def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor:# mask下采样4倍mask_embedding = self.mask_downscaling(masks)return mask_embedding # 在PromptEncoder的__init__定义 self.mask_downscaling = nn.Sequential( # 输入mask时 4倍下采样nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2),LayerNorm2d(mask_in_chans // 4),activation(),nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2),LayerNorm2d(mask_in_chans),activation(),nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),)
-
假设没有mask输入,则将no_mask_embed编码扩展到与图像编码一致的尺寸代替mask。
-
# 在PromptEncoder的forward定义 dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(bs, -1, self.image_embedding_size[0], self.image_embedding_size[1] )
-
PositionEmbeddingRandom:用于将标记点和标记框的坐标进行提示编码预处理。
-
def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:super().__init__()if scale is None or scale <= 0.0:scale = 1.0# 理解为模型的常数 [2,f]self.register_buffer("positional_encoding_gaussian_matrix",scale * torch.randn((2, num_pos_feats)),)
-
将标记点的坐标具体的位置转变为[0~1]之间的比例位置
-
def forward_with_coords(self, coords_input: torch.Tensor, image_size: Tuple[int, int] ) -> torch.Tensor:coords = coords_input.clone()# 将坐标位置缩放到[0~1]之间coords[:, :, 0] = coords[:, :, 0] / image_size[1]coords[:, :, 1] = coords[:, :, 1] / image_size[0]# B,N+1,2-->B,N+1,2freturn self._pe_encoding(coords.to(torch.float))
-
标记点位置编码,因为sin和cos,编码的值归一化至 [-1,1]
-
def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:# assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shapecoords = 2 * coords - 1# B,N+1,2 × 2,f --> B,N+1,fcoords = coords @ self.positional_encoding_gaussian_matrixcoords = 2 * np.pi * coords# outputs d_1 x ... x d_n x C shape# B,N+1,2freturn torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
-
-
MaskDecoder网络简述,sam模型中Mask_decoder模块初始化
-
mask_decoder=MaskDecoder(# 消除掩码歧义预测的掩码数num_multimask_outputs=3,# 用于预测mask的网咯transformertransformer=TwoWayTransformer(# 层数depth=2,# 输入channelembedding_dim=prompt_embed_dim,# MLP内部channelmlp_dim=2048,# attention的head数num_heads=8,),# transformer的channeltransformer_dim=prompt_embed_dim,# MLP的深度,MLP用于预测掩模质量的iou_head_depth=3,# MLP隐藏channeliou_head_hidden_dim=256, ),
-
MaskDeco网络(MaskDecoder类)结构参数配置。
-
def __init__(self,*,# transformer的channeltransformer_dim: int,# 用于预测mask的网咯transformertransformer: nn.Module,# 消除掩码歧义预测的掩码数num_multimask_outputs: int = 3,# 激活层activation: Type[nn.Module] = nn.GELU,# MLP深度,MLP用于预测掩模质量的iou_head_depth: int = 3,# MLP隐藏channeliou_head_hidden_dim: int = 256, ) -> None:super().__init__()self.transformer_dim = transformer_dim # transformer的channel#----- transformer -----self.transformer = transformer # 用于预测mask的网咯transformer# ----- transformer -----self.num_multimask_outputs = num_multimask_outputs # 消除掩码歧义预测的掩码数self.iou_token = nn.Embedding(1, transformer_dim) # iou的takenself.num_mask_tokens = num_multimask_outputs + 1 # mask数self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim) # mask的tokens数#----- upscaled -----# 4倍上采样self.output_upscaling = nn.Sequential(nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2), #转置卷积 上采样2倍LayerNorm2d(transformer_dim // 4),activation(),nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),activation(),)# ----- upscaled -----# ----- MLP -----# 对应mask数的MLPself.output_hypernetworks_mlps = nn.ModuleList([MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3)for i in range(self.num_mask_tokens)])# ----- MLP -----# ----- MLP -----# 对应iou的MLPself.iou_prediction_head = MLP(transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth)# ----- MLP -----
-
SAM模型中MaskDeco网络结构如下图所示:
-
MaskDeco网络(MaskDecoder类)在特征提取中的几个基本步骤:
- transformer:融合特征(提示信息特征与图像特征)获得粗略掩膜src
- upscaled:对粗略掩膜src上采样
- mask_MLP:全连接层组(计算加权权重,使粗掩膜src转变为掩膜mask)
- iou_MLP:全连接层组(计算掩膜mask的Score)
-
def forward(self,# image encoder 图像特征image_embeddings: torch.Tensor,# 位置编码image_pe: torch.Tensor,# 标记点和标记框的嵌入编码sparse_prompt_embeddings: torch.Tensor,# 输入mask的嵌入编码dense_prompt_embeddings: torch.Tensor,# 是否输出多个maskmultimask_output: bool, ) -> Tuple[torch.Tensor, torch.Tensor]:masks, iou_pred = self.predict_masks(image_embeddings=image_embeddings,image_pe=image_pe,sparse_prompt_embeddings=sparse_prompt_embeddings,dense_prompt_embeddings=dense_prompt_embeddings,)# Select the correct mask or masks for outputif multimask_output:mask_slice = slice(1, None)else:mask_slice = slice(0, 1)masks = masks[:, mask_slice, :, :]iou_pred = iou_pred[:, mask_slice]return masks, iou_pred def predict_masks(self,image_embeddings: torch.Tensor,image_pe: torch.Tensor,sparse_prompt_embeddings: torch.Tensor,dense_prompt_embeddings: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]:# Concatenate output tokens# 1,E and 4,E --> 5,Eoutput_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)# 5,E --> B,5,Eoutput_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)# B,5,E and B,N,E -->B,5+N,E N是点的个数(标记点和标记框的点)tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)# 扩展image_embeddings的B维度,因为boxes标记分割时,n个box时batchsize=batchsize*n# Expand per-image data in batch direction to be per-mask# B,C,H,Wsrc = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)# B,C,H,W + 1,C,H,W ---> B,C,H,Wsrc = src + dense_prompt_embeddings# 1,C,H,W---> B,C,H,Wpos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)b, c, h, w = src.shape# ----- transformer -----# Run the transformer# B,N,Chs, src = self.transformer(src, pos_src, tokens)# ----- transformer -----iou_token_out = hs[:, 0, :]mask_tokens_out = hs[:, 1: (1 + self.num_mask_tokens), :]# Upscale mask embeddings and predict masks using the mask tokens# B,N,C-->B,C,H,W src = src.transpose(1, 2).view(b, c, h, w)# ----- upscaled -----# 4倍上采样upscaled_embedding = self.output_upscaling(src)# ----- upscaled -----hyper_in_list: List[torch.Tensor] = []# ----- mlp -----for i in range(self.num_mask_tokens):# mask_tokens_out[:, i, :]: B,1,C# output_hypernetworks_mlps: B,1,chyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]))# B,n,chyper_in = torch.stack(hyper_in_list, dim=1)# ----- mlp -----b, c, h, w = upscaled_embedding.shape# B,n,c × B,c,N-->B,n,h,wmasks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)# ----- mlp -----# Generate mask quality predictions# iou_token_out: B,1,niou_pred = self.iou_prediction_head(iou_token_out)# ----- mlp -----# masks: B,n,h,w# iou_pred: B,1,nreturn masks, iou_pred
-
MaskDeco由多个重复堆叠TwoWayAttention Block和1个Multi-Head Attention组成。
-
class TwoWayTransformer(nn.Module):def __init__(self,# 层数depth: int,# 输入channelembedding_dim: int,# attention的head数num_heads: int,# MLP内部channelmlp_dim: int,activation: Type[nn.Module] = nn.ReLU,attention_downsample_rate: int = 2,) -> None:super().__init__()self.depth = depth # 层数self.embedding_dim = embedding_dim # 输入channelself.num_heads = num_heads # attention的head数self.mlp_dim = mlp_dim # MLP内部隐藏channelself.layers = nn.ModuleList()for i in range(depth):self.layers.append(TwoWayAttentionBlock(embedding_dim=embedding_dim, # 输入channelnum_heads=num_heads, # attention的head数mlp_dim=mlp_dim, # MLP中间channelactivation=activation, # 激活层attention_downsample_rate=attention_downsample_rate, # 下采样skip_first_layer_pe=(i == 0),))self.final_attn_token_to_image = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate)self.norm_final_attn = nn.LayerNorm(embedding_dim)def forward(self,image_embedding: Tensor,image_pe: Tensor,point_embedding: Tensor,) -> Tuple[Tensor, Tensor]:# BxCxHxW -> BxHWxC == B x N_image_tokens x Cbs, c, h, w = image_embedding.shape# 图像编码(image_encoder的输出)# BxHWxC=>B,N,Cimage_embedding = image_embedding.flatten(2).permute(0, 2, 1)# 图像位置编码# BxHWxC=>B,N,Cimage_pe = image_pe.flatten(2).permute(0, 2, 1)# 标记点编码# B,N,Cqueries = point_embeddingkeys = image_embedding# -----TwoWayAttention-----for layer in self.layers:queries, keys = layer(queries=queries,keys=keys,query_pe=point_embedding,key_pe=image_pe,)# -----TwoWayAttention-----q = queries + point_embeddingk = keys + image_pe# -----Attention-----attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)# -----Attention-----queries = queries + attn_outqueries = self.norm_final_attn(queries)return queries, keys
-
TwoWayAttention Block由LayerNorm 、Multi-Head Attention和MLP构成。
-
class TwoWayAttentionBlock(nn.Module):def __init__(self,embedding_dim: int, # 输入channelnum_heads: int, # attention的head数mlp_dim: int = 2048, # MLP中间channelactivation: Type[nn.Module] = nn.ReLU, # 激活层attention_downsample_rate: int = 2, # 下采样skip_first_layer_pe: bool = False,) -> None:super().__init__()self.self_attn = Attention(embedding_dim, num_heads)self.norm1 = nn.LayerNorm(embedding_dim)self.cross_attn_token_to_image = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate)self.norm2 = nn.LayerNorm(embedding_dim)self.mlp = MLPBlock(embedding_dim, mlp_dim, activation)self.norm3 = nn.LayerNorm(embedding_dim)self.norm4 = nn.LayerNorm(embedding_dim)self.cross_attn_image_to_token = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate)self.skip_first_layer_pe = skip_first_layer_pedef forward(self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor) -> Tuple[Tensor, Tensor]:# queries:标记点编码相关(原始标记点编码经过一系列特征提取)# keys:原始图像编码相关(原始图像编码经过一系列特征提取)# query_pe:原始标记点编码# key_pe:原始图像位置编码# 第一轮本身queries==query_pe没比较再"残差"if self.skip_first_layer_pe:queries = self.self_attn(q=queries, k=queries, v=queries)else:q = queries + query_peattn_out = self.self_attn(q=q, k=q, v=queries)queries = queries + attn_outqueries = self.norm1(queries)# Cross attention block, tokens attending to image embeddingq = queries + query_pek = keys + key_peattn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)queries = queries + attn_outqueries = self.norm2(queries)# MLP blockmlp_out = self.mlp(queries)queries = queries + mlp_outqueries = self.norm3(queries)# Cross attention block, image embedding attending to tokensq = queries + query_pek = keys + key_peattn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)keys = keys + attn_outkeys = self.norm4(keys)return queries, keys
-
TwoWayAttentionBlock是Prompt encoder的提示信息特征与Image encoder的图像特征的融合过程,而Prompt encoder对提示信息没有过多处理,因此TwoWayAttentionBlock的目的是边对提示信息特征做进一步处理边与图像特征融合。
-
MaskDeco的Attention与ViT的Attention有些细微的不同:MaskDeco的Attention是3个FC层分别接受3个输入获得q、k和v,而ViT的Attention是1个FC层接受1个输入后将结果均拆分获得q、k和v。
-
class Attention(nn.Module):def __init__(self,embedding_dim: int, # 输入channelnum_heads: int, # attention的head数downsample_rate: int = 1, # 下采样) -> None:super().__init__()self.embedding_dim = embedding_dimself.internal_dim = embedding_dim // downsample_rateself.num_heads = num_headsassert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim."# qkv获取self.q_proj = nn.Linear(embedding_dim, self.internal_dim)self.k_proj = nn.Linear(embedding_dim, self.internal_dim)self.v_proj = nn.Linear(embedding_dim, self.internal_dim)self.out_proj = nn.Linear(self.internal_dim, embedding_dim)def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:b, n, c = x.shapex = x.reshape(b, n, num_heads, c // num_heads)return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_headdef _recombine_heads(self, x: Tensor) -> Tensor:b, n_heads, n_tokens, c_per_head = x.shapex = x.transpose(1, 2)return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x Cdef forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:# Input projectionsq = self.q_proj(q)k = self.k_proj(k)v = self.v_proj(v)# Separate into heads# B,N_heads,N_tokens,C_per_headq = self._separate_heads(q, self.num_heads)k = self._separate_heads(k, self.num_heads)v = self._separate_heads(v, self.num_heads)# Attention_, _, _, c_per_head = q.shapeattn = q @ k.permute(0, 1, 3, 2) # B,N_heads,N_tokens,C_per_head# Scaleattn = attn / math.sqrt(c_per_head)attn = torch.softmax(attn, dim=-1)# Get outputout = attn @ v# # B,N_tokens,Cout = self._recombine_heads(out)out = self.out_proj(out)return out
-
MaskDeco的Attention和ViT的Attention的结构对比示意图:
-
transformer_MLP
-
class MLPBlock(nn.Module):def __init__(self,embedding_dim: int,mlp_dim: int,act: Type[nn.Module] = nn.GELU,) -> None:super().__init__()self.lin1 = nn.Linear(embedding_dim, mlp_dim)self.lin2 = nn.Linear(mlp_dim, embedding_dim)self.act = act()def forward(self, x: torch.Tensor) -> torch.Tensor:return self.lin2(self.act(self.lin1(x)))
-
upscaled
-
# 在MaskDecoder的__init__定义 self.output_upscaling = nn.Sequential(nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2), #转置卷积 上采样2倍LayerNorm2d(transformer_dim // 4),activation(),nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),activation(), ) # 在MaskDecoder的predict_masks添加位置编码 upscaled_embedding = self.output_upscaling(src)
-
mask_MLP
-
# 在MaskDecoder的__init__定义 self.output_hypernetworks_mlps = nn.ModuleList([MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3)for i in range(self.num_mask_tokens)] ) # 在MaskDecoder的predict_masks添加位置编码for i in range(self.num_mask_tokens):# mask_tokens_out[:, i, :]: B,1,C# output_hypernetworks_mlps: B,1,chyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]))# B,n,chyper_in = torch.stack(hyper_in_list, dim=1)b, c, h, w = upscaled_embedding.shape# B,n,c × B,c,N-->B,n,h,wmasks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
-
iou_MLP
-
# 在MaskDecoder的__init__定义 self.iou_prediction_head = MLP(transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth ) # 在MaskDecoder的predict_masks添加位置编码 iou_pred = self.iou_prediction_head(iou_token_out)
-
MaskDeco_MLP
-
class MLP(nn.Module):def __init__(self,input_dim: int, # 输入channelhidden_dim: int, # 中间channeloutput_dim: int, # 输出channelnum_layers: int, # fc的层数sigmoid_output: bool = False,) -> None:super().__init__()self.num_layers = num_layersh = [hidden_dim] * (num_layers - 1)self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))self.sigmoid_output = sigmoid_outputdef forward(self, x):for i, layer in enumerate(self.layers):x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)if self.sigmoid_output:x = F.sigmoid(x)return x
-
iou_MLP
-
# 在MaskDecoder的__init__定义 self.iou_prediction_head = MLP(transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth ) # 在MaskDecoder的predict_masks添加位置编码 iou_pred = self.iou_prediction_head(iou_token_out)
-
MaskDeco_MLP
-
class MLP(nn.Module):def __init__(self,input_dim: int, # 输入channelhidden_dim: int, # 中间channeloutput_dim: int, # 输出channelnum_layers: int, # fc的层数sigmoid_output: bool = False,) -> None:super().__init__()self.num_layers = num_layersh = [hidden_dim] * (num_layers - 1)self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))self.sigmoid_output = sigmoid_outputdef forward(self, x):for i, layer in enumerate(self.layers):x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)if self.sigmoid_output:x = F.sigmoid(x)return x
-