MobileSAM论文笔记

本文介绍了一种针对移动设备的优化版本MobileSAM,通过解耦蒸馏法将原始SAM的图像编码器知识传授给轻量级模型,使得模型在保持高性能的同时,参数量减少60倍,适合移动设备使用,推理速度显著提升。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

摘要

自Meta研究团队发布SAM(Segment Anything Model)项目依赖,因其令人惊艳的零样本迁移特性和与其他视觉应用兼容的高通用性,引起了极大的关注。由于大多数类似的应用都需要运行在资源限制的边缘设备,如手机,因此,本文的目标是通过使用轻量化的encoder替换原始计算量大的encoder使其称为移动友好型模型。一个简单的思路是按照SAM原文训练一个新的轻量化的SAM,但是效果不理想,尤其是在有限训练资源的情况下。作者发现效果不理想是由于图像编码器和mask解码器的耦合优化(coupled optimization)引起的,因此,提出了解耦蒸馏法。具体而言,就是将原始SAM的图像编码器的知识蒸馏给一个轻量化的encoder,它可以与原始SAM中的mask解码器自动兼容。MoibleSAM的训练过程可以在1天之内单张GPU上完成,相同性能的情况下,参数量少了60多倍。在推理速度上,MobileSAM平均10ms一张图像,特征提取8ms和mask解码2ms。凭借卓越的性能和更高的通用性,MobileSAM比同期的FastSAM小7倍,快5倍,更适合移动端使用。此外,MobileSAM可以在CPU上更加流畅的推理运行。
太长不看版:
作者发现由于图像编码器和mask解码器的耦合优化,使得单纯重新训练一个轻量化的SAM效果不佳,提出了解耦蒸馏法。具体而言,就是将原始SAM的图像编码器的知识蒸馏给一个轻量化的encoder,它可以与原始SAM中的mask解码器自动兼容。
MoibleSAM的训练过程可以在1天之内单张GPU上完成,相同性能的情况下,参数量少了60多倍。在推理速度上,MobileSAM平均10ms一张图像,特征提取8ms和mask解码2ms

模型结构

SAM

在介绍MoibleSAM之前,先了解一下SAM模型。如下图所示,SAM包含两个部分:图像编码器和mask解码器,相比于mask解码器的4M参数量,图像编码器的参数量为632M,这使得部署SAM在移动端非常困难。因此,使用轻量化的模型替换原始的图像编码器成为优化手段。
image.png

解耦蒸馏

如下图所示,左边为coupled distillation,指的是采用知识蒸馏的方法将ViT-H的知识蒸馏给一个更小的图像编码器。然而,这种直接替换在重新训练的难度主要在于图像编码器和mask解码的耦合优化。基于分而治之的思路,作者提出将知识蒸馏的任务分为两个子任务:图像编码器蒸馏和mask编码器微调。如右图所示,将mask解码器冻结之后进行知识蒸馏,这种蒸馏法称为seim-coupled distillation。
image.png
然而,根据经验发现这种优化仍然存在挑战性,因为prompt的选择是随机的,使得mask解码器可变,从而增加了优化难度。因此,提出直接把ViT-H蒸馏到小的图像编码器中,该方法称为decoupled distillation,如下图所示。在图像embedding上直接进行蒸馏的好处是可以使用简单的MSE损失,而不需要使用focal loss和dice loss。
image.png

微调mask解码器的必要性

使用decoupled蒸馏法后,发现由于student图像编码器生成的embedding接近teacher,使得微调mask解码器变成可选项。如果需要微调mask解码器,可以冻结图像编码器或者一起微调可能进一步提升性能。

效果对比

将原始SAM生成的mask作为GT,输入相同点比较coupled和的coupled蒸馏法的效果差异。在coupled蒸馏中,使用ViT-B作为图像编码器,在SA-1B数据集上训练180个迭代。相反,在decoupled蒸馏中,使用SA-1B的1%数据训练55个迭代。结果对比如下表所示,decoupled蒸馏法取得0.75的mIoU。
image.png

实验结果

轻量化图像编码器

本文的目标是通过使用一个轻量化模型替换原始SAM的ViT-H来得到一个高效的SAM。作为一个ViT-based基础网络,ViT-Tiny与Deit-Tiny有相似的参数量但效果更好。例如,在ImageNet-1K数据集上,Deit-Tiny取得72.2%精度但ViT-Tiny是79.1%。因此采用ViT-Tiny作为MobileSAM的图像编码器,验证本文提出的解耦蒸馏法有效性。
image.png

