PyTorch生成式人工智能——ContraGAN详解与实现

0. 前言

生成对抗网络 (Generative Adversarial Network, GAN) 的发展历程中,如何有效利用条件信息一直是研究的重点。ContraGAN 通过引入对比学习机制,创新性地解决了传统条件 GAN 的条件信息利用不足问题。本节将深入解析 ContraGAN 的技术原理,并使用 PyTorch 从零开始实现 ContraGAN

1. ContraGAN 介绍

1.1 核心原理

传统条件GAN (conditional Generative Adversarial, cGAN) 通过让生成器 G ( z , y ) G(z,y) G(z,y) 按类条件 y y y 生成样本,判别器 D ( x , y ) D(x,y) D(x,y) 判别真假。ContraGAN 的想法是在判别器内部引入对比学习,使判别器不仅区分真伪,还要把相同类(无论是真实还是生成)的样本在特征空间中拉近,而把不同类样本拉远:

  • 判别器的表征更语义化 (class-semantics),有助判别器给出更有信息的梯度给生成器
  • 生成器会试图生成与目标类在判别器特征上贴近的样本(通过对抗损失间接实现),从而提高条件一致性。

1.2 损失函数

ContraGAN 的核心创新是引入条件对比损失 (Conditional Contrastive Loss, 2C loss):

  • 最大化同一类别样本的相似度
  • 最小化不同类别样本的相似度

判别器对抗损失(采用 hinge loss):
L D a d v = E x ∼ p d a t a [ m a x ⁡ ( 0 , 1 − D ( x , y ) ) ] + E x ~ ∼ G [ m a x ⁡ ( 0 , 1 + D ( x ~ , y ) ) ] \mathcal L_D^{adv}=E_{x∼p_{data}}[max⁡(0,1−D(x,y))]+E_{\tilde x∼G}[max⁡(0,1+D(\tilde x,y))] LDadv=Expdata[max(0,1D(x,y))]+Ex~G[max(0,1+D(x~,y))]

生成器对抗损失(采用 hinge loss):
L G a d v = − E x ~ ∼ G [ D ( x ~ , y ) ] \mathcal L_G^{adv}=−E_{\tilde x∼G}[D(\tilde x,y)] LGadv=Ex~G[D(x~,y)]
判别器条件对比损失:
将判别器的一个中间特征映射为向量 f ( x ) f(x) f(x) 并归一化。给定 batch 中的特征矩阵 F F F 和对应标签 y y y (真实样本与生成样本都赋予它们的条件标签),对每个样本 i i i
ℓ i = − l o g ⁡ ∑ j ∈ P ( i ) e x p ⁡ ( f i ⋅ f j τ ) ∑ k ≠ i e x p ⁡ ( f i ⋅ f k τ ) ℓ_i=−log⁡\frac{∑_{j∈P(i)}exp⁡(\frac{fi⋅fj}τ)}{∑_{k≠i}exp⁡(\frac {fi⋅fk}τ)} i=logk=iexp(τfifk)jP(i)exp(τfifj)
其中 P ( i ) P(i) P(i) 是与 i i i 同类且 k ≠ i k≠i k=i 的样本集合, τ τ τ 为温度系数。最终对所有有效 i i i (记为 V V V)取平均。
L D c o n t r a s t = 1 ∣ V ∣ ​ ∑ i ∈ V ​ ℓ i \mathcal L_D^{contrast}=\frac 1{|V|}​\sum_{i∈V}​ℓi LDcontrast=V1iVi
总判别器损失:
L D = L D a d v + λ c L D c o n t r a s t \mathcal L_D=\mathcal L_D^{adv}+λc\mathcal L_D^{contrast} LD=LDadv+λcLDcontrast

2. ContraGAN 架构

ContraGAN 架构如下图所示:
ContraGAN
模型训练流程如下:

  • 更新 D
    • 用当前 G 生成的 x ~ \tilde x x~ (detach(),不把梯度回传给 G),计算 D 的对抗损失 + 对比损失
    • 对比损失的正样本对 (同类别样本) :<真实,真实>、<生成,生成>、<真实,生成>都可以当作正样本对
  • 更新 G
    • 只使用对抗损失 L G a d v \mathcal L_G^{adv} LGadv,为了让 D ( x ~ , y ) D(\tilde x,y) D(x~,y) 上升,G 会优化让 f ( x ~ ) f(\tilde x) f(x~) 靠近“该类簇 + 与 e ( y ) e(y) e(y) 对齐”的方向,其中 e ( y ) ∈ R d e(y)∈\mathbb R^d e(y)Rd 为类嵌入
    • 间接地,G 会靠近判别器的类簇中心

