【图像分割】mask2former:通用的图像分割模型详解

最近看到几个项目都用mask2former做图像分割,虽然是1年前的论文,但是其attention的设计还是很有借鉴意义,同时,mask2former参考了detr的query设计,实现了语义和实例分割任务的统一。

1.背景

1.1 detr简介

detr算是第一个尝试用transformer实现目标检测的框架,其设计思路也很简单,就是定义object queries,用来查询是否存在目标以及目标位置的,类似cnn检测中的rpn,产生候选框。在detr中,object queries为(100,b,256)的可学习的参数,其中每个256维的向量代表了检测的box信息,这个信息是由类别和空间信息(box坐标)组成,其中类别信息用于区别类别,而空间信息则描述了目标在图像中的位置。

通过设置query,则不需要像传统cnn检测时预设anchor,最后通过匈牙利匹配算法将query到的目标和gt进行匹配,计算loss。

decoder过程中,query object先初始化为0,然后经过self attention,再和encoder的输出进行cross attention。

1.2 Deformable-DETR简介

Deformable-Detr是在detr的基础上了主要做了2个改进,Deformable attention(可变形注意力)和多尺度特征,通过可变性注意力降低了显存,多尺度特征对小目标检测效果比较好。

(1)Deformable attention(可变形注意力)

这个设计参考了可变性卷积(DCN),后续很多设计都参考了这个。先看下DCN,就是在标准卷积(a)的3 * 3的卷积核上,每个点上增加一个偏移量(dx,dy),让卷积核不规则,可以适应目标的形状和尺度。

对于一般的attention,query与key的每个值都要计算注意力,这样的问题就是耗显存;另外,对图像来说,假设其中有一个目标,一般只有离图像比较近的像素才有用,离比较远的像素,对目标的贡献很少,甚至还有负向的干扰。

Defromable attention的设计思路就是query不与全局的key进行计算,而是至于其周围的key进行计算。至于这个周围要选哪几个位置,就类似DCN,让模型自己去学。

  • 单尺度的可变性注意力机制

DeformAttn的公式如下:

  • 多尺度的可变性注意力机制

多尺度即类似fpn,提取不同尺度的特征,但由于特征的尺寸不一样,需要将不同尺度的特征连接起来。

可变性注意力机制公式如下:

相比单尺度的,多尺度多了一个l,代表第几个尺度,一般取4个层级。

对于一个query,在其参考点(reference point)对应的所有层都采用K个点,然后将每层的K个点特征融合(相加)。

整个deformable atten的流程如下:

2.mask2former

mask2former的设计上使用了deformable detr的可变形注意力。

主要计算过程用下图表示:

2.1 模型改进

(1)masked attention

一般计算过程中,计算atten时只用前景部分计算,减少显存占用。

(2) 多分辨率特征

如上图,图像经过backbone得到4层特征,然后经过Pixel Decoder得到O1,O2,O3,O4,注意O1,O2,O3经过Linear+Deform atten Layer,O4只通过Linear+卷积得到,具体可以区别看上图。

(3) decoder优化

在transformer decoder(这个过程用的是标准attention)计算过程中,query刚开始都是随机初始化的,没有图像特征,如果按常规直接self attention可能学不到充分的信息,所以将ca和sa两个模块反过来,先和pixdecoder得到的图像O1,O2,O3计算ca,再继续计算sa。

2.2 类别和mask分开预测

class和mask预测独立开来,mask只预测是背景还是前景,class负责预测类别,这部分保留了maskformer的设计。

如上图,class通过query加上Linear直接将维度转到(n,k+1),其中k为类别数目。

mask通过decoder和最后一层的mask做外积运算,得到(k,h,w)的tensor,每个k代表一个前景。

采用这种query的方式,既可以做instance也可以做语义分割,query的数量N和类别K数量无关。

2.3 loss优化

