import torch import torch.nn as nn import torch.optim as optim from torch import device from torchvision import datasets, transforms 这段代码导入了 PyTorch 相关的核心库以及用于处理图像数据的torchvision
库中的部分模块。torch
是 PyTorch 的基础库,nn
用于构建神经网络模型,optim
用于定义优化器,device
用于指定计算设备(如 CPU 或 GPU),datasets
和transforms
用于加载和预处理图像数据。 # 定义超参数 latent_dim = 100 # 噪声向量维度 image_size = 64 # 图像大小 channels = 3 # 图像通道数(RGB) batch_size = 128 num_epochs = 100 learning_rate = 0.0002 这里定义了一系列超参数,包括生成对抗网络(GAN)中噪声向量的维度、生成图像的大小、图像的通道数(RGB 图像为 3 通道)、每次训练的批量大小、训练的轮数以及学习率。这些超参数将在后续的模型训练和生成过程中起到关键作用。 # 数据预处理与加载 transform = transforms.Compose([ transforms.Resize(image_size), transforms.CenterCrop(image_size), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ]) dataset = datasets.ImageFolder('path/to/image/dataset', transform=transform) dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
- 数据预处理:首先定义了一个
transform
操作,它由一系列的图像变换组成。包括将图像调整为指定的大小、进行中心裁剪、转换为张量格式以及对图像的像素值进行归一化处理,使其均值为 0.5,标准差为 0.5。 - 数据加载:使用
ImageFolder
类从指定路径加载图像数据集,并应用前面定义的transform
操作。然后通过DataLoader
将数据集包装成可迭代的数据加载器,以便在训练过程中按批次获取数据,并且设置了随机打乱数据的功能。
# 定义生成器 class Generator(nn.Module): def __init__(self): super(Generator, self).__init__() self.main = nn.Sequential( # 输入层 nn.ConvTranspose2d(latent_dim, 512, kernel_size=4, stride=1, padding=0, bias=False), nn.BatchNorm2d(512), nn.ReLU(True), # 层叠反卷积层 nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1, bias=False), nn.BatchNorm2d(256), nn.ReLU(True), nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, bias=False), nn.BatchNorm2d(128), nn.ReLU(True), nn.ConvTranspose2d(128, channels, kernel_size=4, stride=2, padding=1, bias=False), nn.Tanh() # 输出层,使用tanh激活函数以确保生成图像像素值在[-1, 1]范围内 ) def forward(self, input): return self.main(input)
- 模型结构:生成器是一个继承自
nn.Module
的神经网络类。在其构造函数__init__
中,定义了一个由多个层组成的main
序列模型。包括一个输入层的转置卷积层(将噪声向量转换为具有一定特征维度的张量),随后跟着几个层叠的转置卷积层用于逐步上采样和生成图像特征,每个转置卷积层后都跟着批量归一化层和 ReLU 激活函数,最后输出层使用tanh
激活函数将生成图像的像素值映射到[-1, 1]
范围内。 - 前向传播:
forward
方法定义了数据在模型中的正向传播路径,即输入噪声向量通过main
序列模型得到生成的图像。
# 定义判别器 class Discriminator(nn.Module): def __init__(self): super(Discriminator, self).__init__() self.main = nn.Sequential( # 输入层 nn.Conv2d(channels, 64, kernel_size=4, stride=2, padding=1, bias=False), nn.LeakyReLU(0.2, inplace=True), # 层叠卷积层 nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1, bias=False), nn.BatchNorm2d(128), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1, bias=False), nn.BatchNorm2d(256), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1, bias=False), nn.BatchNorm2d(512), nn.LeakyReLU(0.2, inplace=True), # 输出层 nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=0, bias=False), nn.Sigmoid() # 输出判别概率 ) def forward(self, input): return self.main(input).view(-1)
- 模型结构:判别器同样是继承自
nn.Module
的类。在构造函数中定义的main
序列模型由多个卷积层组成。从输入层开始,依次通过多个层叠的卷积层对输入图像进行下采样和特征提取,每个卷积层后跟着批量归一化层和 LeakyReLU 激活函数(斜率为 0.2),最后输出层使用sigmoid
函数将判别结果映射到[0, 1]
区间,表示输入图像是真实图像的概率。 - 前向传播:
forward
方法定义了数据在判别器中的正向传播路径,输入图像通过main
序列模型后经过view(-1)
操作将输出张量展平为一维向量,得到判别器对输入图像的判别概率。
# 实例化模型 generator = Generator() discriminator = Discriminator() # 使用Adam优化器 optimizer_G = optim.Adam(generator.parameters(), lr=learning_rate, betas=(0.5, 0.999)) optimizer_D = optim.Adam(discriminator.parameters(), lr=learning_rate, betas=(0.5, 0.999)) # 定义损失函数 criterion = nn.BCELoss()
- 模型实例化:分别实例化了生成器和判别器模型。
- 优化器定义:为生成器和判别器分别定义了 Adam 优化器,指定了学习率以及
betas
参数(用于控制优化器的动量和自适应学习率的衰减)。 - 损失函数定义:使用
BCELoss
作为损失函数,用于衡量判别器和生成器的输出与真实标签之间的差异。这种损失函数适用于二分类问题,在这里用于判断图像是真实的还是生成的。
# 训练循环 for epoch in range(num_epochs): for i, (real_images, _) in enumerate(dataloader): real_images = real_images.to(device) # 将数据转移到GPU(如果可用) # 训练判别器 discriminator.zero_grad() real_labels = torch.ones(batch_size, device=device) fake_labels = torch.zeros(batch_size, device=device) # 计算真实图像损失 real_output = discriminator(real_images) d_real_loss = criterion(real_output, real_labels) # 生成假图像并计算损失 noise = torch.randn(batch_size, latent_dim, device=device) fake_images = generator(noise) fake_output = discriminator(fake_images.detach()) # detach()避免反向传播到生成器 d_fake_loss = criterion(fake_output, fake_labels) # 判别器总损失 d_loss = d_real_loss + d_fake_loss d_loss.backward() optimizer_D.step() # 训练生成器 generator.zero_grad() fake_labels.fill_(1) # 期望生成器产生的图像被判别器判断为真 # 重新计算生成图像的判别损失 fake_output = discriminator(fake_images) g_loss = criterion(fake_output, fake_labels) g_loss.backward() optimizer_G.step() if (i + 1) % 100 == 0: print(f"Epoch [{epoch}/{num_epochs}], Step [{i + 1}/{len(dataloader)}], " f"D loss: {d_loss.item():.4f}, G loss: {g_loss.item():.4f}")
- 整体训练流程:通过外层的
for
循环遍历训练的轮数,内层的for
循环遍历数据加载器中的每个批次数据。 - 判别器训练:
- 首先将判别器的梯度清零,然后定义真实图像的标签为全 1 向量,生成图像的标签为全 0 向量。
- 计算判别器对真实图像的输出与真实标签之间的损失(
d_real_loss
),以及对生成图像(通过生成器生成并从计算图中分离,避免反向传播到生成器)的输出与假标签之间的损失(d_fake_loss
)。 - 将这两个损失相加得到判别器的总损失
d_loss
,然后进行反向传播更新判别器的参数。
- 生成器训练:
- 先将生成器的梯度清零,然后将生成图像的期望标签设置为全 1 向量,表示希望生成器生成的图像能被判别器判断为真实图像。
- 重新计算生成图像通过判别器后的输出与期望标签之间的损失(
g_loss
),进行反向传播更新生成器的参数。
- 训练信息打印:每隔 100 个批次打印当前轮数、批次数以及判别器和生成器的损失值。
# 生成新图像 fixed_noise = torch.randn(64, latent_dim, device=device) fake_images = generator(fixed_noise).detach().cpu() # 可以保存或显示fake_images,查看生成的图像效果
在训练完成后,通过生成一个固定的噪声向量fixed_noise
,将其输入到训练好的生成器中得到生成的图像fake_images
,然后将其从 GPU(如果在 GPU 上训练)转移到 CPU,并可以进一步保存或显示这些生成的图像来查看生成效果。
总体而言,这段代码实现了一个基本的生成对抗网络(GAN),包括数据预处理、模型定义、训练过程以及生成新图像的功能,用于生成与训练数据相似的新图像。