生成对抗网络(GAN)代码实战

一、GAN的核心思想与架构

生成对抗网络(Generative Adversarial Network, GAN)由Ian Goodfellow于2014年提出,其核心思想是通过两个神经网络的对抗博弈实现数据生成。在MNIST手写数字生成的案例中,我们构建了包含生成器(Generator)​判别器(Discriminator)​的对抗系统。

1.1 对抗博弈机制

  • 生成器(G)​​:接收随机噪声向量z,通过全连接层将低维噪声映射到与真实数据相同的维度(MNIST中为784维)。其目标是生成逼真的假图像,欺骗判别器。
  • 判别器(D)​​:接收真实图像或生成图像,输出0-1之间的概率值,表示输入为真实数据的置信度。其目标是准确区分真假样本。

两者的对抗过程可形式化为极小极大博弈:

该公式体现了生成器试图最小化判别器的判别能力,而判别器试图最大化其判别能力的动态平衡。

1.2 网络架构设计

在MNIST案例中,我们采用全连接网络结构:

  • 判别器​:输入层(784维)→ 隐藏层(128维,LeakyReLU)→ 输出层(1维,Sigmoid)
  • 生成器​:输入层(64维噪声)→ 隐藏层(256维,LeakyReLU)→ 输出层(784维,Tanh)

选择全连接网络是因为MNIST数据维度较低(28x28),而DCGAN等变体通过卷积层更适合高分辨率图像生成。


二、训练过程与关键技术

2.1 损失函数设计

GAN的损失函数包含两部分:

  1. 判别器损失​:

    通过二元交叉熵(BCELoss)实现,使判别器能区分真假样本。

  2. 生成器损失​:

    生成器通过最大化判别器对假样本的误判概率(即让判别器输出趋近于1)来优化自身参数。

2.2 训练动态分析

在MNIST案例中,训练过程呈现以下特点:

  • 交替训练​:先固定生成器更新判别器,再固定判别器更新生成器
  • 梯度惩罚​:通过retain_graph=True保留计算图,避免反向传播时梯度消失
  • 学习率控制​:采用Adam优化器(lr=3e-4),平衡收敛速度与稳定性

实验发现,当判别器准确率超过95%时,生成器容易陷入模式崩溃(生成重复样本),此时需调整网络容量或引入Wasserstein距离等改进方案。


三、MNIST生成案例的实现细节

3.1 数据预处理

transform = transforms.Compose([
    transforms.ToTensor(),  # 转换为[0,1]范围的张量
    transforms.Normalize((0.1307,), (0.3081,))  # 标准化处理
])

MNIST数据集的像素均值和标准差分别为0.1307和0.3081,标准化可加速模型收敛。

3.2 固定噪声可视化

通过固定噪声向量fixed_noise,可在训练过程中观察生成器输出的变化轨迹:

fixed_noise = torch.randn((32, 64))  # 批量生成固定噪声

该方法有助于评估生成器的稳定性:若多次生成的图像相似但无重复,则说明模型具有良好多样性。


四、完整代码

# 导入必要的库
import torch  # PyTorch核心库
import torch.nn as nn  # 神经网络模块
import torch.optim as optim  # 优化器模块
import torchvision  # 计算机视觉工具包
import torchvision.transforms as transforms  # 数据预处理转换
from torch.utils.data import DataLoader  # 数据加载器
from torch.utils.tensorboard import SummaryWriter  # 可视化工具(TensorBoard)


# ==================== 定义判别器(Discriminator) ====================
class Discriminator(nn.Module):
    def __init__(self, img_dim):
        """
        判别器网络结构初始化
        参数:
            img_dim: 输入图像的维度(展平后的一维长度)
        """
        super(Discriminator, self).__init__()  # 调用父类初始化方法

        # 定义判别器的网络结构(Sequential容器按顺序连接各层)
        self.disc = nn.Sequential(
            # 输入层:接收展平后的图像(维度为img_dim),输出128维特征
            nn.Linear(img_dim, 128),
            # 激活函数:LeakyReLU(带泄露的ReLU,解决死神经元问题),负斜率0.1
            nn.LeakyReLU(0.1),
            # 输出层:将128维特征映射到1维(表示输入为真实图像的概率)
            nn.Linear(128, 1),
            # 激活函数:Sigmoid(将输出压缩到[0,1]区间,表示概率)
            nn.Sigmoid()
        )

    def forward(self, x):
        """前向传播:输入图像,输出判别结果(概率值)"""
        return self.disc(x)  # 将输入x传入网络结构


