[医学分割大模型系列] (1) SAM 分割大模型解析

[医学大模型系列] [1] SAM 分割大模型解析

  • 1. 特点
  • 2. 网络结构
    • 2.1 Image encoder
    • 2.2 Prompt encoder
    • 2.3 Mask decoder
  • 3. 数据引擎
  • 4. 讨论

论文地址:Segment Anything

开源地址:https://github.com/facebookresearch/segment-anything

demo地址:Segment Anything | Meta AI

参考:

  1. SAM模型详解
  2. https://www.bilibili.com/video/BV1K94y177Ka/?spm_id_from=333.337.search-card.all.click&vd_source=6bf14836c2866f1a29042ffd4369a079
  3. https://www.bilibili.com/video/BV1Bu411s79u/?spm_id_from=333.337.search-card.all.click&vd_source=6bf14836c2866f1a29042ffd4369a079

1. 特点

在这里插入图片描述

  • 可提示的(Prompt)交互图像分割大模型(Foundation Models)
  • 四种Prompt形式:点,框,Mask,文本
  • 构建数据引擎:使用高效模型来协助数据收集和使用新收集的数据来帮助模型迭代。

2. 网络结构

模型整体上包含三个大模块,image encoder,prompt encoder和mask decoder。

2.1 Image encoder

image encoder旨在映射待分割的图像到图像特征空间。
在这里插入图片描述
这里的ViT结构也并不是十分复杂,这里简单列出输入图像经过ViT的流程,其实整体只有4个步骤:

  • 输入图像进入网络,先经过一个卷积base的patch_embedding:取16*16为一个patch,步长也是16,这样feature map的尺寸就缩小了16倍,同时channel从3映射到768。
  • patch_embed过后加positional_embedding:positional_embedding是个可学习的参数矩阵,初始化是0。
  • 加了positional_embedding后的feature map过16个transformer block,其中12个transformer是基于window partition(就是把特征图分成14*14的windows做局部的attention)的attn模块,和4个全局attn,这4个全局attn是穿插在windowed attention中的。
  • 最后过两层卷积(neck)把channel数降到256(token长度),这就是最终的image embedding的结果。

整体来看,这个部分的计算量是相对来说比较大的,demo体验过程中,只有这个过程的计算是在fb的服务器上做的,prompt encoder和mask decoder体积比较小,都是在浏览器内部或者说用本地的内存跑的,整体速度还比较快。

其使用的预训练模型: Masked Autoencoders Are Scalable Vision Learners - 代码

# build_sam.py
sam_model_registry = {"default": build_sam_vit_h,"vit_h": build_sam_vit_h,"vit_l": build_sam_vit_l,"vit_b": build_sam_vit_b,
}

2.2 Prompt encoder

prompt encoder则是负责映射输入的prompt到prompt的特征空间,这里有一点要提就是作者定义了sparse(包括 点,框,文本)和dense(mask)两种prompt,其中sparse prompt比较好理解,就是指demo中我们可以输入的点,目标框或者是描述目标的text,而dense prompt在目前的线上demo中体验不到,paper中也只说它对应的是mask类型的prompt,从代码里看应该是训练时候用的比较多,一般是上一次迭代预测出的一个粗分割的mask,粗略指出待分割的目标区域。
在这里插入图片描述
映射出的特征的channel和image embedding的channel一致(默认均为256),因为这两个后边要用attention进行融合。

  • sparse prompt:
    • 如果prompt是point,那么它的映射由两个部分相加组成,一个是位置编码,这里的位置编码使用的是Fourier Features Let Networks Learn High Frequency Functions in Low Dimensional Domains的编码方式,用空间坐标乘以高斯分布的向量来描述位置比直接的线性向量描述效果更好,另一个部分是一个描述当前点是前景还是背景(因为demo里可以选择pos点也可以选择neg点)特征的可学习的一维向量。换句话说,如果当前选择的点是positive,那么就在位置编码的2维向量上加一个表示postitive的一维向量,如果是neg,就加一个表示neg的一维向量,对于所有的positive的点,加上去的pos向量都是一样的。
    • 如果prompt是box,那box的映射也是由两个部分相加组成,第一部分是左上和右下两个点的位置编码,第二部分是一组一维向量用来描述这个点是“左上”还是“右下”。也就是说,对于左上的点,他的映射就是位置编码+“左上”这个特征的描述向量,右下的点,就是位置编码+“右下”这个特征的描述向量。
  • dense prompt:
    • 对于mask这类的dense prompt,他的映射就比较简单粗暴。在输入prompt encoder之前,先要把mask降采样到4x,再过两个2x2,stride=2的卷积,这样尺寸又降了4x,就和降了16x的图像特征图尺寸一致了,再过一个1*1的卷积,把channel也升到256。如果没有提供mask,也就是我们实际inference时候的场景,这个结构会直接返回一个描述“没有mask”特征的特征图。