3. 实现 ContraGAN

3.1 模型构建

(1) 首先导入所需库,并定义设备:

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm

# 检查GPU可用性
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

(2) 定义生成器类 Generator, 输入为噪声向量 (z_dim) + 条件标签 (num_classes),输出为生成的图像:

class Generator(nn.Module):
    def __init__(self, z_dim=100, num_classes=10, channels=3, img_size=32):
        super(Generator, self).__init__()
        self.img_size = img_size
        self.z_dim = z_dim
        
        # 标签嵌入层
        self.label_emb = nn.Embedding(num_classes, num_classes)
        
        # 初始全连接层,将噪声和标签映射到特征图
        self.init_size = img_size // 4  # 初始特征图大小
        self.l1 = nn.Sequential(
            nn.Linear(z_dim + num_classes, 128 * self.init_size**2),
            nn.BatchNorm1d(128 * self.init_size**2),
            nn.LeakyReLU(0.2, inplace=True))
        
        # 转置卷积块
        self.conv_blocks = nn.Sequential(
            # 上采样到 1/2
            nn.ConvTranspose2d(128, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            
            # 上采样到原始尺寸
            nn.ConvTranspose2d(128, 64, 4, 2, 1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),
            
            # 输出层
            nn.Conv2d(64, channels, 3, 1, 1),
            nn.Tanh()
        )
    
    def forward(self, z, labels):
        # 将噪声向量和条件标签拼接
        label_input = self.label_emb(labels)
        x = torch.cat([z, label_input], dim=1)
        
        # 通过全连接层并reshape
        out = self.l1(x)
        out = out.view(out.shape[0], 128, self.init_size, self.init_size)
        
        # 通过转置卷积块
        img = self.conv_blocks(out)
        return img

(3) 定义判别器 Discriminator,输入为图像,输出为图像真实性概率 + 特征向量(用于对比损失):

class Discriminator(nn.Module):
    def __init__(self, num_classes=10, channels=3, img_size=32):
        super(Discriminator, self).__init__()
        self.img_size = img_size
        
        # 共享的特征提取层
        self.feature_extractor = nn.Sequential(
            nn.Conv2d(channels, 64, 4, 2, 1),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(64, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(128, 256, 4, 2, 1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
        )
        
        # 计算特征图大小
        feature_size = img_size // 8
        
        # 真实性判别分支
        self.adv_layer = nn.Sequential(
            nn.Linear(256 * feature_size * feature_size, 1),
            nn.Sigmoid()
        )
        
        # 对比学习分支(输出特征向量)
        self.contrastive_layer = nn.Sequential(
            nn.Linear(256 * feature_size * feature_size, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 128)  # 输出128维特征向量
        )
    
    def forward(self, img):
        features = self.feature_extractor(img)
        features = features.view(features.size(0), -1)
        
        # 真实性判断
        validity = self.adv_layer(features)
        
        # 对比学习特征
        contra_features = self.contrastive_layer(features)
        return validity, contra_features

(4) 实现对比损失函数,用于计算同一批次内样本间的对比损失:

class ContrastiveLoss(nn.Module):
    def __init__(self, temperature=0.5):
        super(ContrastiveLoss, self).__init__()
        self.temperature = temperature
    
    def forward(self, features, labels):
        batch_size = features.size(0)
        
        # 计算特征间的相似度矩阵
        features_norm = nn.functional.normalize(features, dim=1)
        similarity_matrix = torch.matmul(features_norm, features_norm.T) / self.temperature
        
        # 创建正样本掩码
        labels = labels.view(-1, 1)
        mask = torch.eq(labels, labels.T).float().to(device)
        
        # 创建负样本掩码
        negative_mask = 1 - mask
        
        # 计算正样本相似度
        positives = similarity_matrix * mask
        positives = positives - torch.eye(batch_size).to(device) * 1e12  # 排除自身
        
        # 计算负样本相似度
        negatives = similarity_matrix * negative_mask
        
        # 计算对比损失
        numerator = torch.exp(positives).sum(dim=1)
        denominator = torch.exp(negatives).sum(dim=1) + numerator
        loss = -torch.log(numerator / (denominator + 1e-8))
        
        return loss.mean()

3.2 模型训练

模型构建完成后,使用 CIFAR-10 数据集训练 ContraGAN
(1) 定义超参数,并初始化模型、优化器和损失函数:

z_dim = 100
num_classes = 10
img_size = 32
channels = 3
batch_size = 128
epochs = 100
lr = 0.0002
temperature = 0.5

generator = Generator(z_dim, num_classes, channels, img_size).to(device)
discriminator = Discriminator(num_classes, channels, img_size).to(device)
contrastive_loss = ContrastiveLoss(temperature)

optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))

adversarial_loss = nn.BCELoss()

(2) 加载 CIFAR-10 数据集:

transform = transforms.Compose([
    transforms.Resize(img_size),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

train_dataset = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform
)

train_loader = DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True, num_workers=4
)

(3) 定义训练循环,训练 ContraGAN 模型 100epoch,且每隔 10epoch 保存生成的伪造样本:

# 训练循环
for epoch in range(epochs):
    progress_bar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{epochs}')
    
    for i, (imgs, labels) in enumerate(progress_bar):
        batch_size = imgs.size(0)
        real_imgs = imgs.to(device)
        real_labels = labels.to(device)
        
        # 创建真实和伪造的标签
        valid = torch.ones(batch_size, 1).to(device)
        fake = torch.zeros(batch_size, 1).to(device)
        
        #  训练判别器
        optimizer_D.zero_grad()
        
        # 真实样本的损失
        real_validity, real_contra = discriminator(real_imgs)
        d_real_loss = adversarial_loss(real_validity, valid)
        
        # 生成伪造样本
        z = torch.randn(batch_size, z_dim).to(device)
        gen_labels = torch.randint(0, num_classes, (batch_size,)).to(device)
        gen_imgs = generator(z, gen_labels)
        
        # 伪造样本的损失
        fake_validity, fake_contra = discriminator(gen_imgs.detach())
        d_fake_loss = adversarial_loss(fake_validity, fake)
        
        # 对比损失
        contra_features = torch.cat([real_contra, fake_contra], dim=0)
        contra_labels = torch.cat([real_labels, gen_labels], dim=0)
        contra_loss = contrastive_loss(contra_features, contra_labels)
        
        # 总判别器损失
        d_loss = d_real_loss + d_fake_loss + contra_loss
        d_loss.backward()
        optimizer_D.step()
        
        #  训练生成器
        optimizer_G.zero_grad()
        
        # 生成器试图欺骗判别器
        validity, _ = discriminator(gen_imgs)
        g_loss = adversarial_loss(validity, valid)
        
        g_loss.backward()
        optimizer_G.step()
        
        # 更新进度条
        progress_bar.set_postfix(
            d_loss=d_loss.item(), 
            g_loss=g_loss.item(),
            contra_loss=contra_loss.item()
        )

    if (epoch + 1) % 10 == 0:
        with torch.no_grad():
            # 生成测试样本
            test_z = torch.randn(25, z_dim).to(device)
            test_labels = torch.arange(0, 10).repeat(3).to(device)[:25]  # 生成0-9的标签
            gen_imgs = generator(test_z, test_labels)
            
            # 保存图像
            fig, axs = plt.subplots(5, 5, figsize=(10, 10))
            cnt = 0
            for i in range(5):
                for j in range(5):
                    axs[i, j].imshow(gen_imgs[cnt].cpu().permute(1, 2, 0) * 0.5 + 0.5)
                    axs[i, j].set_title(f"Label: {test_labels[cnt].item()}")
                    axs[i, j].axis('off')
                    cnt += 1
            plt.savefig(f"contragan_epoch_{epoch+1}.png")
            plt.close()

训练过程生成,生成图像如下所示,可以看到,随着训练的进行模型生成的图像逐渐清晰:

训练过程

3.3 结果可视化

(1) 使用训练完成的 ContraGAN 模型生成不同类别的样本:

with torch.no_grad():
    z = torch.randn(10, z_dim).to(device)
    labels = torch.arange(0, 10).to(device)
    gen_imgs = generator(z, labels)
    
    # 可视化结果
    fig, axs = plt.subplots(2, 5, figsize=(15, 6))
    for i in range(2):
        for j in range(5):
            idx = i * 5 + j
            img = gen_imgs[idx].cpu().permute(1, 2, 0) * 0.5 + 0.5
            axs[i, j].imshow(img)
            axs[i, j].set_title(f"Class: {labels[idx].item()}")
            axs[i, j].axis('off')
    plt.tight_layout()
    plt.show()

生成结果

(2) 使用训练完成的 ContraGAN 模型生成指定类别的多个样本并可视化:

import matplotlib.pyplot as plt
import numpy as np
from torchvision.utils import make_grid

def generate_same_class_images(generator, num_samples=10, class_idx=0, z_dim=100, device='cuda'):
    generator.eval()
    
    # 生成相同类别的样本
    with torch.no_grad():
        # 创建相同类别的标签
        labels = torch.full((num_samples,), class_idx, dtype=torch.long).to(device)
        
        # 生成不同的噪声向量
        z = torch.randn(num_samples, z_dim).to(device)
        
        # 生成图像
        gen_imgs = generator(z, labels).cpu()
    
    # 将图像从[-1,1]转换到[0,1]范围
    gen_imgs = gen_imgs * 0.5 + 0.5
    grid = make_grid(gen_imgs, nrow=int(np.sqrt(num_samples)), padding=2, normalize=False)
    plt.figure(figsize=(10, 10))
    plt.imshow(grid.permute(1, 2, 0))
    plt.title(f'Generated Samples for Class {class_idx}', fontsize=16)
    plt.axis('off')
    plt.show()

def visualize_all_classes(generator, num_classes=10, samples_per_class=5, z_dim=100, img_size=32):
    generator.eval()
    plt.figure(figsize=(15, 15))
    
    with torch.no_grad():
        for class_idx in range(num_classes):
            # 生成当前类别的样本
            z = torch.randn(samples_per_class, z_dim).to(device)
            labels = torch.full((samples_per_class,), class_idx, dtype=torch.long).to(device)
            gen_imgs = generator(z, labels).cpu()
            gen_imgs = gen_imgs * 0.5 + 0.5  # 反归一化
            
            # 在当前类别的子图中显示
            for i in range(samples_per_class):
                ax = plt.subplot(num_classes, samples_per_class, 
                                class_idx * samples_per_class + i + 1)
                ax.imshow(gen_imgs[i].permute(1, 2, 0))
                ax.axis('off')
                
                # 只在每行的第一个图像显示类别标签
                if i == 0:
                    ax.set_title(f'Class {class_idx}', fontsize=12, y=0.9)
    
    plt.tight_layout()
    plt.suptitle('ContraGAN Generated Samples by Class', y=1.02, fontsize=16)
    plt.show()

# 可视化单个类别的样本
generate_same_class_images(generator, num_samples=9, class_idx=3)

# 可视化所有类别的样本
visualize_all_classes(generator, num_classes=10, samples_per_class=10)

生成结果

相关链接

PyTorch生成式人工智能实战:从零打造创意引擎
PyTorch生成式人工智能(1)——神经网络与模型训练过程详解
PyTorch生成式人工智能(2)——PyTorch基础
PyTorch生成式人工智能(3)——使用PyTorch构建神经网络
PyTorch生成式人工智能(4)——卷积神经网络详解
PyTorch生成式人工智能(5)——分类任务详解
PyTorch生成式人工智能(6)——生成模型(Generative Model)详解
PyTorch生成式人工智能(7)——生成对抗网络实践详解
PyTorch生成式人工智能(8)——深度卷积生成对抗网络
PyTorch生成式人工智能(9)——Pix2Pix详解与实现
PyTorch生成式人工智能(10)——CyclelGAN详解与实现
PyTorch生成式人工智能(11)——神经风格迁移
PyTorch生成式人工智能(12)——StyleGAN详解与实现
PyTorch生成式人工智能(13)——WGAN详解与实现
PyTorch生成式人工智能(14)——条件生成对抗网络(conditional GAN,cGAN)
PyTorch生成式人工智能(15)——自注意力生成对抗网络(Self-Attention GAN, SAGAN)
PyTorch生成式人工智能(16)——自编码器(AutoEncoder)详解
PyTorch生成式人工智能(17)——变分自编码器详解与实现
PyTorch生成式人工智能(18)——循环神经网络详解与实现
PyTorch生成式人工智能(19)——自回归模型详解与实现
PyTorch生成式人工智能(20)——像素卷积神经网络(PixelCNN)
PyTorch生成式人工智能(21)——归一化流模型(Normalizing Flow Model)
PyTorch生成式人工智能(27)——从零开始训练GPT模型
PyTorch生成式人工智能(28)——MuseGAN详解与实现

评论 11
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

盼小辉丶

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值