【即插即用】SimAM无参数注意力机制(附源码)

论文介绍了一种无需增加参数的注意力模块,通过优化能量函数确定神经元重要性,提供快速实现方法,提升ConvNets在视觉任务中的性能。作者在Pytorch-SimAM开源代码中展示了该模块的应用。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

论文地址:https://proceedings.mlr.press/v139/yang21o#/
代码地址:https://github.com/ZjjConan/SimAM
摘要:

在这篇论文里,我们提出了一种概念简单但非常有效的注意力模块,用于卷积神经网络(ConvNets)。与现有的通道式和空间式注意力模块不同,我们的模块在不增加原网络参数的情况下,为每一层的特征图推导出3D注意力权重。具体来说,我们基于一些广为人知的神经科学理论,提出优化一个能量函数来找出每个神经元的重要性。我们还为这个能量函数推导出了一个快速的封闭形式解,这个解只需要不到十行代码就能实现。这个模块的另一个优点是,大多数操作符都是基于定义的能量函数的解来选择的,这样就避免了过多的结构调整工作。在各种视觉任务上的定量评估表明,我们提出的模块既灵活又有效,能够提升许多ConvNets的表示能力。

简单来说,我们发明了一种新的注意力模块,可以让卷积神经网络更好地关注图像中重要的部分。与现有的方法不同,我们这个方法不需要给网络增加额外的参数,只需要优化一个能量函数来找出每个神经元的重要性。我们还找到了一个简单快速的解法,用几行代码就能实现。这个模块不仅灵活,而且在各种视觉任务中都表现出色,能够提升网络的性能。如果你对这个模块感兴趣,可以在Pytorch-SimAM上找到我们的代码。

原理图
Pytorch源码:
 
import torch
import torch.nn as nn

class SimAM(nn.Module):
    def __init__(self, lamda=1e-5):
        super().__init__()
        self.lamda = lamda
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        # 获取输入张量的形状信息
        b, c, h, w = x.shape
        # 计算像素点数量
        n = h * w - 1
        # 计算输入张量在通道维度上的均值
        mean = torch.mean(x, dim=[-2,-1], keepdim=True)
        # 计算输入张量在通道维度上的方差
        var = torch.sum(torch.pow((x - mean), 2), dim=[-2, -1], keepdim=True) / n
        # 计算特征图的激活值
        e_t = torch.pow((x - mean), 2) / (4 * (var + self.lamda)) + 0.5
        # 使用 Sigmoid 函数进行归一化
        out = self.sigmoid(e_t) * x
        return out

# 测试模块
if __name__ == "__main__":
    # 创建 SimAM 实例并进行前向传播测试
    layer = SimAM(lamda=1e-5)
    x = torch.randn((2, 3, 224, 224))
    output = layer(x)
    print("Output shape:", output.shape)

### 如何在 YOLOv8 中集成 SimAM 注意力机制 要在 YOLOv8 中集成 SimAM 注意力机制,可以通过修改其配置文件或直接调整源码来完成。以下是具体的实现方法和代码示例。 #### 修改 YOLOv8 配置文件 YOLOv8 提供了灵活的 YAML 文件用于定义模型架构。可以在该文件中指定使用 SimAM 注意力机制的位置。通常情况下,在 backbone 和 neck 层次中加入注意力机制可以显著提升性能[^3]。 假设我们正在使用的默认配置文件为 `yolov8n.yaml`,需要对其进行如下修改: ```yaml # yolov8n_simam.yaml nc: 80 # 类别数 depth_multiple: 0.33 # 深度倍增器 width_multiple: 0.25 # 宽度倍增器 backbone: # [from, number, module, args] [[-1, 1, Conv, [64, 3, 2]], # 0-P1/2 [-1, 1, SimAM, []], # 添加 SimAM 到 P1/2 ... ] neck: [[-1, 1, C3, [256, 3, False]], [-1, 1, SimAM, []], # 添加 SimAM 到 neck ... ] ``` 通过这种方式,SimAM 被嵌入到骨干网络和特征金字塔网络 (FPN) 的特定层中。 --- #### 自定义 Python 实现 如果希望更深入地控制 SimAM 的行为,则可以直接编写自定义模块并将其注入到 YOLOv8 的训练流程中。以下是一个完整的代码示例: ```python import torch.nn as nn import torch class SimAM(nn.Module): """ SimAM注意力机制实现。 参考官方代码 https://github.com/ZjjConan/SimAM """ def __init__(self, e_lambda=1e-4): super(SimAM, self).__init__() self.e_lambda = e_lambda def forward(self, x): b, c, h, w = x.size() n = w * h - 1 x_minus_mu_square = (x - x.mean(dim=[2, 3], keepdim=True)).pow(2) y = x_minus_mu_square / (4 * (x_minus_mu_square.sum(dim=[2, 3], keepdim=True) / n + self.e_lambda)) + 0.5 return x * torch.sigmoid(y) def add_simam_to_yolov8(): """将 SimAM 注入到 YOLOv8 模型""" from ultralytics import YOLO class CustomYOLO(YOLO): def build_model(self, cfg="yolov8n.yaml", verbose=True): model = super().build_model(cfg, verbose) # 替换部分卷积层后的激活函数为 SimAM for name, layer in model.named_modules(): if isinstance(layer, nn.Conv2d): # 找到所有卷积层 simam_layer = SimAM(e_lambda=1e-4) setattr(model, f"{name}_simam", simam_layer) return model return CustomYOLO if __name__ == "__main__": custom_yolo = add_simam_to_yolov8()("yolov8n.yaml") # 创建带有 SimAM 的 YOLOv8 模型 results = custom_yolo.train( data="coco128.yaml", epochs=10, imgsz=640, device=0, batch=16 ) ``` 此代码片段展示了如何创建一个继承自 `YOLO` 的子类,并在其内部动态插入 SimAM 模块。注意这里的 `add_simam_to_yolov8()` 函数会遍历整个模型结构,找到合适的卷积层位置SimAM。 --- #### 训练过程中的注意事项 当集成了 SimAM 后,可能需要注意以下几个方面: 1. **超参数调节**:由于引入了额外的计算开销,建议适当减少批量大小 (`batch`) 并增加梯度累积次数以保持 GPU 显存占用平衡[^2]。 2. **优化器选择**:推荐继续沿用 SGD 或 AdamW 等经典优化策略,同时监控学习率变化曲线。 3. **硬件支持**:考虑到 SimAM 对于某些低功耗设备可能存在效率瓶颈,可探索量化感知训练或其他压缩技术来缓解这一问题[^1]。 ---
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值