Segment Anything
Segment Anything(SA)项目:一个用于图像分割的新任务、新模型和新数据集

通过FM(基础模型)+prompt解决了CV中难度较大的分割任务,给计算机视觉实现基础模型+提示学习+指令学习提供了一种思路
关键:加大模型容量(构造海量的训练数据,或者构造合适的自监督任务来预训练)
Segment Anything Task
SAM的一部分灵感是来源于NLP中的基座模型(Foundation Model),Foundation Model是OpenAI提出的一个概念,它指的是在超大量数据集上预训练过的大模型(如GPT系列、BERT),这些模型具有非常强大的 zero-shot 和 few-shot能力,结合prompt engineering和fine tuning等技术可以将基座模型应用在各种下游任务中并实现惊人的效果。
SAM就是想构建一个这样的图像分割基座模型,即使是一个未见过的数据集,模型也能自动或半自动(基于prompt)地完成下游的分割任务。为了实现这个目标,SAM定义了一种可提示化的分割任务(promptable segmentation task),这个提示可以是点、框、掩码、文本(代码中未实现)等形式,基于这个提示模型就能分割出提示处所在物体的masks。同时这种提示可以是模糊的,比如以下图剪刀握手那的黄色部分点为提示,分割掩码可以是下图最右边三种情况中任意一种,从上到下分别代表whole, part, subpart三种层级的分割,这也是SAM兼容的。

要达到这种效果就需要足够的高质量分割数据,SAM团队用他们提出的Data Engine策略成功使用人工加模型自动标注的方式制作除了一个有10亿个masks的分割数据集SA-1B,这也是他们核心的贡献之一,本文尾部会介绍相关流程。模型架构来说相对比较常规,主要是借鉴了ViT和DETR,本身创新不大。
模型结构
SAM模型架构主要包括image encoder,prompt encoder和mask decoder三部分:
- image encoder,使用了用MAE预训练的ViT模型将图像编码得到image embedding
- prompt encoder,将point、box、mask、txt等提示信息进行编码,后续会和image embedding一起用于生成masks
- mask decoder,将上述两个模块得到的embeddings整合,然后结合两个可学习的tokens生成不同层级的masks和对应的置信度值
值得一提的是,prompt encoder和mask decoder都是非常轻量的,主要的计算开销都在image encoder上,这点从模型权重上也能看出来,以ViT_B为基础的SAM权重是375M,其中prompt encoder只有32.8k,mask decoder是16.3M(4.35%),剩余则是image encoder,可想而知图像编码这块是非常耗时的。因此在实际推理中,一般单张图的image embedding只计算一次,然后将结果缓存起来,需要的时候直接调用。在image embedding已经计算好的情况下,论文中说给定一个prompt,生成mask时prompt encoder和mask decoder在浏览器中的计算耗时也仅需50ms。下面会具体介绍下各模块的输入输出和流程,均只考虑batch size为1的情况
Image encoder
输入:默认是1024x1024的图像,如尺寸不一致会将原图按最长边resize
输出:单张图的1x256x64x64的image embedding,即编码后的图像特征
流程:

上图是ViT论文中的结构图,image encoder整体流程和ViT是一样的,区别在于不需要[class]token做分类,只输出最终的图像编码张量
- 输入1024的图,拆分成64x64的768维patchs
- 经过attention block(window和global的MSA,相对位置编码)和MLP得到同样大小64x64x768embbeding特征
- 再经过neck得到1x256x64x64的图片embedding
Prompt encoder
输入:
point、box、mask、txt等prompt,格式一般如下,B为batch size
- point需要包含点的 \(x,y\) 坐标 \(B\times N\times2\) 和label(0为前景,1位背景)\(B\times N\times1\)
- box包含框的左上和右下两个点,\(B\times N\times4\),对于某个gt即单个mask,只会有1个box;如果输入的是N个box最终会生成N个masks
- mask一般和SAM最终输出mask的\(h\times w\)(256x256)大小相同,Bx1xHxW
- txt: 使用clip进行编码
对于points以及bbox编码原理很简单就是用点或者坐标直接计算他们的傅里叶特征,比如说points的伪代码
输出两个:
- sparse_embeddings 点和框的稀疏嵌入,形状为BxNx(embed_dim),其中N由输入点和框的数量确定,如果两者同时有则N的计算方式为(点的个数+2x框的个数)
- dense_embeddings 掩码的密集嵌入,形状为Bx(embed_dim)x(embed_H)x(embed_W),默认大小为Bx256x64x64,没有提示时会返回一个网络学习到的no mask默认嵌入
流程:
网络已自动学会了针对不通过类型提示的编码信息,输入的point、box、mask等提示加上位置编码后,再加上网络学会的综合编码信息,最终对point、box这种稀疏的提示会返回sparse embedding, 对mask会返回dense embeddings(没有mask提示时是网络学习到的embeddings)。这部分就相当于把各种提示转换为decoder能理解的格式。
Mask decoder
输入:
- image encoder得到的image_embeddings和图像的positional encoding
-
prompt encoder得到的prompt embeddings(sparse和dense两种)
输出: -
masks,如果指定了"multimask_output"参数则会输出3个层级的mask(whole, part, and subpart),否则只输出1个mask
-
IoU scores,可以理解为每个mask的置信度,由网络中的iou token得到
流程 -
首先image_embeddings会混入dense embeddings的信息(两者直接相加),sparse embeddings则会与mask token和IoU token拼在一起成为一个新的token,mask token后续会用于生成mask,IoU token用于衡量每个mask的好坏
- 然后这个新的token和image_embeddings经过一个TwoWayTransformer模块(下图黄色框部分),先做token的self attention,然后做token(作为query)到图像的cross attention,经过MLP更新token,最后做图像(作为query)到token的attention,目的是不断更新图像和token中的信息,会重复两次
- 更新后token再做一次token(作为query)到图像的cross attention后,又拆出来之前的两个部分mask token和IoU token,后者就代表每个mask的置信度;