预训练: CLIP

# prompt_encoder.pydef forward(self,points: Optional[Tuple[torch.Tensor, torch.Tensor]],boxes: Optional[torch.Tensor],masks: Optional[torch.Tensor],) -> Tuple[torch.Tensor, torch.Tensor]:"""Embeds different types of prompts, returning both sparse and denseembeddings.Arguments:points (tuple(torch.Tensor, torch.Tensor) or none): point coordinatesand labels to embed.boxes (torch.Tensor or none): boxes to embedmasks (torch.Tensor or none): masks to embedReturns:torch.Tensor: sparse embeddings for the points and boxes, with shapeBxNx(embed_dim), where N is determined by the number of input pointsand boxes.torch.Tensor: dense embeddings for the masks, in the shapeBx(embed_dim)x(embed_H)x(embed_W)"""bs = self._get_batch_size(points, boxes, masks)sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device())if points is not None:coords, labels = pointspoint_embeddings = self._embed_points(coords, labels, pad=(boxes is None))sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)if boxes is not None:box_embeddings = self._embed_boxes(boxes)sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1)if masks is not None:dense_embeddings = self._embed_masks(masks)else:dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(bs, -1, self.image_embedding_size[0], self.image_embedding_size[1])return sparse_embeddings, dense_embeddings

2.3 Mask decoder

在这里插入图片描述

  • detr结构回顾

    • cross att:
      在这里插入图片描述
      100个queries token对图片特征进行遍历,确定每个token要对应寻找的物体特征

    • self att:
      在这里插入图片描述
      避免token之间寻找物体特征重复,queries token进行自注意力机制。

  • 结构描述

    • (1) self-attention on the tokens,
    • (2) cross-attention from tokens (as queries) to the image embedding,
    • (3) a point-wise MLP updates each token,
    • (4) cross-attention from the image embedding (as queries) to tokens.
    • (5) Additionally, the entire original prompt tokens (including their positional encodings) are re-added to the updated tokens whenever they participate in an attention layer. 提示加了两遍
      在这里插入图片描述
  • 结构分析

    • 在prompt embedding进入decoder之前,先在它上面concat了一组可学习的output tokens (相当于dert的queries token,100个tokens),output tokens由两个部分构成:
      • 一个是iou token,它会在后面被分离出来用于预测iou的可靠性(对应结构图右侧的IoU output token),它受到模型计算出的iou与模型计算出的mask与GT实际的iou之间的MSE loss监督;
      • 另一个是mask token,它也会在后面被分离出来参与预测最终的mask(对应结构图右侧的output token per mask),mask受到focal loss和dice loss 20:1的加权组合监督。
      • 这两个token的意义我感觉比较抽象,因为理论来说进入decoder的变量应该是由模型的输入,也就是prompt和image的映射构成,但这两个token的定义与prompt和image完全没有关系,而是凭空出现的。从结果反推原因,只能把它们理解成对模型的额外约束,因为它们两个参与构成了模型的两个输出并且有loss对他们进行监督。
      • 最终prompt embedding(这一步改名叫prompt token)和刚才提到这两个token concat到一起统称为tokens进入decoder。
    • image embedding在进入decoder之前也要进行一步操作:dense prompt由于包含密集的空间信息,与image embedding所在的特征空间一致性更高,所以直接与image embedding相加融合。因为后面要与prompt做cross attention融合,这里还要先算一下image embedding的位置编码。
    • 接下来{image embedding,image embedding的位置编码,tokens}进入一个两层transformer结构的decoder做融合。值得注意的是,在transformer结构中,为了保持位置信息始终不丢失,每做一次attention运算,不管是self-attention还是cross-attention,tokens都叠加一次初始的tokens,image embedding都叠加一次它自己的位置编码,并且每个attention后边都接一个layer_norm。
      • tokens先过一个self-attention。
      • tokens作为q,对image embedding做cross attention,更新tokens。
      • tokens再过两层的mlp做特征变换。
      • image embedding作为q,对tokens做cross attention,更新image embedding。
    • 更新后的tokens作为q,再对更新后的image embedding做cross attention,产生最终的tokens。
    • 更新后的image embedding过两层kernel_size=2, stride=2的转置卷积,升采样到4x大小(依然是4x降采样原图的大小),产生最终的image embedding。
    • 接下来兵分两路:
      • mask token被从tokens中分离出来(因为他一开始就是concat上去的,可以直接按维度摘出来),过一个三层的mlp调整channel数与最终的image embedding一致,并且他们两个做矩阵乘法生成mask的预测。最终的image embedding大小为[token长度(channel),图像长,图像宽],mask token的大小为[token长度(channel),token个数(最终生成的mask个数)]。两者点乘后的大小为[token个数(最终生成的mask个数),图像长,图像宽]。也就是说有几个token个数生成几个mask。
      • iou token被从tokens中分离出来,也过一个三层的mlp生成最终的iou预测。在反向传播时,排序mask,参与计算的只有loss最小的mask相关的参数。
    • 最后,如前文所述,分别对mask的预测和iou预测进行监督,反向传播,更新参数。每一个mask,会随机产生11种prompt与之配对。

