import math
import torch
import torch.nn as nn
import torch.utils.model_zoo as model_zoo
from models.blocks import CBAMLayer, SPPLayer
import logging
from functools import partial
import torch.nn.functional as F
__all__ = ['convnext_tiny', 'convnext_small', 'convnext_base', 'convnext_large',
'convnext_tiny_cbam', 'convnext_small_cbam', 'convnext_base_cbam', 'convnext_large_cbam']
model_urls = {
'convnext_tiny': 'https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth',
'convnext_small': 'https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth',
'convnext_base': 'https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth',
'convnext_large': 'https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth',
}
class LayerNorm(nn.Module):
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
super().__init__()
self.weight = nn.Parameter(torch.ones(normalized_shape))
self.bias = nn.Parameter(torch.zeros(normalized_shape))
self.eps = eps
self.data_format = data_format
if self.data_format not in ["channels_last", "channels_first"]:
raise NotImplementedError
self.normalized_shape = (normalized_shape,)
def forward(self, x):
if self.data_format == "channels_last":
return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
elif self.data_format == "channels_first":
u = x.mean(1, keepdim=True)
s = (x - u).pow(2).mean(1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
x = self.weight[:, None, None] * x + self.bias[:, None, None]
return x
class Block(nn.Module):
def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6, cbam=None):
super().__init__()
self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim)
self.norm = LayerNorm(dim, eps=1e-6)
self.pwconv1 = nn.Linear(dim, 4 * dim)
self.act = nn.GELU()
self.pwconv2 = nn.Linear(4 * dim, dim)
self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)),
requires_grad=True) if layer_scale_init_value > 0 else None
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.cbam = CBAMLayer(dim) if cbam else None
def forward(self, x):
input = x
x = self.dwconv(x)
x = x.permute(0, 2, 3, 1)
x = self.norm(x)
x = self.pwconv1(x)
x = self.act(x)
x = self.pwconv2(x)
if self.gamma is not None:
x = self.gamma * x
x = x.permute(0, 3, 1, 2)
if self.cbam is not None:
x = self.cbam(x)
x = input + self.drop_path(x)
return x
class DropPath(nn.Module):
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
if self.drop_prob == 0. or not self.training:
return x
keep_prob = 1 - self.drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
random_tensor.floor_()
output = x.div(keep_prob) * random_tensor
return output
class ConvNeXt(nn.Module):
def __init__(self, in_chans=3, num_classes=1000,
depths=[3, 3, 9, 3], dims=[96, 192, 384, 768],
drop_path_rate=0., layer_scale_init_value=1e-6, cbam=None):
super().__init__()
self.downsample_layers = nn.ModuleList()
stem = nn.Sequential(
nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
LayerNorm(dims[0], eps=1e-6, data_format="channels_first")
)
self.downsample_layers.append(stem)
for i in range(3):
downsample_layer = nn.Sequential(
LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
nn.Conv2d(dims[i], dims[i + 1], kernel_size=2, stride=2),
)
self.downsample_layers.append(downsample_layer)
self.stages = nn.ModuleList()
dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
cur = 0
for i in range(4):
stage = nn.Sequential(
*[Block(dim=dims[i], drop_path=dp_rates[cur + j],
layer_scale_init_value=layer_scale_init_value,
cbam=cbam) for j in range(depths[i])]
)
self.stages.append(stage)
cur += depths[i]
self.norm = nn.LayerNorm(dims[-1], eps=1e-6)
self.head = nn.Linear(dims[-1], num_classes) if num_classes > 0 else nn.Identity()
self.feature_dim = dims[-1]
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, (nn.Conv2d, nn.Linear)):
nn.init.trunc_normal_(m.weight, std=0.02)
nn.init.constant_(m.bias, 0)
def forward(self, x):
features = []
for i in range(4):
x = self.downsample_layers[i](x)
x = self.stages[i](x)
features.append(x)
x = x.mean([-2, -1])
x = self.norm(x)
x = self.head(x)
return x, features
def get_features(self):
return nn.Sequential(
self.downsample_layers[0],
self.stages[0],
self.downsample_layers[1],
self.stages[1],
self.downsample_layers[2],
self.stages[2],
self.downsample_layers[3],
self.stages[3],
)
def load_state_dict(self, state_dict, strict=True):
model_dict = self.state_dict()
pretrained_dict = {k: v for k, v in state_dict.items()
if k in model_dict and model_dict[k].size() == v.size()}
if len(pretrained_dict) == len(state_dict):
logging.info('%s: All params loaded' % type(self).__name__)
else:
logging.info('%s: Some params were not loaded:' % type(self).__name__)
not_loaded_keys = [k for k in state_dict.keys() if k not in pretrained_dict.keys()]
logging.info(('%s, ' * (len(not_loaded_keys) - 1) + '%s') % tuple(not_loaded_keys))
model_dict.update(pretrained_dict)
super(ConvNeXt, self).load_state_dict(model_dict)
# ConvNeXt 变体函数
def convnext_tiny(pretrained=False, **kwargs):
model = ConvNeXt(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], **kwargs)
if pretrained:
url = model_urls['convnext_tiny']
state_dict = model_zoo.load_url(url)
model.load_state_dict(state_dict)
return model
def convnext_small(pretrained=False, **kwargs):
model = ConvNeXt(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], **kwargs)
if pretrained:
url = model_urls['convnext_small']
state_dict = model_zoo.load_url(url)
model.load_state_dict(state_dict)
return model
def convnext_base(pretrained=False, **kwargs):
model = ConvNeXt(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs)
if pretrained:
url = model_urls['convnext_base']
state_dict = model_zoo.load_url(url)
model.load_state_dict(state_dict)
return model
def convnext_large(pretrained=False, **kwargs):
model = ConvNeXt(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs)
if pretrained:
url = model_urls['convnext_large']
state_dict = model_zoo.load_url(url)
model.load_state_dict(state_dict)
return model
# 带CBAM的ConvNeXt变体
def convnext_tiny_cbam(pretrained=False, **kwargs):
return convnext_tiny(pretrained=pretrained, cbam=True, **kwargs)
def convnext_small_cbam(pretrained=False, **kwargs):
return convnext_small(pretrained=pretrained, cbam=True, **kwargs)
def convnext_base_cbam(pretrained=False, **kwargs):
return convnext_base(pretrained=pretrained, cbam=True, **kwargs)
def convnext_large_cbam(pretrained=False, **kwargs):
return convnext_large(pretrained=pretrained, cbam=True, **kwargs)修改上述convnext.py和下述wsdan.py代码,并给我提供完整代码"""
WS-DAN models
Hu et al.,
"See Better Before Looking Closer: Weakly Supervised Data Augmentation Network for Fine-Grained Visual Classification",
arXiv:1901.09891
Created: May 04,2019 - Yuchong Gu
Revised: Dec 03,2019 - Yuchong Gu
Revised: May 26,2025 - 增加EfficientNet支持
"""
import logging
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import models.vgg as vgg
import models.resnet as resnet
import models.efficientnet as efficientnet
import models.convnext as convnext
from models.inception import inception_v3, BasicConv2d
__all__ = ['WSDAN']
EPSILON = 1e-12
# Bilinear Attention Pooling
class BAP(nn.Module):
def __init__(self, pool='GAP'):
super(BAP, self).__init__()
assert pool in ['GAP', 'GMP']
if pool == 'GAP':
self.pool = None
else:
self.pool = nn.AdaptiveMaxPool2d(1)
def forward(self, features, attentions):
B, C, H, W = features.size()
_, M, AH, AW = attentions.size()
# match size
if AH != H or AW != W:
attentions = F.upsample_bilinear(attentions, size=(H, W))
# feature_matrix: (B, M, C) -> (B, M * C)
if self.pool is None:
feature_matrix = (torch.einsum('imjk,injk->imn', (attentions, features)) / float(H * W)).view(B, -1)
else:
feature_matrix = []
for i in range(M):
AiF = self.pool(features * attentions[:, i:i + 1, ...]).view(B, -1)
feature_matrix.append(AiF)
feature_matrix = torch.cat(feature_matrix, dim=1)
# sign-sqrt
feature_matrix = torch.sign(feature_matrix) * torch.sqrt(torch.abs(feature_matrix) + EPSILON)
# l2 normalization along dimension M and C
feature_matrix = F.normalize(feature_matrix, dim=-1)
return feature_matrix
# WS-DAN: Weakly Supervised Data Augmentation Network for FGVC
class WSDAN(nn.Module):
def __init__(self, num_classes, M=32, net='inception_mixed_6e', pretrained=False):
super(WSDAN, self).__init__()
self.num_classes = num_classes
self.M = M
self.net = net
# EfficientNet配置
self.EFFICIENTNET_CONFIG = {
'efficientnet_b0': {'stage': 3, 'feat_dim': 1280},
'efficientnet_b1': {'stage': 3, 'feat_dim': 1280},
'efficientnet_b2': {'stage': 3, 'feat_dim': 1408},
'efficientnet_b3': {'stage': 3, 'feat_dim': 1536},
'efficientnet_b4': {'stage': 3, 'feat_dim': 1792},
'efficientnet_b5': {'stage': 3, 'feat_dim': 2048},
'efficientnet_b6': {'stage': 3, 'feat_dim': 2304},
'efficientnet_b7': {'stage': 3, 'feat_dim': 2560},
'efficientnet_v2_s': {'stage': 3, 'feat_dim': 1280},
'efficientnet_v2_m': {'stage': 3, 'feat_dim': 1280},
'efficientnet_v2_l': {'stage': 3, 'feat_dim': 1280},
}
self.CONVNEXT_CONFIG = {
'convnext_tiny': {'feat_dim': 768},
'convnext_small': {'feat_dim': 768},
'convnext_base': {'feat_dim': 1024},
'convnext_large': {'feat_dim': 1536},
}
# 其他骨干网络配置
self.BACKBONE_CONFIG = {
**self.CONVNEXT_CONFIG,
**self.EFFICIENTNET_CONFIG,
# 保持其他网络配置不变
'inception_mixed_6e': {'stage': 0, 'feat_dim': 768},
'inception_mixed_7c': {'stage': 0, 'feat_dim': 2048},
'vgg16': {'stage': 0, 'feat_dim': 512},
'resnet50': {'stage': 0, 'feat_dim': 2048},
# 可以添加更多网络配置
}
# Network Initialization
if 'convnext' in net: # 新增ConvNeXt处理
assert net in self.CONVNEXT_CONFIG, f"Unsupported ConvNeXt version: {net}"
cfg = self.CONVNEXT_CONFIG[net]
model = getattr(convnext, net)(pretrained=pretrained)
self.features = model.get_features()
self.num_features = cfg['feat_dim']
elif 'efficientnet' in net:
assert net in self.EFFICIENTNET_CONFIG, f"Unsupported EfficientNet version: {net}"
cfg = self.EFFICIENTNET_CONFIG[net]
# 获取EfficientNet模型
model = getattr(efficientnet, net)(pretrained=pretrained)
# 获取特征提取器和特征维度
self.features, self.num_features = model.get_features_blocks(index=cfg['stage'])
elif 'inception' in net:
if net == 'inception_mixed_6e':
self.features = inception_v3(pretrained=pretrained).get_features_mixed_6e()
self.num_features = 768
elif net == 'inception_mixed_7c':
self.features = inception_v3(pretrained=pretrained).get_features_mixed_7c()
self.num_features = 2048
else:
raise ValueError('Unsupported net: %s' % net)
elif 'vgg' in net:
self.features = getattr(vgg, net)(pretrained=pretrained).get_features()
self.num_features = 512
elif 'resnet' in net:
self.features = getattr(resnet, net)(pretrained=pretrained).get_features()
self.num_features = 512 * self.features[-1][-1].expansion
else:
raise ValueError('Unsupported net: %s' % net)
# Attention Maps
self.attentions = BasicConv2d(self.num_features, self.M, kernel_size=1)
# Bilinear Attention Pooling
self.bap = BAP(pool='GAP')
# Classification Layer
self.fc = nn.Linear(self.M * self.num_features, self.num_classes, bias=False)
logging.info('WSDAN: using {} as feature extractor, num_classes: {}, num_attentions: {}'.format(net, self.num_classes, self.M))
def forward(self, x):
batch_size = x.size(0)
if 'convnext' in self.net: # ConvNeXt特殊处理
x = self.features(x)
feature_maps = x # ConvNeXt的特征图直接输出
else:
feature_maps = self.features(x)
# Feature Maps, Attention Maps and Feature Matrix
if self.net != 'inception_mixed_7c':
attention_maps = self.attentions(feature_maps)
else:
attention_maps = feature_maps[:, :self.M, ...]
feature_matrix = self.bap(feature_maps, attention_maps)
# Classification
p = self.fc(feature_matrix * 100.)
# Generate Attention Map
if self.training:
# Randomly choose one of attention maps Ak
attention_map = []
for i in range(batch_size):
attention_weights = torch.sqrt(attention_maps[i].sum(dim=(1, 2)).detach() + EPSILON)
attention_weights = F.normalize(attention_weights, p=1, dim=0)
k_index = np.random.choice(self.M, 2, p=attention_weights.cpu().numpy())
attention_map.append(attention_maps[i, k_index, ...])
attention_map = torch.stack(attention_map) # (B, 2, H, W) - one for cropping, the other for dropping
else:
# Object Localization Am = mean(Ak)
attention_map = torch.mean(attention_maps, dim=1, keepdim=True) # (B, 1, H, W)
# p: (B, self.num_classes)
# feature_matrix: (B, M * C)
# attention_map: (B, 2, H, W) in training, (B, 1, H, W) in val/testing
return p, feature_matrix, attention_map
def load_state_dict(self, state_dict, strict=True):
model_dict = self.state_dict()
pretrained_dict = {k: v for k, v in state_dict.items()
if k in model_dict and model_dict[k].size() == v.size()}
if len(pretrained_dict) == len(state_dict):
logging.info('%s: All params loaded' % type(self).__name__)
else:
logging.info('%s: Some params were not loaded:' % type(self).__name__)
not_loaded_keys = [k for k in state_dict.keys() if k not in pretrained_dict.keys()]
logging.info(('%s, ' * (len(not_loaded_keys) - 1) + '%s') % tuple(not_loaded_keys))
model_dict.update(pretrained_dict)
super(WSDAN, self).load_state_dict(model_dict)