整图分割推理(segment everything)

流程
在图片上生成\(32\times32\)的网格,得到1024个采样点,每个采样点都当做1个前景的prompt进入prompt encoder然后和image encoder结果一起生成mask,每次会处理一个batch(默认64)的采样点;每个batch得到的mask都会进行以下几个过滤:
- predicted IoU过滤,mask decoder除了返回masks还会预测对应mask iou值,过滤低置信度(默认阈值0.88)的mask
- stability score过滤,stability score是mask在两个阈值下二值化后的IoU值,可以理解为改变过滤阈值后还能得到同样mask的能力,过滤低于0.95的mask
- mask threshold过滤,直接过滤mask logits值低于mask_threshold(默认0.0)的mask
- boundary过滤,每个mask生成外界矩形,过滤超过图像边界的mask
所有batch过滤后的的masks结果再进行nms过滤(mask对应外接矩形的nms,阈值0.7)就得到最终的分割结果
最终结果
git上也有官方demo可以参考:全图分割的官方demo

数据引擎(data engine)
SAM除了模型外,还公开了一份有10亿个masks的1100万张图的分割数据集SA-1B,基于他们提出的data engine方案得到,这块的贡献也是非常显著,也体现了Data-centric AI的惊人能力. 从论文里总结就是辅助人工标注、半自动标注、全自动标注三步,具体如下:

- 第一步以人工标注为主。初始模型在公开数据集训练后辅助生成masks,再人工精修调整,再用标好的新数据迭代模型。如此重复6次,从12万张图得到430万masks
- 第二步是模型半自动标注高置信度masks,然后人工标注补充剩余未标出的masks。mask的置信度判断是用一个模型对mask进行目标检测,如果能检测出物体则是置信度较高mask无需再人工标注,这个目标检测模型是基于第一步得到的数据训练的。如此迭代5次,从18万张图新增了590万masks
- 第三部是模型全自动标注。基于此前两步的数据得到模型,已有较好的分割能力且能适配模糊提示分割(局部mask或者整体mask),对一张图撒32x32的网格点进行segment everything,后处理会挑选搞IoU和搞稳定性的masks并做NMS得到全图最终的masks。针对所有图片自动分割,最终得到了SA-1B数据集
SAM v2
SAM v2更像是SAM v1在视频邻域的泛化,整个模型结构如下所示:

主要值得关注的是其中的 Memory Attention:将当前帧的特征与过去帧的特征和预测以及任何新的提示联系起来。通过堆叠了 L 个transformer模块,第一个模块将当前帧的图像编码作为输入。每个区块执行self-attention,然后cross-attention(提示/未提示)帧和对象的记忆,这些记忆存储在一个记忆库中,接着是一个 MLP。在self-attention和cross-attention中使用了vanilla注意力操作,从而受益于高效注意力内核的最新发展。
memory encoder通过使用卷积模块对输出掩码进行下采样,并将其与图像编码器的无条件帧嵌入相加,生成记忆,然后使用轻量级卷积层来融合信息。
memory bank通过维护最多N个最近帧的FIFO记忆队列来保留视频中目标对象的过去预测信息,并将提示信息存储在最多M个提示帧的FIFO队列中。例如,在VOS任务中,初始掩码是唯一的提示,内存库始终保留第一帧的记忆以及最多N个最近(非提示)帧的记忆。两组记忆都以空间特征图的形式存储。
除空间存储器外,还根据每个帧的掩码解码器输出标记,将对象指针列表作为轻量级向量存储起来,用于存储要分割对象的高级语义信息。
我们将时间位置信息嵌入到N个最近帧的memory中,允许模型表示短期物体运动,但不包含到提示帧的记忆中,因为提示帧的训练信号更稀疏,并且更难以推广到推理设置中,提示帧可能来自与训练期间看到的时间范围非常不同的时间范围。