PyTorch —— GAN(手写数字)

本文详细介绍了如何使用PyTorch实现GAN,包括生成器和判别器的设计,数据预处理,以及模型的训练过程。通过MNIST数据集展示GAN如何生成逼真的手写数字,展示了模型从噪声到生成图像的逐步提升。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >


PyTorch 学习 —— GAN(生成对抗网络)

“ 凡我不能创造的,我就无法理解 ”

Ⅰ 简介

GAN 分为两部分,生成网络对抗网络

  1. 生成网络学习数据真实分布的模式,并尝试生成看起来像该真实数据分布中的样本的新样本。生成模型是无监督学习的分支,因为其通过成生成样本来学习潜在的特征。生成网络活跃在多个研究领域,例如提高图像分辨率,为图像填充缺失片段等。
  2. 对抗网络即对抗生成网络的网络,其任务是对生成网络的生成样本提供反馈,充当“监督者”的身份。

  在GAN 中,生成网络的目标是尽可能生成能够成功欺骗对抗网络的生成模型;对抗网络的目标是尽可能识别出生成网络所生成的模型。两者不断学习,不断优化自己的参数以便在这场较量当中占据上风。最终的结果就是生成网络能够生成与真实数据分布极为相似的样本,而对抗网络则有着极高的敏感度,一旦发现任何蛛丝马迹都能够很好将其识别出。

  生成器从随机分布中获取样本,然后由网络将样本转换为输出。判别器仅负责通知生成器其生成的输出看起来像不像真实图像,以便生成器更改其生成图像的方式以使判别器确信它是真的。


接下来搭建一个简单的GAN。

  这是一个利用手写数据集进行训练得到的GAN,生成器接收随机噪声作为输入,然后输出一张手写数字图像;判别器的输入则是两幅图像,分别是真的手写数字图像和生成器生成的假图像,然后输出对这两幅图像的判别结果。
在这里插入图片描述

Ⅱ 准备数据

BATCH_SIZE = 100

def mnist_data():
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])
    return datasets.MNIST(root='./data', train=True, transform=transform, download=True)

data = mnist_data()

data_loader = torch.utils.data.DataLoader(data, batch_size=BATCH_SIZE, shuffle=True)

Ⅲ 判别器网络

  可以想象,如果我们只给判别器输入生成器的生成图像,判别器会学会将所有图像不管是否真实都判定为Fake。这显然不是我们想要的结果,所以我们会在训练判别器时,分别给它一张真实的图像和一张假的图像(如上图所示),其需要找出哪个是真哪个是假,并以此计算损失。

  所以判别器的任务可以视为分类任务,其需要将输入图像分类为原始图像或生成图像,这是二元分类。

  我们搭建一个判别器网络,其共有四层,前三层每层都由一个线性层、激活函数LeakyReLU 和一个dropout 层组成,最后一层由一个线性层和Sigmoid 激活函数组成。

class DiscriminatorNet(nn.Module):
    """
    A three hidden-layer discriminative neural network
    """
    def __init__(self):
        super().__init__()
        n_features = 784
        n_out = 1
        
        self.hidden0 = nn.Sequential(
            nn.Linear(n_features, 1024),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3)
        )
        self.hidden1 = nn.Sequential(
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3)
        )
        self.hidden2 = nn.Sequential(
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3)
        )
        self.out = nn.Sequential(
            nn.Linear(256, n_out),
            nn.Sigmoid()
        )
        
    def forward(self, x):
        x = self.hidden0(x)
        x = self.hidden1(x)
        x = self.hidden2(x)
        x = self.out(x)
        return x

  因为不会将零以下的数全部置为零,所以使用LeakyReLU 激活函数相比使用ReLU 能够更好地使梯度流过网络。

在这里插入图片描述
  判别器总是从真实源数据集或生成器中获取图像,因此其输入为28×28=784 维度的数据,输出为分类结果。

Ⅳ 生成器网络

  生成器的结构与判别器类似,但其最后一层是tanh 层,而不是sigmoid 层。此更改是为了与对MNIST 数据进行的归一化同步,以将其值转换到[-1, 1] 中,以便判别器始终获取数据点处于相同值域的数据集。

class GeneratorNet(nn.Module):
    """
    A three hidden-layer generative neural network
    """
    def __init__(self):
        super().__init__()
        n_features = 100
        n_out = 784
        
        self.hidden0 = nn.Sequential(
            nn.Linear(n_features, 256),
            nn.LeakyReLU(0.2)
        )
        self.hidden1 = nn.Sequential(
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2)
        )
        self.hidden2 = nn.Sequential(
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2)
        )
        self.out = nn.Sequential(
            nn.Linear(1024, n_out),
            nn.Tanh()
        )
        
    def forward(self, x):
        x = self.hidden0(x)
        x = self.hidden1(x)
        x = self.hidden2(x)
        x = self.out(x)
        return x