# ==================== 定义生成器(Generator) ====================
class Generator(nn.Module):
    def __init__(self, z_dim, img_dim):
        """
        生成器网络结构初始化
        参数:
            z_dim: 噪声向量的维度(输入潜变量空间的维度)
            img_dim: 输出图像的维度(展平后的一维长度)
        """
        super(Generator, self).__init__()

        # 定义生成器的网络结构
        self.gen = nn.Sequential(
            # 输入层:接收z维度的噪声向量,输出256维特征
            nn.Linear(z_dim, 256),
            # 激活函数:LeakyReLU
            nn.LeakyReLU(0.1),
            # 输出层:将256维特征映射到img_dim维(展平后的图像维度)
            nn.Linear(256, img_dim),
            # 激活函数:Tanh(将输出压缩到[-1,1]区间,与预处理后的图像范围匹配)
            nn.Tanh()
        )

    def forward(self, x):
        """前向传播:输入噪声向量,输出生成的图像(展平后)"""
        return self.gen(x)  # 将输入噪声x传入网络结构


# ==================== 主训练流程 ====================
def main():
    # 超参数设置
    lr = 3e-4  # 学习率(优化器的学习速率)
    z_dim = 64  # 噪声向量的维度(潜空间维度)
    img_dim = 28 * 28  # MNIST图像展平后的维度(28x28=784)
    batch_size = 32  # 每批数据的样本数量
    epochs = 50  # 训练轮数(遍历整个数据集的次数)

    # 初始化判别器和生成器
    disc = Discriminator(img_dim)  # 判别器实例化(输入维度为img_dim)
    gen = Generator(z_dim, img_dim)  # 生成器实例化(噪声维度z_dim,输出维度img_dim)

    # 固定噪声向量(用于训练过程中可视化生成效果)
    # 每次训练都使用相同的噪声,观察生成图像的变化
    fixed_noise = torch.randn((batch_size, z_dim))  # 形状:(32, 64)

    # 数据预处理与加载
    # 定义数据转换:转换为张量,并标准化(MNIST的均值约0.1307,标准差约0.3081)
    trans = transforms.Compose([
        transforms.ToTensor(),  # 转换为PyTorch张量(形状:[C, H, W],范围[0,1])
        transforms.Normalize((0.1307,), (0.3081,))  # 标准化:(x - mean)/std
    ])
    # 加载MNIST训练集(root指定存储路径,train=True表示训练集,download=True自动下载)
    dataset = torchvision.datasets.MNIST(
        root='./dataset',  # 数据集存储路径
        train=True,  # 加载训练集
        transform=trans,  # 应用预处理转换
        download=True  # 若本地无数据则下载
    )
    # 创建数据加载器(批量加载数据,shuffle=True表示打乱顺序)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    # 初始化优化器(分别用于更新判别器和生成器的参数)
    # Adam优化器(常用且效果好),学习率lr
    opt_disc = optim.Adam(disc.parameters(), lr=lr)  # 判别器优化器
    opt_gen = optim.Adam(gen.parameters(), lr=lr)  # 生成器优化器

    # 定义损失函数:二元交叉熵(BCELoss)
    # 用于衡量判别器的分类准确性和生成器的欺骗能力
    criterion = nn.BCELoss()

    # 初始化TensorBoard的SummaryWriter(用于可视化训练过程)
    # 'runs/MNIST_real':保存真实图像的日志目录
    # 'runs/MNIST_fake':保存生成图像的日志目录
    writer_real = SummaryWriter('runs/MNIST_real')
    writer_fake = SummaryWriter('runs/MNIST_fake')

    step = 0  # 全局步数计数器(用于TensorBoard记录不同训练阶段的图像)

    # 开始训练循环
    for epoch in range(epochs):  # 遍历每个训练轮次
        # 遍历数据加载器中的每个批次(batch)
        for batch_idx, (real, _) in enumerate(dataloader):
            """
            real:当前批次的真实图像(形状:[batch_size, 1, 28, 28],1通道,28x28像素)
            _:图像的标签(MNIST的数字标签,此处不需要,用_占位)
            """

            # ---------------------- 步骤1:预处理真实图像 ----------------------
            # 将真实图像展平为一维向量(因为全连接层的输入需要一维)
            # 原始形状:[batch_size, 1, 28, 28] → 展平后:[batch_size, 28 * 28=784]
            real = real.view(-1, 28 * 28)

            # ---------------------- 步骤2:训练判别器(D) ----------------------
            # 生成一批噪声向量(用于生成假图像)
            noise = torch.randn((batch_size, z_dim))  # 形状:[32, 64]
            # 生成器根据噪声生成假图像(展平后的一维向量)
            fake = gen(noise)  # 形状:[32, 784]

            # 判别器对真实图像的分类结果(输出概率值)
            # 输入真实图像 → 输出形状:[32, 1](每个样本的概率值)
            disc_real = disc(real).reshape(-1)  # 展平为[32](方便后续计算损失)
            # 判别器对假图像的分类结果(输出概率值)
            disc_fake = disc(fake).reshape(-1)  # 展平为[32]

            # 计算判别器的损失(目标:真实图像被判为1,假图像被判为0)
            # 真实图像的损失:BCELoss(预测概率, 真实标签=1)
            loss_real = criterion(disc_real, torch.ones_like(disc_real))
            # 假图像的损失:BCELoss(预测概率, 真实标签=0)
            loss_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
            # 总损失:真实和假图像损失的平均值(平衡两者的重要性)
            loss_disc = (loss_real + loss_fake) / 2

            # 反向传播并更新判别器参数
            opt_disc.zero_grad()  # 清空判别器优化器的梯度缓存
            loss_disc.backward(retain_graph=True)  # 计算损失对参数的梯度(反向传播)
            opt_disc.step()  # 根据梯度更新判别器的参数

            # ---------------------- 步骤3:训练生成器(G) ----------------------
            # 生成器希望生成的假图像被误判为真实图像(即判别器输出接近1)
            # 再次将假图像输入判别器(此时不计算梯度,避免影响生成器训练)
            output = disc(fake).reshape(-1)  # 判别器对假图像的输出(形状:[32])
            # 生成器的损失:BCELoss(判别器输出, 目标标签=1)(希望判别器认为假图像是真实的)
            loss_gen = criterion(output, torch.ones_like(output))

            # 反向传播并更新生成器参数
            opt_gen.zero_grad()  # 清空生成器优化器的梯度缓存
            loss_gen.backward()  # 计算损失对参数的梯度
            opt_gen.step()  # 根据梯度更新生成器的参数

            # ---------------------- 步骤4:定期可视化训练结果 ----------------------
            # 每训练完第一个batch(batch_idx=0)时,打印损失并保存图像到TensorBoard
            if batch_idx == 0:
                # 打印当前轮次和批次的损失信息
                print(
                    f"Epoch [{epoch}/{epochs}] "  # 当前轮次/总轮次
                    f"Loss D: {loss_disc:.4f}, loss G: {loss_gen:.4f}"  # 判别器和生成器的损失
                )

                # 可视化真实图像(使用固定批次的真实数据)
                with torch.no_grad():  # 关闭梯度计算(节省内存)
                    fake = gen(fixed_noise).reshape(-1, 1, 28, 28)
                    data = real.reshape(-1, 1, 28, 28)
                    img_grid_real = torchvision.utils.make_grid(
                        data,
                        normalize=True  # 自动归一化到[0,1](便于显示)
                    )
                    img_grid_fake = torchvision.utils.make_grid(
                        fake,
                        normalize=True
                    )
                    # 将真实图像网格写入TensorBoard(标签'MNIST_real',全局步数step)
                    writer_real.add_image('MNIST_real', img_grid_real, global_step=step)
                    # 将假图像网格写入TensorBoard(标签'MNIST_fake',全局步数step)
                    writer_fake.add_image('MNIST_fake', img_grid_fake, global_step=step)
                    # 全局步数加1(下次记录时使用新的步数)
                    step += 1


if __name__ == "__main__":
    main()

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值