SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformer *模型解析及滑坡语义分割实战

实时语义分割之BiSeNetv2(2020)结构原理解析及建筑物提取实践
实时语义分割之Deep Dual-resolution Networks(DDRNet2021)原理解析及建筑物提取实践
SegFormer 模型解析:结合 Transformer 的高效分割框架原理及滑坡语义分割实战

用自己的数据训练yolov11目标检测

概要

提示: (1)论文地址:https://arxiv.org/abs/2105.15203 (2)官方代码:https://github.com/NVlabs/SegFormer(3)本人打包好数据可运行代码:SegFormer.zip 链接: https://pan.baidu.com/s/1IhZNe1JjWtVNgZPKHOK9PA?pwd=7rgd 提取码: 7rgd
SegFormer 是一种基于 Transformer 的语义分割模型,通过结合 多尺度特征融合 和 轻量级解码器,在保持高效性的同时实现了高精度分割。其核心思想是利用 Transformer 提取全局上下文信息,并通过多级特征融合提升细节恢复能力。
多种实时语义分割效率精度对比

理论知识

整体架构流程

SegFormer网络架构

class SegFormer(BaseModel):
    def __init__(self, backbone: str = 'MiT-B0', num_classes: int = 2) -> None:
        super().__init__(backbone, num_classes)
        self.decode_head = SegFormerHead(self.backbone.channels, 256 if 'B0' in backbone or 'B1' in backbone else 768, num_classes)
        self.apply(self._init_weights)

    def forward(self, x: Tensor) -> Tensor:
        y = self.backbone(x)
        y = self.decode_head(y)   # 4x reduction in image size
        y = F.interpolate(y, size=x.shape[2:], mode='bilinear', align_corners=False)    # to original image shape
        return y

SegFormer 由两部分组成:

  • Backbone(主干网络):基于 MiT(Mix Transformer)的多尺度特征提取器。
  • Decoder(解码头):轻量级的特征融合与预测模块。

关键模块解析

模块名称功能描述结构在模型中定义
BaseModel父类,基础模型框架class BaseModel
MiT Backbone混合Transformer编码器class MiT
SegFormerHead多尺度特征融合分割头class SegFormerHead

基础模型框架 (BaseModel)

class BaseModel(nn.Module):
    def __init__(self, backbone: str = 'MiT-B0', num_classes: int = 19):
        super().__init__()
        backbone, variant = backbone.split('-')
        self.backbone = eval(backbone)(variant)  # 动态加载 MiT 系列

    def _init_weights(self, m: nn.Module):
        # 分层初始化策略
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)  # 截断正态分布
        elif isinstance(m, nn.Conv2d):
            # He 初始化适配卷积层
            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))

    def init_pretrained(self, pretrained: str = None):
        # 灵活加载预训练权重
        self.backbone.load_state_dict(torch.load(pretrained), strict=False)


模型骨架的抽象基类,支持预训练模型加载,提供权重初始化标准流程。

_no_grad_trunc_normal_初始化策略对比:

层类型初始化方法标准差目的
Linear截断正态分布0.02防止过大初始值
Conv2dHe 初始化sqrt(2/fan_out)适配 ReLU 激活
LayerNorm权重=1,偏置=0-保留原始分布

Backbone:MiT(Mix Transformer)

代码逻辑(简化示意)

class MiT(nn.Module):
    def __init__(self, variant: str):
        # 根据变体配置参数(B0-B5)
        self.embed_dims = {
            'B0': [32, 64, 160, 256], 
            'B1': [64, 128, 320, 512],
            # ... 其他变体配置
        }
        
        # 构建四阶段特征提取
        self.stages = nn.ModuleList([
            OverlapPatchEmbed(...),  # 重叠块嵌入
            BlockSequence(...),      # Transformer Block 堆叠
            # 重复构建各阶段...
        ])

    def forward(self, x):
        features = []
        for stage in self.stages:
            x = stage(x)
            features.append(x)  # 收集多尺度特征
        return features

说明:

  • 重叠块嵌入:使用 7x7 卷积+滑动步长stride=4 实现块嵌入,保留边缘信息

  • 效率优化:通过 Sequence Reduction Attention 将 K/V 维度压缩到 1/4

  • 层次化设计:每个阶段输出分辨率递减(1/4, 1/8, 1/16, 1/32)

解码头(SegFormerHead)

代码逻辑(简化示意)