mask decoder过程中,主要用最后一层的输出计算loss;同时为了辅助训练,默认开启了auxiliary loss(辅助loss),其他层的输出也去计算loss。

还有一个trick,mask计算loss时,不是mask上的所有点都去计算,而是随机采样一定数目的点去计算loss。默认设置= 12544, i.e., 112 × 112 points,这样可以节省显存。

3.扩展

3.1 DAT:另一个Deform atten设计

另一篇deform atten的论文DAT,和deform attention思路类似,也是学习offset。只不过在偏移量设计上有区别,如下图所示,DAT在当前特征图F上学习offset时,进行了上采样2倍,在得到offset后需要插值回F的尺寸,增加了相对位置的bias。

对比几种查询的注意力结果,vit是全查,swin固定窗口大小,有可能限制查到的key,DCN为可变性卷积,DAT学到的key更好。

模型设计上,参考swin-transformer,只将最后2层替换Deformable attention,效果最好。

3.2 视频实例分割跟踪

mask2former用于视频分割,结构如下

模型结构上和图像的分割基本一致。

修改主要在transformer decoder,包含以下3个地方:

(1)增加时间编码t

主要在Transformer decoder过程,图像的位置编码为(x,y),对于视频,由于考虑了多帧数据,增加时间t进行编码,位置编码为(x,y,t)。

       # b, t, c, h, wassert x.dim() == 5, f"{x.shape} should be a 5-dimensional Tensor, got {x.dim()}-dimensional Tensor instead"if mask is None:mask = torch.zeros((x.size(0), x.size(1), x.size(3), x.size(4)), device=x.device, dtype=torch.bool)not_mask = ~maskz_embed = not_mask.cumsum(1, dtype=torch.float32)  # not_mask【bath,t,h,w】1代表时间列的索引,cumsum累加计算,得到位置idy_embed = not_mask.cumsum(2, dtype=torch.float32)  # hx_embed = not_mask.cumsum(3, dtype=torch.float32)  # wif self.normalize:eps = 1e-6z_embed = z_embed / (z_embed[:, -1:, :, :] + eps) * self.scaley_embed = y_embed / (y_embed[:, :, -1:, :] + eps) * self.scalex_embed = x_embed / (x_embed[:, :, :, -1:] + eps) * self.scaledim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)dim_t_z = torch.arange((self.num_pos_feats * 2), dtype=torch.float32, device=x.device)dim_t_z = self.temperature ** (2 * (dim_t_z // 2) / (self.num_pos_feats * 2))pos_x = x_embed[:, :, :, :, None] / dim_t  # [b,t,h,w]->[b,t,h,w,d] xy编码的d长度是位置编码向量长度的一半pos_y = y_embed[:, :, :, :, None] / dim_tpos_z = z_embed[:, :, :, :, None] / dim_t_z # z用编码向量长度,然后和xy编码相加pos_x = torch.stack((pos_x[:, :, :, :, 0::2].sin(), pos_x[:, :, :, :, 1::2].cos()), dim=5).flatten(4)pos_y = torch.stack((pos_y[:, :, :, :, 0::2].sin(), pos_y[:, :, :, :, 1::2].cos()), dim=5).flatten(4)pos_z = torch.stack((pos_z[:, :, :, :, 0::2].sin(), pos_z[:, :, :, :, 1::2].cos()), dim=5).flatten(4)pos = (torch.cat((pos_y, pos_x), dim=4) + pos_z).permute(0, 1, 4, 2, 3)  # b, t, c, h, w

(2) query和多帧数据进行atten计算

        for i in range(self.num_feature_levels):size_list.append(x[i].shape[-2:])pos.append(self.pe_layer(x[i].view(bs, t, -1, size_list[-1][0], size_list[-1][1]), None).flatten(3))src.append(self.input_proj[i](x[i]).flatten(2) + self.level_embed.weight[i][None, :, None])  #level_embed size [level_num,d],level embed和输入相加# NTxCxHW => NxTxCxHW => (TxHW)xNxC  # 多帧数据融合_, c, hw = src[-1].shapepos[-1] = pos[-1].view(bs, t, c, hw).permute(1, 3, 0, 2).flatten(0, 1)# 其中src是Pixel decoder的输出src[-1] = src[-1].view(bs, t, c, hw).permute(1, 3, 0, 2).flatten(0, 1)