对于一个输出,如果给出一个模糊的提示,该模型将平均多个有效的掩码。为了解决这个问题,我们修改了模型,以预测单个提示的多个输出掩码(比如说提示在衣服上,会分出来衣服的mask和人的mask)。我们发现3个掩模(output/mask token的个数为3)输出足以解决大多数常见的情况(嵌套掩模通常最多有三个深度:整体、部分和子部分)。在训练期间,我们只支持mask上的最小损失[匈牙利损失]。为了对掩码进行排名,该模型预测了每个掩码的置信度分数(即估计的IoU)

在这里插入图片描述

# transformer.pydef forward(self,image_embedding: Tensor,image_pe: Tensor,point_embedding: Tensor,) -> Tuple[Tensor, Tensor]:"""Args:image_embedding (torch.Tensor): image to attend to. Should be shapeB x embedding_dim x h x w for any h and w.image_pe (torch.Tensor): the positional encoding to add to the image. Musthave the same shape as image_embedding.point_embedding (torch.Tensor): the embedding to add to the query points.Must have shape B x N_points x embedding_dim for any N_points.Returns:torch.Tensor: the processed point_embeddingtorch.Tensor: the processed image_embedding"""# BxCxHxW -> BxHWxC == B x N_image_tokens x Cbs, c, h, w = image_embedding.shapeimage_embedding = image_embedding.flatten(2).permute(0, 2, 1)image_pe = image_pe.flatten(2).permute(0, 2, 1)# Prepare queriesqueries = point_embeddingkeys = image_embedding# Apply transformer blocks and final layernormfor layer in self.layers:queries, keys = layer(queries=queries,keys=keys,query_pe=point_embedding,key_pe=image_pe,)# Apply the final attention layer from the points to the imageq = queries + point_embeddingk = keys + image_peattn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)queries = queries + attn_outqueries = self.norm_final_attn(queries)return queries, keys

3. 数据引擎

在这里插入图片描述
阶段一,手动阶段,模型越来越好,数据越来越多

  1. 拿现有已经表好的数据集训练一个粗糙版的SAM
  2. 粗糙版的SAM分割新数据,人工修正后,重新训练模型
  3. 上两部反复迭代

