GlobalAveragePooling中x.mean([2,3])的理解

本文介绍了GlobalAveragePooling2D层的工作原理。在通道优先的数据格式下,该层会对输入数据的第二和第三维度进行平均操作,实现特征的降维。这一过程在卷积神经网络中常用于减少全连接层参数,提高模型泛化能力。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

class GlobalAveragePooling(nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self, x):
        return x.mean([2, 3])

在GlobalAveragePooling2D有如下代码:

def call(self, inputs):
    if self.data_format == 'channels_last':
      return backend.mean(inputs, axis=[1, 2])
    else:
      return backend.mean(inputs, axis=[2, 3])

可以看到如果输入数据为channels_first,即通道在前面的话,GAP便是对第二第三维度进行求平均。

class SqueezeExcitation(nn.Module): def __init__(self, in_channels, reduction): super(SqueezeExcitation, self).__init__() self.fc1 = nn.Linear(in_channels, in_channels // reduction, bias=False) self.fc2 = nn.Linear(in_channels // reduction, in_channels, bias=False) self.relu = nn.ReLU(inplace=False) self.sigmoid = nn.Sigmoid() def forward(self, x): # Squeeze batch_size, channels, _, _ = x.size() y = x.view(batch_size, channels, -1).mean(dim=2) # Global average pooling # Excitation y = self.fc1(y) y = self.relu(y) y = self.fc2(y) y = self.sigmoid(y).view(batch_size, channels, 1, 1) # Reshape for scaling # Scale return x * y # Scale the input features class HeadMLP(nn.Module): def __init__(self, in_channels=3, out_channels=3): super(HeadMLP, self).__init__() self.conv1 = nn.Conv2d(in_channels, 128, kernel_size=3, stride=1, padding=1) self.bn1 = nn.BatchNorm2d(128) self.relu1 = nn.ReLU(inplace=False) self.se1 = SqueezeExcitation(128, 8) self.conv2 = nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1) self.bn2 = nn.BatchNorm2d(64) self.relu2 = nn.ReLU(inplace=False) self.se2 = SqueezeExcitation(64, 4) self.conv3 = nn.Conv2d(64, 3, kernel_size=3, stride=1, padding=1) self.bn3 = nn.BatchNorm2d(3) self.relu3 = nn.ReLU(inplace=False) #self.pool = nn.AvgPool2d(kernel_size=3, stride=1, padding=1) #self.pool = nn.AdaptiveAvgPool2d((256, 256)) #self.se1 = SqueezeExcitation(128, 8) #self.se2 = SqueezeExcitation(64, 8) def forward(self, x): x = self.relu1(self.bn1(self.conv1(x))) x = self.se1(x) x = self.relu2(self.bn2(self.conv2(x))) x = self.se2(x) x = self.relu3(self.bn3(self.conv3(x))) #x = self.pool(x) return x如果我想要将这个头部cnn放入移动端,应该怎么修改使其FLOPs降低,但是不影响效果
最新发布
03-14
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

生成滞涨网络~

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值