class MYGenerator(nn.Module):
def __init__(self, n_residual=6):
super().__init__()
# 编码器(下采样 + 残差)
self.down1 = nn.Sequential(
nn.Conv2d(3, 64, 4, 2, 1),
nn.InstanceNorm2d(64),
nn.ReLU(inplace=True),
CBAM(64),
ResidualBlock(64) # 新增残差块
)
self.down2 = nn.Sequential(
nn.Conv2d(64, 128, 4, 2, 1),
nn.InstanceNorm2d(128),
nn.ReLU(inplace=True),
CBAM(128),
ResidualBlock(128) # 新增残差块
)
self.down3 = nn.Sequential(
nn.Conv2d(128, 256, 4, 2, 1),
nn.InstanceNorm2d(256),
nn.ReLU(inplace=True),
CBAM(256),
ResidualBlock(256) # 新增残差块
)
self.down4 = nn.Sequential(
nn.Conv2d(256, 512, 4, 2, 1),
nn.InstanceNorm2d(512),
nn.ReLU(inplace=True),
CBAM(512),
ResidualBlock(512) # 新增残差块
)
# 瓶颈层(堆叠残差块)
self.bottleneck = nn.Sequential(
*[ResidualBlock(512) for _ in range(n_residual)] # 新增6个残差块
)
# 解码器(调整通道数适应拼接)
self.up1 = nn.Sequential(
nn.ConvTranspose2d(512, 256, 4, 2, 1),
nn.InstanceNorm2d(256),
nn.ReLU(inplace=True)
)
self.up2 = nn.Sequential(
nn.ConvTranspose2d(512, 128, 4, 2, 1), # 输入512=256(up1输出)+256(d3)
nn.InstanceNorm2d(128),
nn.ReLU(inplace=True)
)
self.up3 = nn.Sequential(
nn.ConvTranspose2d(256, 64, 4, 2, 1), # 输入256=128(up2输出)+128(d2)
nn.InstanceNorm2d(64),
nn.ReLU(inplace=True)
)
self.up4 = nn.Sequential(
nn.ConvTranspose2d(128, 3, 4, 2, 1), # 输入128=64(up3输出)+64(d1)
nn.Tanh()
)
def forward(self, x):
# 编码
d1 = self.down1(x) # 64x128x128
d2 = self.down2(d1) # 128x64x64
d3 = self.down3(d2) # 256x32x32
d4 = self.down4(d3) # 512x16x16
# 瓶颈强化
d4 = self.bottleneck(d4) # 新增
# 解码(跳跃连接)
u1 = self.up1(d4) # 256x32x32
u1 = torch.cat([u1, d3], dim=1) # 512x32x32
u2 = self.up2(u1) # 128x64x64
u2 = torch.cat([u2, d2], dim=1) # 256x64x64
u3 = self.up3(u2) # 64x128x128
u3 = torch.cat([u3, d1], dim=1) # 128x128x128
return self.up4(u3) # 3x256x256
该生成器应该如何让修改