阶段二,半自动阶段,默认准确率够高,召回率不够好(漏标)

  1. 用训练好的SAM模型分割图像
  2. 把分割不出来的给人来标
  3. 重新训练模型

阶段三,全自动阶段

  1. SAM模型在图像上分割,对得分较高,准确率较高,符合设定规则的结果进行保留

4. 讨论

在这里插入图片描述
SAM的一个我个人认为比较新颖的点子是它从interactive segmentation引申出了一个新的任务类型,叫做promptable segmentation。从他的模型中也能看出,输入的prompt是模型在输出最终mask的关键指导信息,这也是为什么我发现目前的SAM模型在处理一些专业领域图像(比如我自己从事的医学图像分割)时,直接使用他的segment everything功能,也就是无prompt进行分割时效果不好的原因。

另一个要搞清楚的问题是在进行有prompt的分割时,实际上实现的是一个二分类的分割任务,模型要解决的问题是根据我们选择的点的特征,从图像(背景)中分割出这个点所在的目标物体(前景),它本质上并不关心这个目标物体是个什么东西。滑稽一点来说,整个过程实际上有点类似photoshop里魔棒的功能,adobe倒是可以考虑把这个模型整合进ps里提升一些性能。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/764364.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

C#,图片分层(Layer Bitmap)绘制,反色、高斯模糊及凹凸贴图等处理的高速算法与源程序

1 图像反色Invert 对图像处理的过程中会遇到一些场景需要将图片反色,反色就是取像素的互补色,比如当前像素是0X00FFFF,对其取反色就是0XFFFFFF – 0X00FFFF = 0XFF0000,依次对图像中的每个像素这样做,最后得到的就是原始2 图像的反色。 2 高斯模糊(Gauss Blur)算法 …

cesium知识点:坐标系

