pytorch入门(八)—— GAN对抗生成网络

此博客聚焦于PyTorch入门,着重介绍GAN对抗生成网络。GAN是重要的深度学习模型,在图像生成等领域应用广泛。借助PyTorch学习GAN,能让开发者掌握相关技术,为后续实践打下基础。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import torch.utils.data as Data
import torchvision
from torchvision import datasets as dsets
from torchvision import transforms as trf
from torchvision.utils import save_image 

#import matplotlib.pyplot as plt
#import numpy as np

def to_var(x):
    if torch.cuda.is_available():
        x = x.cuda()
    return Variable(x)

def denorm(x):
    out = (x + 1) / 2
    return out.clamp(0, 1)

transform = trf.Compose([
        trf.ToTensor(),
#        trf.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
    ])

train_data = dsets.MNIST(
            root='./data/',
            train=True,
            transform=transform,
            download=True
        )

data_loader = Data.DataLoader(dataset=train_data, batch_size=100, shuffle=True)

D = nn.Sequential(
        nn.Linear(784, 256),
        nn.LeakyReLU(0.2),
        nn.Linear(256, 128),
        nn.LeakyReLU(0.2),
        nn.Linear(128, 1),
        nn.Sigmoid()
    )

G = nn.Sequential(
        nn.Linear(64, 128),
        nn.LeakyReLU(0.2),
        nn.Linear(128, 256),
        nn.LeakyReLU(0.2),
        nn.Linear(256, 784),
        nn.Tanh()
    )


if torch.cuda.is_available():
    D.cuda()
    G.cuda()

loss_fn = nn.BCELoss()
d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0003)
g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0003)

for epoch in range(5):
    for i, (images, _) in enumerate(data_loader):
        batch_size = images.size(0)
        images = to_var(images.view(batch_size, -1))
        real_labels = to_var(torch.ones(batch_size))
        fake_labels = to_var(torch.zeros(batch_size))
        outputs = D(images)
        d_loss_real = loss_fn(outputs, real_labels)
        real_score = outputs

        z = to_var(torch.randn(batch_size, 64))
        fake_images = G(z)
        d_loss_fake = loss_fn(outputs, fake_labels)
        fake_score = outputs
        d_loss = d_loss_real + d_loss_fake
        
        d_optimizer.zero_grad()
        d_loss.backward()
        d_optimizer.step()

        z = to_var(torch.randn(batch_size, 64))
        fake_images = G(z)
        outputs = D(fake_images)
        g_loss = loss_fn(outputs, real_labels)

        d_optimizer.zero_grad()
        g_optimizer.zero_grad()
        g_loss.backward()
        g_optimizer.step()
        
        if (i + 1) % 300 == 0:
            print('Epoch [%d/%d], Step [%d/%d], d_loss: %.4f, g_loss: %.4f, D(x): %.2f, D(G(z)): %.2f' %
                    (epoch, 5, i + 1, 600, d_loss.item(), g_loss.item(), 
                        real_score.data.mean(), fake_score.data.mean()))

        if (epoch + 1) == 1:
            images = images.view(images.size(0), 1, 28, 28)
            save_image(denorm(images.data), './data/real_images.png')
        
        fake_images = fake_images.view(fake_images.size(0), 1, 28, 28)
        save_image(denorm(fake_images.data), './data/fake_images-%d.png' % (epoch + 1))

torch.save(G, './generator.pkl')
torch.save(D, './discriminator.pkl')
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值