(3)query和mask计算优化

如代码所示,query和mask 外积计算,从q外积mask得到mask的shape为[b,q,t,h,w],也就是得到(b,q,t)个instance mask,然后query的instance mask和每帧的gt计算loss。

    def forward_prediction_heads(self, output, mask_features, attn_mask_target_size):decoder_output = self.decoder_norm(output)decoder_output = decoder_output.transpose(0, 1)outputs_class = self.class_embed(decoder_output)mask_embed = self.mask_embed(decoder_output)# query和mask 外积计算,从q外积mask得到[b,q,t,h,w]个maskoutputs_mask = torch.einsum("bqc,btchw->bqthw", mask_embed, mask_features)b, q, t, _, _ = outputs_mask.shape# NOTE: prediction is of higher-resolution# [B, Q, T, H, W] -> [B, Q, T*H*W] -> [B, h, Q, T*H*W] -> [B*h, Q, T*HW]attn_mask = F.interpolate(outputs_mask.flatten(0, 1), size=attn_mask_target_size, mode="bilinear", align_corners=False).view(b, q, t, attn_mask_target_size[0], attn_mask_target_size[1])# must use bool type# If a BoolTensor is provided, positions with ``True`` are not allowed to attend while ``False`` values will be unchanged.attn_mask = (attn_mask.sigmoid().flatten(2).unsqueeze(1).repeat(1, self.num_heads, 1, 1).flatten(0, 1) < 0.5).bool()attn_mask = attn_mask.detach()return outputs_class, outputs_mask, attn_mask

训练时是以instance作为一个基础单元,假设有t帧图像,有n个instance(实例),instance和frame的关系如下图表示:

instance在每帧上都可能存在或者不存在。对于每个instance,初始化t个mask,初始化为0,所以instace的shape是[b,n,t,h,w],如果这个instance在某帧上存在,即赋真值mask,用于匹配计算loss;不存在,即为0。

instance在每帧上都是同一个物体(形态可能变化,但是instance id是相同的),所以预测instance的类别时,每个instance只需要预测一个类别即可,所以类别的shape为[b,n]

3.3 思考

sam(segment anything model)可以通过prompt进行分割,但是缺乏类别信息,可以参考mask2former的思想,mask和类别是独立的,可以添加分类的query,接一个分类的分支,然后在coco等数据集上单独训练这个分支,让sam分割后增加类别信息。

4.参考资料

  • mask2former论文
  • mask2former代码


附赠

【一】上千篇CVPR、ICCV顶会论文
【二】动手学习深度学习、花书、西瓜书等AI必读书籍
【三】机器学习算法+深度学习神经网络基础教程
【四】OpenCV、Pytorch、YOLO等主流框架算法实战教程

➤ 在助理处自取:

➤ 还可咨询论文辅导❤【毕业论文、SCI、CCF、中文核心、El会议】评职称、研博升学、本升海外学府!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/39775.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

香橙派AIpro实测:YOLOv8便捷检测,算法速度与运行速度结合

香橙派AIpro实测&#xff1a;YOLOv8便捷检测&#xff0c;算法速度与运行速度结合 文章目录 香橙派AIpro实测&#xff1a;YOLOv8便捷检测&#xff0c;算法速度与运行速度结合一、引言二、香橙派AIpro简介三、YOLOv8检测效果3.1 目标检测算法介绍3.1.1 YOLO家族3.1.2 YOLOv8算法理…

上海计算机考研炸了,这所学校慎报!上海大学计算机考研考情分析!