训练和评估细节

为了蒸馏图像编码器,本文使用1%的SA-1B数据在单张GPU上训练8个epoch。由于在知识蒸馏过程中,教师图像编码器占用大部分的计算资源,因此,为了加速蒸馏过程,首先将数据集的图像的embedding保存下来。至此,在单张GPU下,训练不到一天就可以得到MobileSAM。在更多GPU上训练更长时间可以得到更好的效果。最开始的研究结果可知通过微调mask解码器可以进一步提升MobileSAM的性能,但是本文中为了简单跳过了该步骤。为了评估MoblieSAM的效果,计算原始SAM和MobileSAM预测的mask之间的mIoU。

可视化结果

MobileSAM给出了基于点和目标框的两种预测mask结果,以及分割一切的结果对比,如下图所示。
image.png
image.png
image.png

消融实验

如下图所示,可以发现在相同迭代次数情况下,增大batch size可以提升模型性能。此外,在相同batch size下,通过增加训练epoch可以提升性能。
image.png

与FastSAM对比

如下表所示,FastSAM包含68M参数量,推理一张图片用时40ms。相反,MoibleSAM参数量小于10M,推理时间进行12ms。
image.png
FastSAM推理mask时需要输入多个points,因此将其中1个点设为前景,其余都为背景点。下图为对比结果,可以看到FastSAM的mIoU远小于MobileSAM表明FastSAM预测的mask与原始SAM的偏差较大。
image.png

结论

本文工作的目标是通过使用轻量化的模型替换原始SAM的图像编码器,使其移动端友好。由于SAM中图像编码器和mask解码器的耦合优化,使得直接训练一个轻量化模型效果不佳,因此提出了解耦蒸馏,将图像编码器ViT-H的知识蒸馏给一个轻量化的编码器。MobileSAM的参数量比原始SAM少了60多倍,同时MobileSAM延续原始SAM模型的接口,仅仅替换了图像编码器,因此,可以与SAM模型的使用方法无缝对接。

### MobileSAM部署教程和配置说明 #### 1. 安装依赖库 为了成功部署MobileSAM,需先安装必要的Python包和其他依赖项。建议创建一个新的虚拟环境来管理这些依赖。 ```bash pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu117 pip install git+https://github.com/facebookresearch/segment-anything.git@main pip install opencv-python matplotlib Pillow scikit-image onnxruntime-gpu tensorrt ``` 上述命令会安装PyTorch及其扩展、Segment Anything Model (SAM),以及用于图像处理的各种工具[^4]。 #### 2. 下载预训练模型权重 访问官方GitHub仓库获取最新的MobileSAM版本,并下载对应的预训练参数文件: ```bash wget https://github.com/ChaoningZhang/MobileSAM/releases/latest/download/mobile_sam.pt ``` 此链接指向最新发布的`mobile_sam.pt`文件,该文件包含了经过优化后的轻量化SAM架构的权重数据。 #### 3. 准备推理脚本 编写一个简单的Python脚本来加载模型并执行预测任务。下面是一个基本的例子: ```python from segment_anything import sam_model_registry, SamAutomaticMaskGenerator import cv2 import numpy as np def load_image(image_path): image = cv2.imread(image_path) rgb_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) return rgb_image device = 'cuda' if torch.cuda.is_available() else 'cpu' model_type = "vit_t" checkpoint = "./mobile_sam.pt" sam = sam_model_registry[model_type](checkpoint=checkpoint).to(device=device) mask_generator = SamAutomaticMaskGenerator(sam) image_path = './example.jpg' input_image = load_image(image_path) masks = mask_generator.generate(input_image) for i, mask_data in enumerate(masks): mask = mask_data['segmentation'] color_mask = np.random.random((1, 3)).tolist()[0] input_image[mask != 0] = input_image[mask != 0] * 0.5 + np.array(color_mask) * 0.5 cv2.imshow('Output', cv2.cvtColor(input_image.astype(np.uint8), cv2.COLOR_RGB2BGR)) cv2.waitKey(0); cv2.destroyAllWindows() ``` 这段代码展示了如何读取输入图片、初始化MobileSAM实例、生成掩码并将结果可视化显示出来。 #### 4. 调整设置以适应特定需求 对于不同的应用场景可能需要调整一些超参或修改部分逻辑。比如当希望实现全自动化的物体分割时,可以将默认使用的交互式模式改为全局扫描方式,即把`generate()`中的调用更改为`generate_everything()`。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值