【简历修改、职业规划、大厂实习请私信联系~】
先预览一下 AE 和 VAE 模型的效果(第一行是输入原图,第二行是 AE 重构结果,第三行是 VAE 重构结果),几乎完美重构!
本文介绍 Diffusion 模型推理加速的一种常见方式:用AE(AutoEncoder) 和 VAE(Variational AutoEncoder) 进行图片压缩/反压缩。理论部分学完之后立即用代码进行实践,彻底掌握 AE/VAE。
AE 基础知识
自编码器(AutoEncoder,AE)是一种无监督学习的神经网络模型,主要用于数据压缩和特征学习。它的核心结构包括两个主要部分:编码器和解码器。编码器负责将输入数据压缩到一个低维的潜在空间,这个过程可以看作是提取输入数据的关键特征。解码器则尝试从这个压缩的表示重构原始输入,目标是使重构的输出尽可能接近原始输入。
AE 通过最小化重构误差来训练,这促使网络学习输入数据的最重要特征。训练完成后,编码器可以用于降维、特征提取或数据压缩,而完整的 AE 可以用于去噪或异常检测等任务。
AE 的优点包括结构简单、训练相对快速,以及可以学习紧凑的特征表示。然而,它也存在一些局限性,如生成能力有限,难以生成新的、有意义的样本。
VAE 基础知识
变分自编码器(VAE)是自编码器的一种概率变体,它结合了变分推断和神经网络,用于生成模型和表示学习。VAE 的核心思想是将输入数据编码为概率分布,而不是固定的向量。
VAE 的结构包括编码器、采样层和解码器。编码器将输入映射到潜在空间的均值和方差,采样层从这个分布中采样,解码器则从采样的潜在向量重构输入。
VAE 的训练目标包括两部分:重构损失和 KL 散度。重构损失确保模型能够准确重建输入,而 KL 散度则作为正则化项,使潜在空间的分布接近标准正态分布。
具体操作
编码器:将输入数据x编码为隐变量z的均值μ和标准差σ。
采样:从标准正态分布中采样一个ε,通过μ和σ计算z = μ + ε* σ。
解码器:将z解码为生成样本x'。
计算重构误差(如均方误差MSE)和KL散度,并通过优化算法调整模型参数,以最小化两者的和。
相比传统的自编码器,VAE 具有更强的生成能力,可以生成新的、合理的样本。它的潜在空间是连续的,便于插值,并且具有一定的正则化效果,有助于减少过拟合。
VAE 广泛应用于图像生成、异常检测、数据增强等领域。然而,它的训练过程可能较为复杂和不稳定,且 KL 散度项可能导致模型忽视部分输入信息。
对比 AE 和 VAE
代码实战
import torch,os
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
# 设置随机种子以确保结果可复现
torch.manual_seed(42)
# 准备MNIST数据集
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
# 创建数据加载器
# batch_size: 每批处理的样本数
# shuffle: 是否在每个epoch打乱数据
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)
class AE(nn.Module):
def __init__(self):
super(AE, self).__init__()
# 编码器:将28x28的输入压缩到3维潜在空间
self.encoder = nn.Sequential(
nn.Linear(28*28, 128),
nn.ReLU(),
nn.Linear(128, 64),
nn.ReLU(),
nn.Linear(64, 12),
nn.ReLU(),
nn.Linear(12, 3)
)
# 解码器:将3维潜在空间重构为28x28的输出
self.decoder = nn.Sequential(
nn.Linear(3, 12),
nn.ReLU(),
nn.Linear(12, 64),
nn.ReLU(),
nn.Linear(64, 128),
nn.ReLU(),
nn.Linear(128, 28*28),
nn.Sigmoid() # 使用Sigmoid确保输出在[0,1]范围内
)
def forward(self, x):
"""
前向传播函数
参数:
x (torch.Tensor): 输入图像张量,形状为 (batch_size, 1, 28, 28)
返回:
torch.Tensor: 重构后的图像张量,形状为 (batch_size, 1, 28, 28)
"""
x = x.view(-1, 28*28) # 将输入展平
encoded = self.encoder(x)
decoded = self.decoder(encoded)
return decoded.view(-1, 1, 28, 28) # 重塑为原始图像形状
# 初始化AE模型
ae_model = AE()
ae_optimizer = optim.Adam(
ae_model.parameters(),
lr=0.0001
)
ae_criterion = nn.MSELoss() # 使用均方误差作为重构损失
class VAE(nn.Module):
def __init__(self):
super(VAE, self).__init__()
# 编码器
self.encoder = nn.Sequential(
nn.Linear(28*28, 128),
nn.ReLU(),
nn.Linear(128, 64),
nn.ReLU()
)
# 均值和对数方差的全连接层
self.fc_mu = nn.Linear(64, 3)
self.fc_logvar = nn.Linear(64, 3)
# 解码器
self.decoder = nn.Sequential(
nn.Linear(3, 64),
nn.ReLU(),
nn.Linear(64, 128),
nn.ReLU(),
nn.Linear(128, 28*28),
nn.Sigmoid()
)
def encode(self, x):
"""
编码函数
参数:
x (torch.Tensor): 输入图像张量,形状为 (batch_size, 784)
返回:
tuple(torch.Tensor, torch.Tensor): 均值和对数方差,每个的形状为 (batch_size, 3)
"""
h = self.encoder(x)
return self.fc_mu(h), self.fc_logvar(h)
def reparameterize(self, mu, logvar):
"""
重参数化技巧
参数:
mu (torch.Tensor): 均值,形状为 (batch_size, 3)
logvar (torch.Tensor): 对数方差,形状为 (batch_size, 3)
返回:
torch.Tensor: 采样得到的潜在变量,形状为 (batch_size, 3)
"""
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
def decode(self, z):
"""
解码函数
参数:
z (torch.Tensor): 潜在变量,形状为 (batch_size, 3)
返回:
torch.Tensor: 重构的图像,形状为 (batch_size, 784)
"""
return self.decoder(z)
def forward(self, x):
"""
前向传播函数
参数:
x (torch.Tensor): 输入图像张量,形状为 (batch_size, 1, 28, 28)
返回:
tuple: (重构图像, 均值, 对数方差)
- 重构图像 (torch.Tensor): 形状为 (batch_size, 1, 28, 28)
- 均值 (torch.Tensor): 形状为 (batch_size, 3)
- 对数方差 (torch.Tensor): 形状为 (batch_size, 3)
"""
x = x.view(-1, 28*28)
mu, logvar = self.encode(x)
z = self.reparameterize(mu, logvar)
return self.decode(z).view(-1, 1, 28, 28), mu, logvar
# 初始化VAE模型
vae_model = VAE()
vae_optimizer = optim.Adam(
vae_model.parameters(),
lr=0.0001
)
def vae_loss(recon_x, x, mu, logvar):
"""
VAE损失函数:重构损失 + KL散度
参数:
recon_x (torch.Tensor): 重构的图像,形状为 (batch_size, 784)
x (torch.Tensor): 原始图像,形状为 (batch_size, 784)
mu (torch.Tensor): 均值,形状为 (batch_size, 3)
logvar (torch.Tensor): 对数方差,形状为 (batch_size, 3)
返回:
torch.Tensor: 标量,表示总损失
"""
BCE = nn.functional.binary_cross_entropy(recon_x.view(-1, 28*28), x.view(-1, 28*28), reduction='sum')
KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return BCE + KLD
def train(epoch, model, optimizer, criterion, is_vae=False):
"""
训练函数
参数:
epoch (int): 当前训练的轮数
model (nn.Module): 要训练的模型(AE或VAE)
optimizer (torch.optim.Optimizer): 优化器
criterion (callable): 损失函数(仅用于AE)
is_vae (bool): 是否为VAE模型
返回:
None
"""
model.train()
train_loss = 0
for batch_idx, (data, _) in enumerate(train_loader):
optimizer.zero_grad()
if is_vae:
recon_batch, mu, logvar = model(data)
loss = vae_loss(recon_batch, data, mu, logvar)
else:
recon_batch = model(data)
loss = criterion(recon_batch, data)
loss.backward()
train_loss += loss.item()
optimizer.step()
print(f'====> Epoch: {epoch} Average loss: {train_loss / len(train_loader.dataset):.4f}')
# 训练AE和VAE模型
num_epochs = 200
def compare_reconstructions(ae_model, vae_model):
"""
比较AE和VAE模型的重构效果
参数:
ae_model (nn.Module): 训练好的AE模型
vae_model (nn.Module): 训练好的VAE模型
返回:
None (显示图像)
"""
ae_model.eval()
vae_model.eval()
with torch.no_grad():
data = next(iter(test_loader))[0][:8] # 获取8个测试样本
ae_recon = ae_model(data)
vae_recon, _, _ = vae_model(data)
# 将原始图像、AE重构和VAE重构拼接在一起
comparison = torch.cat([data, ae_recon, vae_recon])
plt.figure(figsize=(12, 4))
for i in range(24):
plt.subplot(3, 8, i+1)
plt.imshow(comparison[i].squeeze().numpy(), cmap='gray')
plt.axis('off')
plt.tight_layout()
plt.show()
# 创建保存路径
save_dir = "/root/autodl-tmp/projects/diffuser/handwritten_algos/ldm/res_images"
os.makedirs(save_dir, exist_ok=True)
# 保存图像
save_path = os.path.join(save_dir, f"{epoch}.png")
plt.savefig(save_path)
print(f"=> saved to {save_path}")
plt.close() # 关闭图像,防止内存泄漏
for epoch in range(1, num_epochs + 1):
train(epoch, ae_model, ae_optimizer, ae_criterion)
train(epoch, vae_model, vae_optimizer, None, is_vae=True)
# 比较重构效果
compare_reconstructions(ae_model, vae_model)