PyTorch生成式人工智能——ContraGAN详解与实现
0. 前言
在生成对抗网络 (Generative Adversarial Network, GAN) 的发展历程中,如何有效利用条件信息一直是研究的重点。ContraGAN
通过引入对比学习机制,创新性地解决了传统条件 GAN
的条件信息利用不足问题。本节将深入解析 ContraGAN
的技术原理,并使用 PyTorch
从零开始实现 ContraGAN
。
1. ContraGAN 介绍
1.1 核心原理
传统条件GAN (conditional Generative Adversarial, cGAN) 通过让生成器
G
(
z
,
y
)
G(z,y)
G(z,y) 按类条件
y
y
y 生成样本,判别器
D
(
x
,
y
)
D(x,y)
D(x,y) 判别真假。ContraGAN
的想法是在判别器内部引入对比学习,使判别器不仅区分真伪,还要把相同类(无论是真实还是生成)的样本在特征空间中拉近,而把不同类样本拉远:
- 判别器的表征更语义化 (
class-semantics
),有助判别器给出更有信息的梯度给生成器 - 生成器会试图生成与目标类在判别器特征上贴近的样本(通过对抗损失间接实现),从而提高条件一致性。
1.2 损失函数
ContraGAN
的核心创新是引入条件对比损失 (Conditional Contrastive Loss
, 2C loss
):
- 最大化同一类别样本的相似度
- 最小化不同类别样本的相似度
判别器对抗损失(采用 hinge loss
):
L
D
a
d
v
=
E
x
∼
p
d
a
t
a
[
m
a
x
(
0
,
1
−
D
(
x
,
y
)
)
]
+
E
x
~
∼
G
[
m
a
x
(
0
,
1
+
D
(
x
~
,
y
)
)
]
\mathcal L_D^{adv}=E_{x∼p_{data}}[max(0,1−D(x,y))]+E_{\tilde x∼G}[max(0,1+D(\tilde x,y))]
LDadv=Ex∼pdata[max(0,1−D(x,y))]+Ex~∼G[max(0,1+D(x~,y))]
生成器对抗损失(采用 hinge loss
):
L
G
a
d
v
=
−
E
x
~
∼
G
[
D
(
x
~
,
y
)
]
\mathcal L_G^{adv}=−E_{\tilde x∼G}[D(\tilde x,y)]
LGadv=−Ex~∼G[D(x~,y)]
判别器条件对比损失:
将判别器的一个中间特征映射为向量
f
(
x
)
f(x)
f(x) 并归一化。给定 batch
中的特征矩阵
F
F
F 和对应标签
y
y
y (真实样本与生成样本都赋予它们的条件标签),对每个样本
i
i
i:
ℓ
i
=
−
l
o
g
∑
j
∈
P
(
i
)
e
x
p
(
f
i
⋅
f
j
τ
)
∑
k
≠
i
e
x
p
(
f
i
⋅
f
k
τ
)
ℓ_i=−log\frac{∑_{j∈P(i)}exp(\frac{fi⋅fj}τ)}{∑_{k≠i}exp(\frac {fi⋅fk}τ)}
ℓi=−log∑k=iexp(τfi⋅fk)∑j∈P(i)exp(τfi⋅fj)
其中
P
(
i
)
P(i)
P(i) 是与
i
i
i 同类且
k
≠
i
k≠i
k=i 的样本集合,
τ
τ
τ 为温度系数。最终对所有有效
i
i
i (记为
V
V
V)取平均。
L
D
c
o
n
t
r
a
s
t
=
1
∣
V
∣
∑
i
∈
V
ℓ
i
\mathcal L_D^{contrast}=\frac 1{|V|}\sum_{i∈V}ℓi
LDcontrast=∣V∣1i∈V∑ℓi
总判别器损失:
L
D
=
L
D
a
d
v
+
λ
c
L
D
c
o
n
t
r
a
s
t
\mathcal L_D=\mathcal L_D^{adv}+λc\mathcal L_D^{contrast}
LD=LDadv+λcLDcontrast
2. ContraGAN 架构
ContraGAN
架构如下图所示:
模型训练流程如下:
- 更新
D
:- 用当前
G
生成的 x ~ \tilde x x~ (detach()
,不把梯度回传给G
),计算D
的对抗损失 + 对比损失 - 对比损失的正样本对 (同类别样本) :<真实,真实>、<生成,生成>、<真实,生成>都可以当作正样本对
- 用当前
- 更新
G
:- 只使用对抗损失
L
G
a
d
v
\mathcal L_G^{adv}
LGadv,为了让
D
(
x
~
,
y
)
D(\tilde x,y)
D(x~,y) 上升,
G
会优化让 f ( x ~ ) f(\tilde x) f(x~) 靠近“该类簇 + 与 e ( y ) e(y) e(y) 对齐”的方向,其中 e ( y ) ∈ R d e(y)∈\mathbb R^d e(y)∈Rd 为类嵌入 - 间接地,
G
会靠近判别器的类簇中心
- 只使用对抗损失
L
G
a
d
v
\mathcal L_G^{adv}
LGadv,为了让
D
(
x
~
,
y
)
D(\tilde x,y)
D(x~,y) 上升,
3. 实现 ContraGAN
3.1 模型构建
(1) 首先导入所需库,并定义设备:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
# 检查GPU可用性
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
(2) 定义生成器类 Generator
, 输入为噪声向量 (z_dim
) + 条件标签 (num_classes
),输出为生成的图像:
class Generator(nn.Module):
def __init__(self, z_dim=100, num_classes=10, channels=3, img_size=32):
super(Generator, self).__init__()
self.img_size = img_size
self.z_dim = z_dim
# 标签嵌入层
self.label_emb = nn.Embedding(num_classes, num_classes)
# 初始全连接层,将噪声和标签映射到特征图
self.init_size = img_size // 4 # 初始特征图大小
self.l1 = nn.Sequential(
nn.Linear(z_dim + num_classes, 128 * self.init_size**2),
nn.BatchNorm1d(128 * self.init_size**2),
nn.LeakyReLU(0.2, inplace=True))
# 转置卷积块
self.conv_blocks = nn.Sequential(
# 上采样到 1/2
nn.ConvTranspose2d(128, 128, 4, 2, 1),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2, inplace=True),
# 上采样到原始尺寸
nn.ConvTranspose2d(128, 64, 4, 2, 1),
nn.BatchNorm2d(64),
nn.LeakyReLU(0.2, inplace=True),
# 输出层
nn.Conv2d(64, channels, 3, 1, 1),
nn.Tanh()
)
def forward(self, z, labels):
# 将噪声向量和条件标签拼接
label_input = self.label_emb(labels)
x = torch.cat([z, label_input], dim=1)
# 通过全连接层并reshape
out = self.l1(x)
out = out.view(out.shape[0], 128, self.init_size, self.init_size)
# 通过转置卷积块
img = self.conv_blocks(out)
return img
(3) 定义判别器 Discriminator
,输入为图像,输出为图像真实性概率 + 特征向量(用于对比损失):
class Discriminator(nn.Module):
def __init__(self, num_classes=10, channels=3, img_size=32):
super(Discriminator, self).__init__()
self.img_size = img_size
# 共享的特征提取层
self.feature_extractor = nn.Sequential(
nn.Conv2d(channels, 64, 4, 2, 1),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(64, 128, 4, 2, 1),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(128, 256, 4, 2, 1),
nn.BatchNorm2d(256),
nn.LeakyReLU(0.2, inplace=True),
)
# 计算特征图大小
feature_size = img_size // 8
# 真实性判别分支
self.adv_layer = nn.Sequential(
nn.Linear(256 * feature_size * feature_size, 1),
nn.Sigmoid()
)
# 对比学习分支(输出特征向量)
self.contrastive_layer = nn.Sequential(
nn.Linear(256 * feature_size * feature_size, 256),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(256, 128) # 输出128维特征向量
)
def forward(self, img):
features = self.feature_extractor(img)
features = features.view(features.size(0), -1)
# 真实性判断
validity = self.adv_layer(features)
# 对比学习特征
contra_features = self.contrastive_layer(features)
return validity, contra_features
(4) 实现对比损失函数,用于计算同一批次内样本间的对比损失:
class ContrastiveLoss(nn.Module):
def __init__(self, temperature=0.5):
super(ContrastiveLoss, self).__init__()
self.temperature = temperature
def forward(self, features, labels):
batch_size = features.size(0)
# 计算特征间的相似度矩阵
features_norm = nn.functional.normalize(features, dim=1)
similarity_matrix = torch.matmul(features_norm, features_norm.T) / self.temperature
# 创建正样本掩码
labels = labels.view(-1, 1)
mask = torch.eq(labels, labels.T).float().to(device)
# 创建负样本掩码
negative_mask = 1 - mask
# 计算正样本相似度
positives = similarity_matrix * mask
positives = positives - torch.eye(batch_size).to(device) * 1e12 # 排除自身
# 计算负样本相似度
negatives = similarity_matrix * negative_mask
# 计算对比损失
numerator = torch.exp(positives).sum(dim=1)
denominator = torch.exp(negatives).sum(dim=1) + numerator
loss = -torch.log(numerator / (denominator + 1e-8))
return loss.mean()
3.2 模型训练
模型构建完成后,使用 CIFAR-10
数据集训练 ContraGAN
。
(1) 定义超参数,并初始化模型、优化器和损失函数:
z_dim = 100
num_classes = 10
img_size = 32
channels = 3
batch_size = 128
epochs = 100
lr = 0.0002
temperature = 0.5
generator = Generator(z_dim, num_classes, channels, img_size).to(device)
discriminator = Discriminator(num_classes, channels, img_size).to(device)
contrastive_loss = ContrastiveLoss(temperature)
optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))
adversarial_loss = nn.BCELoss()
(2) 加载 CIFAR-10
数据集:
transform = transforms.Compose([
transforms.Resize(img_size),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
train_dataset = torchvision.datasets.CIFAR10(
root='./data', train=True, download=True, transform=transform
)
train_loader = DataLoader(
train_dataset, batch_size=batch_size, shuffle=True, num_workers=4
)
(3) 定义训练循环,训练 ContraGAN
模型 100
个 epoch
,且每隔 10
个 epoch
保存生成的伪造样本:
# 训练循环
for epoch in range(epochs):
progress_bar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{epochs}')
for i, (imgs, labels) in enumerate(progress_bar):
batch_size = imgs.size(0)
real_imgs = imgs.to(device)
real_labels = labels.to(device)
# 创建真实和伪造的标签
valid = torch.ones(batch_size, 1).to(device)
fake = torch.zeros(batch_size, 1).to(device)
# 训练判别器
optimizer_D.zero_grad()
# 真实样本的损失
real_validity, real_contra = discriminator(real_imgs)
d_real_loss = adversarial_loss(real_validity, valid)
# 生成伪造样本
z = torch.randn(batch_size, z_dim).to(device)
gen_labels = torch.randint(0, num_classes, (batch_size,)).to(device)
gen_imgs = generator(z, gen_labels)
# 伪造样本的损失
fake_validity, fake_contra = discriminator(gen_imgs.detach())
d_fake_loss = adversarial_loss(fake_validity, fake)
# 对比损失
contra_features = torch.cat([real_contra, fake_contra], dim=0)
contra_labels = torch.cat([real_labels, gen_labels], dim=0)
contra_loss = contrastive_loss(contra_features, contra_labels)
# 总判别器损失
d_loss = d_real_loss + d_fake_loss + contra_loss
d_loss.backward()
optimizer_D.step()
# 训练生成器
optimizer_G.zero_grad()
# 生成器试图欺骗判别器
validity, _ = discriminator(gen_imgs)
g_loss = adversarial_loss(validity, valid)
g_loss.backward()
optimizer_G.step()
# 更新进度条
progress_bar.set_postfix(
d_loss=d_loss.item(),
g_loss=g_loss.item(),
contra_loss=contra_loss.item()
)
if (epoch + 1) % 10 == 0:
with torch.no_grad():
# 生成测试样本
test_z = torch.randn(25, z_dim).to(device)
test_labels = torch.arange(0, 10).repeat(3).to(device)[:25] # 生成0-9的标签
gen_imgs = generator(test_z, test_labels)
# 保存图像
fig, axs = plt.subplots(5, 5, figsize=(10, 10))
cnt = 0
for i in range(5):
for j in range(5):
axs[i, j].imshow(gen_imgs[cnt].cpu().permute(1, 2, 0) * 0.5 + 0.5)
axs[i, j].set_title(f"Label: {test_labels[cnt].item()}")
axs[i, j].axis('off')
cnt += 1
plt.savefig(f"contragan_epoch_{epoch+1}.png")
plt.close()
训练过程生成,生成图像如下所示,可以看到,随着训练的进行模型生成的图像逐渐清晰:
3.3 结果可视化
(1) 使用训练完成的 ContraGAN
模型生成不同类别的样本:
with torch.no_grad():
z = torch.randn(10, z_dim).to(device)
labels = torch.arange(0, 10).to(device)
gen_imgs = generator(z, labels)
# 可视化结果
fig, axs = plt.subplots(2, 5, figsize=(15, 6))
for i in range(2):
for j in range(5):
idx = i * 5 + j
img = gen_imgs[idx].cpu().permute(1, 2, 0) * 0.5 + 0.5
axs[i, j].imshow(img)
axs[i, j].set_title(f"Class: {labels[idx].item()}")
axs[i, j].axis('off')
plt.tight_layout()
plt.show()
(2) 使用训练完成的 ContraGAN
模型生成指定类别的多个样本并可视化:
import matplotlib.pyplot as plt
import numpy as np
from torchvision.utils import make_grid
def generate_same_class_images(generator, num_samples=10, class_idx=0, z_dim=100, device='cuda'):
generator.eval()
# 生成相同类别的样本
with torch.no_grad():
# 创建相同类别的标签
labels = torch.full((num_samples,), class_idx, dtype=torch.long).to(device)
# 生成不同的噪声向量
z = torch.randn(num_samples, z_dim).to(device)
# 生成图像
gen_imgs = generator(z, labels).cpu()
# 将图像从[-1,1]转换到[0,1]范围
gen_imgs = gen_imgs * 0.5 + 0.5
grid = make_grid(gen_imgs, nrow=int(np.sqrt(num_samples)), padding=2, normalize=False)
plt.figure(figsize=(10, 10))
plt.imshow(grid.permute(1, 2, 0))
plt.title(f'Generated Samples for Class {class_idx}', fontsize=16)
plt.axis('off')
plt.show()
def visualize_all_classes(generator, num_classes=10, samples_per_class=5, z_dim=100, img_size=32):
generator.eval()
plt.figure(figsize=(15, 15))
with torch.no_grad():
for class_idx in range(num_classes):
# 生成当前类别的样本
z = torch.randn(samples_per_class, z_dim).to(device)
labels = torch.full((samples_per_class,), class_idx, dtype=torch.long).to(device)
gen_imgs = generator(z, labels).cpu()
gen_imgs = gen_imgs * 0.5 + 0.5 # 反归一化
# 在当前类别的子图中显示
for i in range(samples_per_class):
ax = plt.subplot(num_classes, samples_per_class,
class_idx * samples_per_class + i + 1)
ax.imshow(gen_imgs[i].permute(1, 2, 0))
ax.axis('off')
# 只在每行的第一个图像显示类别标签
if i == 0:
ax.set_title(f'Class {class_idx}', fontsize=12, y=0.9)
plt.tight_layout()
plt.suptitle('ContraGAN Generated Samples by Class', y=1.02, fontsize=16)
plt.show()
# 可视化单个类别的样本
generate_same_class_images(generator, num_samples=9, class_idx=3)
# 可视化所有类别的样本
visualize_all_classes(generator, num_classes=10, samples_per_class=10)
相关链接
PyTorch生成式人工智能实战:从零打造创意引擎
PyTorch生成式人工智能(1)——神经网络与模型训练过程详解
PyTorch生成式人工智能(2)——PyTorch基础
PyTorch生成式人工智能(3)——使用PyTorch构建神经网络
PyTorch生成式人工智能(4)——卷积神经网络详解
PyTorch生成式人工智能(5)——分类任务详解
PyTorch生成式人工智能(6)——生成模型(Generative Model)详解
PyTorch生成式人工智能(7)——生成对抗网络实践详解
PyTorch生成式人工智能(8)——深度卷积生成对抗网络
PyTorch生成式人工智能(9)——Pix2Pix详解与实现
PyTorch生成式人工智能(10)——CyclelGAN详解与实现
PyTorch生成式人工智能(11)——神经风格迁移
PyTorch生成式人工智能(12)——StyleGAN详解与实现
PyTorch生成式人工智能(13)——WGAN详解与实现
PyTorch生成式人工智能(14)——条件生成对抗网络(conditional GAN,cGAN)
PyTorch生成式人工智能(15)——自注意力生成对抗网络(Self-Attention GAN, SAGAN)
PyTorch生成式人工智能(16)——自编码器(AutoEncoder)详解
PyTorch生成式人工智能(17)——变分自编码器详解与实现
PyTorch生成式人工智能(18)——循环神经网络详解与实现
PyTorch生成式人工智能(19)——自回归模型详解与实现
PyTorch生成式人工智能(20)——像素卷积神经网络(PixelCNN)
PyTorch生成式人工智能(21)——归一化流模型(Normalizing Flow Model)
PyTorch生成式人工智能(27)——从零开始训练GPT模型
PyTorch生成式人工智能(28)——MuseGAN详解与实现