一,地理坐标系 1.经纬度坐标系 对象:没有实际的对象 说明:cesium默认使用WGS84坐标系作为空间参考,坐标原点在椭球的质心。 2.弧度坐标系(Cartographic) 对象:new Cesium.Cartographic(longitude, latitude, heigh…

easyExcel大数据量导出oom

easyExcel大数据量导出 异常信息 com.alibaba.excel.exception.ExcelGenerateException: java.lang.OutOfMemoryError: GC overhead limit exceededat com.alibaba.excel.write.ExcelBuilderImpl.fill(ExcelBuilderImpl.java:84)at com.alibaba.excel.ExcelWriter.fill(Excel…

AI智能分析网关V4养老院视频智能监控方案

随着科技的快速发展,智能监控技术已经广泛应用于各个领域,尤其在养老院这一特定场景中,智能监控方案更是发挥着不可或缺的作用。尤其是伴随着社会老龄化趋势的加剧,养老院的安全管理问题也日益凸显。为了确保老人的生活安全&#…

yarn安装包时报错error Error: certificate has expired

安装教程: 配置镜像地址: npm config set registry https://registry.npmmirror.com//镜像:https://developer.aliyun.com/mirror/NPM 安装yarn: npm install --global yarn查看版本: yarn --version卸载&#xff…

每日五道java面试题之springboot篇(一)

目录: 第一题. 什么是 Spring Boot?第二题. Spring Boot 有哪些优点?第三题. Spring Boot 的核心注解是哪个?它主要由哪几个注解组成的?第四题. 什么是 JavaConfig?第五题. Spring Boot 自动配置原理是什么…

ChatGPTGPT4科研应用、数据分析与机器学习、论文高效写作、AI绘图技术教程

原文链接:ChatGPTGPT4科研应用、数据分析与机器学习、论文高效写作、AI绘图技术教程https://mp.weixin.qq.com/s?__bizMzUzNTczMDMxMg&mid2247598506&idx2&sn14f96667bfbeba5f51366a1f019e3d64&chksmfa82004dcdf5895bba2784ba10f6715f6f5e4c59c9b1…

【MySQL】3.2MySQL事务和存储引擎

MySQL事务 一、MySQL事物的概念 事务是一种机制,包含了一件事的完整的一个过程 ●事务是一种机制、一个操作序列,包含了一组数据库操作命令,并且把所有的命令作为一个整体一起向系统提交或撤销操作请求,即这一组数据库命令要么…

后端项目中构建前端模块问题记录

后端项目中在登陆页面使用jsp,后端项目会通过接口返回给前端几个js的路径,这几个js呢,是由后端先构建好,然后返回给前端路径的,前端通过这个路径访问js执行。。。 总之,很奇怪的项目。。 1、首先要安装no…

JSqlParser的使用

简介 JSqlParse是一款很精简的sql解析工具,它可以将常用的sql文本解析成具有层级结构的语法树,我们可以针对解析后的节点进行处理(增加、移除、修改等操作),从而生成符合我们业务要求的sql,比如添加过滤条件等等 JSqlParse采用访问者模式 项…

全智能深度演进,一键成片让视频创作颠覆式提效

全智能一键成片,让内容创作的「边际成本」逼近于零。 大模型和AIGC技术的发展,可以用“日新月异”来形容,其迭代速度史无前例,涌现出的各类垂直应用模型,也使得音视频行业的应用场景更加广泛和多样化。 然而&#xff…

如何从零开始拆解uni-app开发的vue项目(三)

前言:前两篇文章我们讲解了如何拆解uni-app开发的项目结构、实现前台数据的动态加载,今天讲一篇如何实现动态加载功能列表,以及美化界面。话不多说,直接先看源码: 在用户成功登录后,会跳转到menu.vue菜单, 再次点击点检功能时,会进入点检的具体功能跳转菜单,我们的点…

在Linux/Debian/Ubuntu上通过 Azure Data Studio 管理 SQL Server 2019

Microsoft 提供 Azure Data Studio,这是一种可在 Linux、macOS 和 Windows 上运行的跨平台数据库工具。 它提供与 SSMS 类似的功能,包括查询、脚本编写和可视化数据。 要在 Ubuntu 上安装 Azure Data Studio,可以按照以下步骤操作&#xff1…

Sphinx + Readthedocs 避坑速通指南

博主在学习使用 Sphinx 和 Read the docs 的过程中, 碰到了许多奇葩的 bug, 使得很简单的任务花费了很长的时间才解决,现在在这里做一个分享,帮助大家用更少的时间高效上线文档的内容。 总的来说, 任务分为两个部分: …

UE5制作推箱子动作时获取物体与角色朝向的角度及跨蓝图修改变量

就是脑残死磕,你们如果有更好的方法一定要留言啊~~独乐乐不如众乐乐。 做推箱子的时候需要考虑脸是不是面对着箱子,不是必须90度,可以有一个-45~45度的范围。 摸索了一下,有几种做法和几个小白坑,这里列出来。 一、准…

mysql 索引原理为什么用b+树而不用二叉树

在数据库中,索引是一种数据结构,它能够快速定位到存储在数据库表中特定行的数据。MySQL等数据库管理系统通常使用B树作为索引的数据结构,而不使用二叉树,主要基于以下几个原因: 高度平衡:B树是一种多路搜索…

软件推荐 篇三十七:开源免费无广告的在线音乐免费播放 | MusicFree纯净无广告体验-小众冷门推荐

引言 自从QQ音乐没了杰伦、某云开始收费,除了各种广告弹窗导致电脑卡的要死,打工人就靠这点音乐背景熬夜了,木有办法,得有个开源免费的听歌软件吧,一搜github,软件一大堆,作为一个打工仔&#…

【前端寻宝之路】学习和总结HTML表格的实现和合并

🌈个人主页: Aileen_0v0 🔥热门专栏: 华为鸿蒙系统学习|计算机网络|数据结构与算法|MySQL| ​💫个人格言:“没有罗马,那就自己创造罗马~” #mermaid-svg-IWDj0gWiFt6IMq3x {font-family:"trebuchet ms",verdana,arial,sans-serif;f…

GraphPad Prism 10:一站式数据分析解决方案

GraphPad Prism 10是一款功能强大的数据分析和可视化软件,广泛应用于生命科学研究、医学、生物、化学等多个领域。以下是对其详细功能的介绍: 首先,GraphPad Prism 10具有出色的数据可视化功能。它支持各种类型的图表和图形,包括…