论文链接:arxiv.org/abs/2304.02…
代码链接:github.com/lllyasviel/…
Demo链接:segment-anything.com/demo
SAM从使命、模型、数据三部分展开写作,和模型的立异比较起来,使命界说和数据的作业愈加出彩,官网也给出了demo,能直观感受SAM的作用,这篇blog也会环绕这几部分展开。
demo
demo中有敞开point, box, everything三种办法。由于text prompt作用不太稳定,demo和代码中都没有该部分。
-
鼠标悬停: 显示的是悬停位置的切割成果,例如下图中将鼠标放到手的位置.
-
点击: 切割包括该点的物体,会按最小切割的成果展现出来,假如想切割的物体大于展现的成果,能够在物体的其他部分也点击下。
-
box: 框定一个box,切割box中的物体
-
everything: 将图片中全部物体的切割都展现出来
使命
使命的规划灵感来自于NLP范畴,例如NLP中能够通过预测next token作为预练习使命,而在下流使命中能够运用prompt engineering做运用。因而,为了建立切割的根底模型,使命的规划方针是也需求具有相似的才能。 这儿作者扩展了下NLP里prompt在图画切割里的用法, prompt能够是以下几种类型:
- point
- box
- mask
- 恣意格式的文本
为了支撑以下的几种输入prompt格式,要求模型能够区分具有混杂含义的prompt,例如下图中,一个point的prompt可能有多种切割办法.这多种切割办法关于模型来说都是有用的。
预练习: 将上面说到的多种sequence的prompt告知模型,练习方针是让模型输出对应promt的切割成果,而且希望模型输出的成果和GT尽可能共同。区别于之前的交互式切割算法,SAM基本能治通过一次交互就能得到很合理的切割成果。要到达这个意图,需求规划十分共同的模型结构和loss。
zero-shot transfer:需求模型对任何prompt,得到适宜的切割成果。例如,假如要做实例切割,能够把检测得到的box作为prompt,SAM就能去做实例切割
related tasks: 切割里有许多子使命,例如边际切割,语义切割等,SAM能完结全部已知的切割使命和还没有作为一个方向的切割使命。之前现已有相似的能够做多种切割的模型(solo), 可是这些模型有多个子子输出,然后做排列组合能够得到多种切割成果。而SAM通过prompt将多个切割使命合并在一起。
总而言之,作者是希望SAM能够切割全部,而且能相CLIP相同,能运用到最开端没有想到的范畴。
模型
模型的结构如上图所示. prompt会通过prompt encoder
, 图画会通过image encoder
。然后将两部分embedding通过一个轻量化的mask decoder
得到融合后的特征。encoder部分运用的都是已有模型,decoder运用transformer。这部分论文中介绍的相比照较少,下面会结合代码一起整理下:
- image encoder: 运用的是用ViT走位backbone的MAE模型。在交互式切割的展现中,image encoder只会运行一次。在试验中,别离有用到ViT-H, ViT-L, ViT-B三种巨细的模型作为image encoder。代码如下,build_sam#L47
sam_model_registry = {
"default": build_sam_vit_h,
"vit_h": build_sam_vit_h,
"vit_l": build_sam_vit_l,
"vit_b": build_sam_vit_b,
}
- prompt encoder: prompt一共有point,box, mask, text四种,会将其分为三类。pint和box能够作为一类运用position encodings, text能够运用CLIP作为encoder, 而mask是一种密布型的prompt,能够运用卷积作为encoder.prompt_encoder.py#LL128C5-L128C5 prompt_encoder的代码如下所示,其间用position embedding别离完结了point和box query两种稀疏embedding,用卷积完结了mask query密布embedding.,
def forward(
self,
points: Optional[Tuple[torch.Tensor, torch.Tensor]],
boxes: Optional[torch.Tensor],
masks: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Embeds different types of prompts, returning both sparse and dense
embeddings.
Arguments:
points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates
and labels to embed.
boxes (torch.Tensor or none): boxes to embed
masks (torch.Tensor or none): masks to embed
Returns:
torch.Tensor: sparse embeddings for the points and boxes, with shape
BxNx(embed_dim), where N is determined by the number of input points
and boxes.
torch.Tensor: dense embeddings for the masks, in the shape
Bx(embed_dim)x(embed_H)x(embed_W)
"""
bs = self._get_batch_size(points, boxes, masks)
sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device())
if points is not None:
coords, labels = points
point_embeddings = self._embed_points(coords, labels, pad=(boxes is None)) # position embedding
sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
if boxes is not None:
box_embeddings = self._embed_boxes(boxes) # position embedding
sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1)
if masks is not None:
dense_embeddings = self._embed_masks(masks) # conv embedding
else:
dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
bs, -1, self.image_embedding_size[0], self.image_embedding_size[1]
)
return sparse_embeddings, dense_embeddings
- mask decoder:: 运用一个transformer将image embedding和prompt embedding做双向的cross-attention;而且也有prompt embedding的self-attention。也有MLP和linear classifier分类切割区域。mask decoder, transformer.py#L151这儿的queries是query embedding,keys是image embedding,query_pe和queries相同,key_pe是需求加到image embedding上的位置编码。query embedding会通过self attention。query embedding和image embedding会做双向的cross-attention, 具体完结办法是如上代码所示,image embedding会作为query,query embedding会作为key和value;相同的,query embedding会作为query,image embedding会作为key和value。
def forward(
self,
image_embedding: Tensor,
image_pe: Tensor,
point_embedding: Tensor,
) -> Tuple[Tensor, Tensor]:
"""
Args:
image_embedding (torch.Tensor): image to attend to. Should be shape
B x embedding_dim x h x w for any h and w.
image_pe (torch.Tensor): the positional encoding to add to the image. Must
have the same shape as image_embedding.
point_embedding (torch.Tensor): the embedding to add to the query points.
Must have shape B x N_points x embedding_dim for any N_points.
Returns:
torch.Tensor: the processed point_embedding
torch.Tensor: the processed image_embedding
"""
# BxCxHxW -> BxHWxC == B x N_image_tokens x C
bs, c, h, w = image_embedding.shape
image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
image_pe = image_pe.flatten(2).permute(0, 2, 1)
# Prepare queries
queries = point_embedding
keys = image_embedding
# Apply transformer blocks and final layernorm
for layer in self.layers:
queries, keys = layer(
queries=queries,
keys=keys,
query_pe=point_embedding,
key_pe=image_pe,
)
# Apply the final attention layer from the points to the image
q = queries + point_embedding
k = keys + image_pe
attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
queries = queries + attn_out
queries = self.norm_final_attn(queries)
return queries, keys
-
处理混杂的输入: 关于一个prompt,模型会输出3个mask,实际上也能够输出更多的切割成果,3个能够看作一个物体的整体、部分、子部分,基本能满足大多数状况。运用IOU的办法,排序mask。在反向传达时,参与核算的只要loss最小的mask相关的参数.
-
高效: 这儿首要指的是prompt encoder和mask decoder。在web浏览器上,CPU核算只用约50ms
-
loss和练习细节: 首要运用的是focal loss和dice loss。每一个mask,会随机发生11种prompt与之配对。
数据
数据引擎
不像CLIP中图画文本对通过互联网容易获取,切割的数据获取本钱巨大。SAM开源了一个10亿张图片的切割数据集。在SAM中规划了一个数据引擎用于获取切割的数据,数据引擎首要分为以下三部分:
-
辅佐标示: 简略来说便是用能够获取到的开源切割数据练习一个初始的SAM模型V0版别,再用V0在没有切割标示的数据上生成预标示,人工check模型的成果并作修改和承认。得到新的数据后,再将新的数据参加到练习集从头练习SAM得到V1版别,再循环标示数据和迭代模型。一共进行6次练习。开端的时分数据集比较少,运用的ViT-B模型,最终会运用ViT-H模型。 这儿面还有一些功率提高的数据,例如跟着模型的迭代,每个mask的标示耗时从34s到14s。SAM模型在每张图片上生成的mask从20到44个。在该阶段数据集最终有12万张图片,430万个mask
-
半主动化标示: 通过第一阶段后,现已有一个不错的SAM模型能生成切割成果。半主动化标示的意图是增加mask的多样性。具体做法是练习一个检测模型,用于检测SAM生成的mask成果是否可信,只保留可信的mask成果,然后将图片给人工标示。人工标示会在可信的mask根底上标示出其他的切割框。通过5次的迭代后,数据集新增了18万张图片,590万mask。
主动标示: 通过前面两个阶段后,SAM有了较好的成果,能切割出图片中的方针,而且关于混杂的prompt也有了较好的输出。这个模型能够主动的对一些图片做标示。主动标示的时分需求有一些挑选策略,模型输出的成果可能还是会呈现一些过错。首要有以下三种办法做挑选
-
SAM模型有一个IOU prediction的模块能输出mask的confidence,如下图所示
-
stable mask的判断,具体的办法是在得到切割成果前对logit加正向和负向的扰动,假如两次扰动生成的切割成果IOU大于0.95,则以为生成的mask是可靠的
-
NMS过滤掉重复的mask
数据质量
图画: 包括11M高分辨率(33004950)的图画,其他的一些开源数据集,例如COCO分辨率较低(480640) Mask: 包括1.1B的mask,99.1%都是模型生成的。作者试验了下,只运用模型生成的mask和即运用模型生成也运用人工标示的mask,模型的作用是适当的。因而发布的数据集里只包括模型生成的mask Mask 质量: 抽取了一部分mask数据做人工的精标,精标前后有94%的mask具有90%以上的IOU。而其他的一些开源数据集只要85-91%的IOU
下面也从mask的数量,每种mask尺寸的占比和mask占外接矩形比例等多方面和其他数据集做了比照
数据来源分布
不同性别,肤色,年龄人群切割作用的差异比照
zero-short Transfer试验
评价的数据集都是SAM模型练习时的不同,而且包括水下,第一视角等没有在SAM中呈现过场景的图片
point mask
这儿比照的是用point作为prompt比照切割的成果, 在绝大部分数据集中都优于RITM(当前的SOTA)
边际检测
SAM在练习的时分便是选用的包括point prompt的办法,作者这儿还比照了一些在练习时没有包括的办法,边界检测便是其间一种。SAM在运用边界检测时,运用办法是在图片上铺上16*16均匀的point prompt,每个prompt发生3个mask,再通过NMS后。通过Sobel filtering得到边际检测的成果。SAM的成果倾向于提取更丰厚的边际,因而在方针上recall和专门做边际检测的模型适当,precision会低些。
方针检测
切割的成果取bbox,就能做方针检测了.整体方针低于ViTDet,可是在中等常见和不太常见的方针上作用优于ViTDet
实例切割
先用一个方针检测算法,用方针检测得到的box作为prompt输入到SAM,就能够做实例切割了。试验的成果分为了定量(用测试集的GT)和定性(人来评判好坏)两种。定量的方针不如BiTDet—H,定性的方针SAM优于ViTDet。作者给出的解说是COCO数据集标示作用一般(在人看来乃至不如SAM和ViTDet模型输出的成果),因而ViTDet在COCO上做练习时拟合到了一些过错的偏差,但过错的偏差和标示相似,因而定量的方针不如ViTDet
Text to Mask
这儿指的是用文本作为prompt,然后切割出文本说到的方针。作者在练习的时分取的是图片中方针尺寸大于100*100的方针,用CLIP提取image embedding(text embedding也行,由于CLIP的image embedding和text embedding是对齐的),作为prompt encoder模块的输出,用于练习SAM模型。这一部分没有和其他办法比照,也由于作用不太稳定,在官方的demo中没有展现
消融试验
有以下的定论: 左面的图,数据来源的影响:
- 参加半主动标示的数据和主动标示的数据功能都有很大的提高
- 只用模型生成的数据与额外加上人工标示的数据差异不大
中心的图, 数据量的影响:
- 数据量从0.1M到1M,模型功能提高很大
- 数据量从1M到11M,模型功能变化不明显,实际运用中1M差不多满足
右边的图, image encoder的影响:
- ViT-B到ViT-L提高很大
- ViT-L到ViT-H提高一般,实际运用ViT-L满足
总结
SAM的热度也十分高,相同作为FB的作业,SAM只是放出来两个月,github上star的数量现已超过了detectron2三年的总和。SAM的希望是能将该模型作为图画范畴的根底模型(foundation model),像CLIP那样能在各个范畴大放反常,或许像GPT相同能一致NLP范畴。SAM也确真实许多场景得到了运用,例如开源的SD中也融合了SAM,能够做许多风趣的运用,例如从假人模特身上用SAM得到衣服的mask,再结合ControlNet,就能够生成不同的人穿戴相同的衣服。
最开端自媒体宣扬的文章也是《CV范畴不存在了》,《CV界的GPT3》相似的标题,SAM确实是在一致上迈出了很大的一步,但实际上CV范畴的一致还有许多挑战。NLP范畴中的Bert用完型填空和GPT预测下一个token的预练习在十分多的使命上表现了很好的泛化性,乃至在一些没有练习过的使命上能取得比一些专家模型更好的作用。
-
使命和数据上的不一致,CV范畴的分类是输出类别,检测输出bbox,切割输出mask。尽管单个使命能够复用,可是整体缺少一个通用的使命。使命上的不一致,数据上也很难做到一致,分类的使命有许多数据,可是检测和切割的数据就要少十分多,而且标示本钱巨大。单纯练习分类作为backbone也很难处理其他使命,检测和切割的算法依然需求做很多的优化
-
CV范畴的使命缺少孕育大模型的土壤,CV使命一直在考虑模型的核算量,显存占用。假如将每个像素看作一个token,一张512*512的图片就有26万个token。假如transformer最开端呈现在CV范畴,面对的问题是显存和核算量都比resnet差,而且作用也远不如resnet。假如没有transformer极大的促进了NLP范畴的开展,CV范畴可能也不会从头思考transfomer能增大感受野,能有更好的泛化才能。
-
还没有找到CV范畴【高维】的使命。NLP范畴的完形填空和对话确实是一种很高维的使命。模型能完结这些使命,一些NER或许RE之类的底层使命也能很好的被处理。现在CV范畴有一些尝试做foundation model,例如比照学习或许像SAM,在一些使命上表现了不错的泛化性,可能是这些办法能一致其他使命,但现在的开展还不太够,也可能是其他一些还没呈现/开展起来的使命。但这种【高维】的使命一定能通过一些办法降维处理现在简直全部的CV根底使命。