上海大学&#xff08;Shanghai University&#xff09;&#xff0c;简称“上大”&#xff0c;是上海市属、国家“211工程”重点建设的综合性大学&#xff0c;教育部与上海市人民政府共建高校&#xff0c;国防科技工业局与上海市人民政府共建高校&#xff0c;国家“双一流”世界…

【微信小程序开发】微信小程序界面弹窗,数据存储相关操作代码逻辑实现

✨✨ 欢迎大家来到景天科技苑✨✨ &#x1f388;&#x1f388; 养成好习惯&#xff0c;先赞后看哦~&#x1f388;&#x1f388; &#x1f3c6; 作者简介&#xff1a;景天科技苑 &#x1f3c6;《头衔》&#xff1a;大厂架构师&#xff0c;华为云开发者社区专家博主&#xff0c;…

how to use Xcode

Xcode IDE概览 Xcode 页面主要分为以下四个部分&#xff1a; 工具栏&#xff08;ToolBar area&#xff09;&#xff1a;主要负责程序运行调试&#xff0c;编辑器功能区域的显示 / 隐藏&#xff1b;编辑区&#xff08;Editor area&#xff09;&#xff1a;代码编写区域&#xf…

vue table表格 ( parseTime-格式化时间)

<el-table-column label"发布时间" width"420px" prop"bidPublishDatetime"><template slot-scope"scope"><span>{{ parseTime(scope.row.bidPublishDatetime, {y}-{m}-{d}) }}</span></template></…

Richtek立锜科技车规级器件选型

芯片按照应用场景&#xff0c;通常可以分为消费级、工业级、车规级和军工级四个等级&#xff0c;其要求依次为军工>车规>工业>消费。 所谓“车规级元器件”--即通过AEC-Q认证 汽车不同于消费级产品&#xff0c;会运行在户外、高温、高寒、潮湿等苛刻的环境&#xff0c…

澳蓝荣耀时刻,6款产品入选2024年第一批《福州市名优产品目录》

近日&#xff0c;福州市工业和信息化局公布2024年第一批《福州市名优产品目录》&#xff0c;澳蓝自主研发生产的直接蒸发冷却空调、直接蒸发冷却组合式空调机组、间接蒸发冷水机组、高效间接蒸发冷却空调机、热泵式热回收型溶液调湿新风机组、防火湿帘6款产品成功入选。 以上新…

飞利浦的台灯值得入手吗?书客、松下多维度横评大分享!

随着生活品质的持续提升&#xff0c;人们对于健康的追求日益趋向精致与高端化。在这一潮流的推动下&#xff0c;护眼台灯以其卓越的护眼功效与便捷的操作体验&#xff0c;迅速在家电领域崭露头角&#xff0c;更成为了众多家庭书房中不可或缺的视力守护者。这些台灯以其精心设计…

(vue)eslint-plugin-vue版本问题 安装axios时npm ERR! code ERESOLVE

(vue)eslint-plugin-vue版本问题 安装axios时npm ERR! code ERESOLVE 解决方法&#xff1a;在命令后面加上 -legacy-peer-deps结果&#xff1a; 解决参考&#xff1a;https://blog.csdn.net/qq_43799531/article/details/131403987

【C语言】指针剖析(完结)

©作者:末央&#xff06; ©系列:C语言初阶(适合小白入门) ©说明:以凡人之笔墨&#xff0c;书写未来之大梦 目录 回调函数概念回调函数的使用 - qsort函数 sizeof/strlen深度理解概念手脑并用1.sizeof-数组/指针专题2.strlen-数组/指针专题 指针面试题专题 回调函…

谷粒商城-个人笔记(集群部署篇二)

前言 ​学习视频&#xff1a;​Java项目《谷粒商城》架构师级Java项目实战&#xff0c;对标阿里P6-P7&#xff0c;全网最强​学习文档&#xff1a; 谷粒商城-个人笔记(基础篇一)谷粒商城-个人笔记(基础篇二)谷粒商城-个人笔记(基础篇三)谷粒商城-个人笔记(高级篇一)谷粒商城-个…