class SegFormerHead(nn.Module):
    def __init__(self, dims: list, embed_dim: int = 256):
        # 多层级特征投影
        for i, dim in enumerate(dims):
            self.add_module(f"linear_c{i+1}", MLP(dim, embed_dim))
        
        # 特征融合模块
        self.linear_fuse = ConvModule(
            in_channels=embed_dim*4,  # 四级特征拼接
            out_channels=embed_dim,
            kernel_size=1
        )
        
        # 最终预测层
        self.linear_pred = nn.Conv2d(embed_dim, num_classes, 1)

    def forward(self, features):
        # 特征对齐与融合流程
        outs = []
        for i, feat in enumerate(features):
            # 步骤1:MLP维度统一
            proj_feat = getattr(self, f'linear_c{i+1}')(feat)
            
            # 步骤2:上采样对齐分辨率
            if i > 0:
                proj_feat = F.interpolate(proj_feat, 
                    size=features[0].shape[-2:], 
                    mode='bilinear')
            
            outs.append(proj_feat)
        
        # 步骤3:逆序拼接与融合
        fused = self.linear_fuse(torch.cat(outs[::-1], dim=1))
        return self.linear_pred(fused)

说明:

  • 统一维度投影:将不同层级的特征映射到相同维度(256/512)

  • 特征逆序拼接:深层语义特征在前,浅层细节特征在后

  • 轻量融合结构:仅使用 1x1 卷积 + BN + ReLU 实现特征聚合

模型实践

训练数据准备

提示:云盘代码已内置少量山体滑坡数据集

训练数据分为原始影像和标签(二值化,0和255),均位于Sample文件夹内,数据相对路径为:

Sample\Build\train\ IMG_T1
------------------\ IMG_LABEL
-------------\val \ IMG_T1
------------------\IMG_LABEL

本示例中影像尺寸不一,在dp0_train.py文件parser参数项,crop_height 、crop_width设置为256、256,训练数据在CDDataset_Seg定义中进行了resize,模型训练中均为256*256。

模型训练

运行dp0_train.py,模型开始训练,核心参数包括:

parser参数说明
num_epochs训练批次
learning_rate初始学习率
batch_size单次样本数量
dataset数据集名字
crop_height训练时影像重采样尺度

数据结构CDDataset_Seg定义在utils文件夹dataset.py中,注意读取后进行了数据增强(随机翻转),灰度化,尺寸调整,标签归一化、 one-hot 编码,以及维度和数据类型的转换,最终得到适用于 PyTorch 模型训练的张量。

        # 读取训练图片和标签图片
        image_t1 = cv2.imread(image_t1_path,-1)
        #image_t2 = cv2.imread(image_t2_path)
        label = cv2.imread(label_path)
       
        # 随机进行数据增强,为2时不做处理
        if self.data_augment:
            flipCode = random.choice([-1, 0, 1, 2])
            if flipCode != 2:
#                image_t1 = normalized(image_t1, 'tif')
                image_t1 = self.augment(image_t1, flipCode)
                #image_t2 = self.augment(image_t2, flipCode)
                label = self.augment(label, flipCode)        
            
        label = cv2.cvtColor(label, cv2.COLOR_BGR2GRAY)
        
        image_t1 = cv2.resize(image_t1, (self.img_h, self.img_w))
        #image_t2 = cv2.resize(image_t2, (Config.img_h, Config.img_w))

        label = cv2.resize(label, (self.img_h, self.img_w))
        label = label/255
        label = label.astype('uint8')
        label = onehot(label, 2)
        label = label.transpose(2, 0, 1)
        label = torch.FloatTensor(label)

训练过程如下图所示,模型保存至checkpoints数据集同名文件夹内。

影像测试

运行dp0_AllPre.py,核心参数包括:

parser参数说明
Checkpointspath预训练模型位置名称
Dataset批量化预测数据文件夹
Outputpath输出数据文件夹

数据加载方式:

    pre_dataset = CDDataset_Pre(data_path=pre_imgpath1,
                            transform=transforms.Compose([
                                transforms.ToTensor()]))

    pre_dataloader = torch.utils.data.DataLoader(dataset=pre_dataset,
                                             batch_size=1,
                                             shuffle=False,
                                             pin_memory=True,
                                             num_workers=0)

需要注意,预测定义的数据结构CDDataset_Pre与训练时CDDataset_Seg有区别,此处将其resize回原始尺寸,后续使用自己的数据训练时注意调整。

结果示例

在这里插入图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值