import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import torch.utils.data as Data
import torchvision
from torchvision import datasets as dsets
from torchvision import transforms as trf
from torchvision.utils import save_image
#import matplotlib.pyplot as plt
#import numpy as np
def to_var(x):
if torch.cuda.is_available():
x = x.cuda()
return Variable(x)
def denorm(x):
out = (x + 1) / 2
return out.clamp(0, 1)
transform = trf.Compose([
trf.ToTensor(),
# trf.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])
train_data = dsets.MNIST(
root='./data/',
train=True,
transform=transform,
download=True
)
data_loader = Data.DataLoader(dataset=train_data, batch_size=100, shuffle=True)
D = nn.Sequential(
nn.Linear(784, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 128),
nn.LeakyReLU(0.2),
nn.Linear(128, 1),
nn.Sigmoid()
)
G = nn.Sequential(
nn.Linear(64, 128),
nn.LeakyReLU(0.2),
nn.Linear(128, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 784),
nn.Tanh()
)
if torch.cuda.is_available():
D.cuda()
G.cuda()
loss_fn = nn.BCELoss()
d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0003)
g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0003)
for epoch in range(5):
for i, (images, _) in enumerate(data_loader):
batch_size = images.size(0)
images = to_var(images.view(batch_size, -1))
real_labels = to_var(torch.ones(batch_size))
fake_labels = to_var(torch.zeros(batch_size))
outputs = D(images)
d_loss_real = loss_fn(outputs, real_labels)
real_score = outputs
z = to_var(torch.randn(batch_size, 64))
fake_images = G(z)
d_loss_fake = loss_fn(outputs, fake_labels)
fake_score = outputs
d_loss = d_loss_real + d_loss_fake
d_optimizer.zero_grad()
d_loss.backward()
d_optimizer.step()
z = to_var(torch.randn(batch_size, 64))
fake_images = G(z)
outputs = D(fake_images)
g_loss = loss_fn(outputs, real_labels)
d_optimizer.zero_grad()
g_optimizer.zero_grad()
g_loss.backward()
g_optimizer.step()
if (i + 1) % 300 == 0:
print('Epoch [%d/%d], Step [%d/%d], d_loss: %.4f, g_loss: %.4f, D(x): %.2f, D(G(z)): %.2f' %
(epoch, 5, i + 1, 600, d_loss.item(), g_loss.item(),
real_score.data.mean(), fake_score.data.mean()))
if (epoch + 1) == 1:
images = images.view(images.size(0), 1, 28, 28)
save_image(denorm(images.data), './data/real_images.png')
fake_images = fake_images.view(fake_images.size(0), 1, 28, 28)
save_image(denorm(fake_images.data), './data/fake_images-%d.png' % (epoch + 1))
torch.save(G, './generator.pkl')
torch.save(D, './discriminator.pkl')
pytorch入门(八)—— GAN对抗生成网络
最新推荐文章于 2025-06-12 16:53:01 发布