本文介绍如何用纯PyTorch从零搭建一个扩散模型(DDPM, Denoising Diffusion Probabilistic Model),用于生成MNIST手写数字。我们使用一个简化的U-Net作为核心噪声预测网络,并手动实现加噪和去噪过程。本文适合想深入理解扩散模型原理和实现机制的朋友。
背景知识简述:什么是DDPM?
扩散模型是一种生成模型,它通过逐步向数据中添加高斯噪声,将真实图像“扩散”成纯噪声,然后再训练一个神经网络学会“逆扩散”过程,从噪声中逐步还原出真实图像。
这个过程分为两个阶段:
-
正向扩散(Forward Process):逐步给图像加噪,最终变成纯噪声。
-
反向生成(Reverse Process):训练网络预测噪声,然后从纯噪声中逐步还原图像。
1. DDPM (Denoising Diffusion Probabilistic Models) 原理介绍
DDPM(去噪扩散概率模型) 是一种生成模型,它通过逐步向数据添加噪声,直到数据完全变成噪声,再通过一个神经网络模型反向推理出原始数据。这种方法的特点是可以非常有效地进行生成,并且在图像生成领域取得了显著的成功。
1.1 扩散过程与去噪过程
-
扩散过程(Forward Diffusion Process):
扩散过程是一个逐步向数据添加噪声的过程。在每一步,图像会与一定量的噪声进行混合,最终使图像变成纯噪声。这个过程是固定的,不需要训练,通常是基于一个预设的噪声调度(即每个时间步噪声的强度)来完成的。数学表示:
假设我们有一个图像 x_0(初始图像),扩散过程的目标是生成一个噪声图像 x_t,随着时间步的增加(即 t 从 0 到 T),数据逐渐被加上噪声,直到最终变为纯噪声。扩散过程定义为:
-
-
其中,α_t 是控制噪声比例的参数,ϵ 是标准正态分布的噪声,x_t 是在时间步 t 时的图像。
-
去噪过程(Reverse Diffusion Process):
去噪过程是一个逐步去除噪声的过程。与扩散过程相反,去噪过程的目的是从纯噪声中逐步恢复到真实图像。这个过程通过训练一个神经网络模型来预测每个时间步的噪声,并根据这个预测逐步去除噪声,最终生成清晰的图像。在反向过程中,模型通过学习从 x_t 预测出噪声 ϵ,并通过以下公式逐步更新图像:
其中,β_t 是每个时间步的噪声增加量,ϵ_t 是模型预测的噪声。
1.2 噪声的预测
DDPM的关键在于训练一个神经网络去预测每个时间步的噪声(或称为“残差”)。这个神经网络通常是一个卷积神经网络(如U-Net)。模型输入的是加噪的图像 xtx_txt 和时间步 ttt,输出的是该时间步的噪声预测 ϵt\epsilon_tϵt,然后使用该预测来去除噪声。
1.3 DDPM的训练与生成
DDPM的训练目标是最小化模型输出的噪声预测与真实噪声之间的均方误差(MSE)。通过训练,模型能够学习到如何从每个加噪图像中恢复出真实的噪声,从而使得它能够在生成过程中有效地去除噪声。
2. 基于PyTorch实现DDPM
现在我们将结合代码来详细解析实现过程。以下是基于PyTorch实现DDPM的步骤。
2.1 MNIST数据集加载
在代码中,我们通过自定义 Dataset
类来加载MNIST数据集。MNIST数据集的图像是28x28的灰度图,通过二进制格式存储为IDX文件。我们通过 struct
和 numpy
来解析文件内容并将图像转为PyTorch tensor。
class MNIST_IDX(Dataset):
def __init__(self, image_