【数据结构】02.顺序表

一、顺序表的概念与结构 1.1线性表 线性表&#xff08;linear list&#xff09;是n个具有相同特性的数据元素的有限序列。线性表是⼀种在实际中广泛使用的数据结构&#xff0c;常见的线性表&#xff1a;顺序表、链表、栈、队列、字符串… 线性表在逻辑上是线性结构&#xff0…

GEE计算遥感生态指数RSEI

目录 RESI湿度绿度热度干度源代码归一化函数代码解释整体的代码功能解释:导出RSEI计算结果参考文献RESI RSEI = f (Greenness,Wetness,Heat,Dryness)其遥感定义为: RSEI = f (VI,Wet,LST,SI)式中:Greenness 为绿度;Wetness 为湿度;Thermal为热度;Dryness 为干度;VI 为植被指数…

【多媒体】Java实现MP4和MP3音视频播放器【JavaFX】【音视频播放】

在Java中播放音视频可以使用多种方案&#xff0c;最常见的是通过Swing组件JFrame和JLabel来嵌入JMF(Java Media Framework)或Xuggler。不过&#xff0c;JMF已经不再被推荐使用&#xff0c;而Xuggler是基于DirectX的&#xff0c;不适用于跨平台。而且上述方案都需要使用第三方库…

拒绝信息差!一篇文章说清Stable Diffusion 3到底值不值得冲

前言 就在几天前&#xff0c;Stability AI正式开源了Stable Diffusion 3 Medium&#xff08;以下简称SD3M&#xff09;模型和适配CLIP文件。这家身处风雨飘摇中的公司&#xff0c;在最近的一年里一直处于破产边缘&#xff0c;就连创始人兼CEO也顶不住压力提桶跑路。 即便这样&…

[leetcode]minimum-absolute-difference-in-bst 二叉搜索树的最小绝对差

. - 力扣&#xff08;LeetCode&#xff09; /*** Definition for a binary tree node.* struct TreeNode {* int val;* TreeNode *left;* TreeNode *right;* TreeNode() : val(0), left(nullptr), right(nullptr) {}* TreeNode(int x) : val(x), left(null…

LeetCode 196, 73, 105

目录 196. 删除重复的电子邮箱题目链接表要求知识点思路代码 73. 矩阵置零题目链接标签简单版思路代码 优化版思路代码 105. 从前序与中序遍历序列构造二叉树题目链接标签思路代码 196. 删除重复的电子邮箱 题目链接 196. 删除重复的电子邮箱 表 表Person的字段为id和email…

昇思MindSpore学习总结七——模型训练

1、模型训练 模型训练一般分为四个步骤&#xff1a; 构建数据集。定义神经网络模型。定义超参、损失函数及优化器。输入数据集进行训练与评估。 现在我们有了数据集和模型后&#xff0c;可以进行模型的训练与评估。 2、构建数据集 首先从数据集 Dataset加载代码&#xff0…

在windows上安装objection

安装命令pip install objection -i https://mirrors.aliyun.com/pypi/simple hook指定进程 objection -g 测试 explore 进程名不定是包名&#xff0c;也可能是app名字&#xff0c;如“测试”就是app的名字 若出现如下错误&#xff0c;说明python 缺少setuptools 直接安装setu…

秋招突击——设计模式补充——单例模式、依赖倒转原则、工厂方法模式

文章目录 引言正文依赖倒转原则工厂方法模式工厂模式的实现简单工厂和工厂方法的对比 抽线工厂模式最基本的数据访问程序使用工厂模式实现数据库的访问使用抽象工厂模式的数据访问程序抽象工厂模式的优点和缺点使用反射抽象工厂的数据访问程序使用反射配置文件实现数据访问程序…