Ⅴ 模型训练

  我们需要同时训练两个模型,分别是判别器模型与生成器模型。其中判别器模型的训练过程中需要输入两幅图像,然后根据其判别结果分别计算损失,最后将损失相加再更新参数;而生成器模型的训练则较为简单,只需要根据判别器的判别结果进行参数更新即可。

  模型的训练过程可以遵循该顺序:1)梯度清零;2)获取输出;3)计算损失;4)反向传播;5)参数优化

  因为理想的判别器能够将真实图像标记为1,将生成图像标记为0。所以在分别对真实图像与生成图像的判别结果计算损失时,使其标签分别为同等尺寸的全1张量与全0张量。

def train_discriminator(optimizer, loss_fn, input_real_data, input_fake_data):
    optimizer.zero_grad()
    
    # 对于真实样本的预测结果
    discriminated_real_data = discriminator(input_real_data)
    loss_real = loss_fn(discriminated_real_data, real_data_target(input_real_data.size(0)))
    loss_real.backward()
    
    # 对于生成样本的预测结果
    discriminated_fake_data = discriminator(input_fake_data)
    loss_fake = loss_fn(discriminated_fake_data, fake_data_target(input_fake_data.size(0)))
    loss_fake.backward()
    
    optimizer.step()
    
    return loss_real+loss_fake, discriminated_real_data, discriminated_fake_data

  理想的生成器能够生成使判别器判别为1(即认为是真实图像)的图像,因此在其计算损失时的标签为同等尺寸的全1张量。

def train_generator(optimizer, loss_fn, input_fake_data):
    optimizer.zero_grad()
    output_discriminated = discriminator(input_fake_data)
    loss = loss_fn(output_discriminated, real_data_target(output_discriminated.size(0)))
    loss.backward()
    optimizer.step()
    return loss

开始训练

discriminator = DiscriminatorNet().to(device)
d_optimizer = optim.Adam(discriminator.parameters(), lr=0.0002)

generator = GeneratorNet().to(device)
g_optimizer = optim.Adam(generator.parameters(), lr=0.0002)

loss_fn = nn.BCELoss()

num_epochs = 100
num_test_samples = 3

  我们首先将真实样本与生成器的生成样本作为判别器的输入,对判别器进行训练。然后再通过判别器的预测结果对生成器进行训练。并间歇输出生成器的生成图像以观察模型训练结果。

for epoch in range(num_epochs):
    for train_idx, (input_real_batch, _) in enumerate(data_loader):
        # training the discriminator
        input_real_data = images2vectors(input_real_batch).to(device)
        generated_fake_data = generator(noise(input_real_data.size(0))).detach()
        d_loss, discriminated_real, discriminated_fake = train_discriminator(d_optimizer, loss_fn, input_real_data, generated_fake_data)
        
        # training the generator
        generated_fake_data = generator(noise(input_real_batch.size(0)))
        g_loss = train_generator(g_optimizer, loss_fn, generated_fake_data)
        # print the info
        if train_idx % 10 == 0:
            print('Epoch:{} Loss:{} =={:.2f}%=='.format\
                 (epoch, g_loss, ((train_idx+epoch*num_batches)/(num_batches*num_epochs))*100))
            # output the images generated by generator
            if train_idx % 30 == 0:
                test_noise = noise(num_test_samples)
                generated = generator(test_noise).detach()
                discriminated = discriminator(generated)
                images = vectors2images(generated)

                fig = plt.figure()
                for i in range(3):
                    plt.subplot(1, 3, i+1)
                    plt.imshow(images[i][0], cmap='gray', interpolation='none')
                    plt.title('pred: {:.3f}'.format(discriminated[i].item()))
                    plt.xticks([])
                    plt.yticks([])
                    plt.show()
                    
print('End Training')

Ⅵ 运行结果

未训练时,可以看到生成的图像完全就是随机噪声,判别器也完全没有作用。
在这里插入图片描述 在这里插入图片描述 在这里插入图片描述
训练了50 epoch 后,图像隐隐约约可以看得出是什么数字.
在这里插入图片描述 在这里插入图片描述 在这里插入图片描述
训练了100 epoch 后。
在这里插入图片描述 在这里插入图片描述 在这里插入图片描述
当训练了200 epoch 后。
在这里插入图片描述 在这里插入图片描述 在这里插入图片描述
训练了300 epoch 后,虽然没有太大进步了,不过可以看到生成的数字更加自然,同时看到判别器给出的预测数值依然不高,说明两者在不断的竞争当中都在不断成长进步。
在这里插入图片描述 在这里插入图片描述 在这里插入图片描述
下面是真实的手写数字样本,通过对比可以看到生成器所生成的手写数字已经非常接近真实样本了。
在这里插入图片描述

Ⅶ 完整代码

SimpleGAN.py
在这里插入图片描述

评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

诺亚方包

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值