改进的分布匹配蒸馏以快速图像合成
paper是MIT 发表在NIPS 2024的工作
paper title:Improved Distribution Matching Distillation for Fast Image Synthesis
Code:;链接
Abstract
最近的研究表明,通过蒸馏昂贵的扩散模型,可以生成高效的单步生成器。其中,分布匹配蒸馏(Distribution Matching Distillation,DMD)能够生成在分布上与教师模型匹配的单步生成器,即蒸馏过程并不强制与教师模型的采样轨迹一一对应。然而,为了确保实际训练的稳定性,DMD 需要额外的回归损失,该损失是通过教师模型使用确定性采样器进行多步采样所生成的大量噪声-图像对计算得到的。这不仅对大规模文本到图像合成计算代价高昂,而且还限制了学生模型的质量,使其过于依赖教师模型的原始采样路径。我们提出 DMD2,这是一套能够消除这些限制并改进 DMD 训练的技术。首先,我们去除了回归损失,从而无需昂贵的数据集构建。我们发现,训练过程中的不稳定性源于“假”评估器无法准确估计生成样本的分布,因此我们提出了一种双时间尺度更新规则来解决该问题。其次,我们在蒸馏过程中引入了 GAN 损失,以区分生成样本和真实图像。这使得学生模型能够在真实数据上进行训练,从而缓解教师模型提供的“真实”分数估计的不完美问题,并提升生成质量。第三,我们引入了一种新的训练方法,使学生模型能够进行多步采样,并通过在训练过程中模拟推理时的生成器输入样本来解决训练-推理输入不匹配的问题。综合来看,我们的改进在单步图像生成上设立了新的基准,在 ImageNet-64×64 上的 FID 分数达到 1.28,在零样本 COCO 2014 上的 FID 分数达到 8.35,超越了原始教师模型,同时推理成本减少了 500 倍。此外,我们的方法能够通过蒸馏 SDXL 生成百万像素级图像,在少步推理方法中展现了卓越的视觉质量,并超越了教师模型和预训练模型。
1 Introduction
扩散模型在视觉生成任务中取得了前所未有的质量 [1–8]。然而,它们的采样过程通常需要进行多次迭代去噪,每一步都需要通过神经网络进行前向传播。这使得高分辨率文本到图像的合成变得缓慢且昂贵。为了解决这一问题,已经开发了许多蒸馏方法,将教师扩散模型转换为高效的、少步的学生生成器 [9–20]。然而,这些方法通常会导致质量下降,因为学生模型通常通过损失函数学习教师模型的成对噪声到图像的映射,但难以完美地模仿其行为。
图1:1024×1024个样品由我们的4步中的生成器从SDXL蒸馏出来。请放大以获取详细信息
然而,需要注意的是,旨在匹配分布的损失函数(如 GAN [21] 或 DMD [22] 损失)并不需要精确学习从噪声到图像的特定路径,因为它们的目标是使学生模型在分布层面上与教师模型对齐——通常通过最小化 Jensen-Shannon (JS) 或近似的 Kullback-Leibler (KL) 散度来匹配学生和教师的输出分布。
具体而言,DMD [22] 在蒸馏 Stable Diffusion 1.5 方面已经取得了最先进的结果,但相比 GAN 方法 [23–29] 仍然研究较少。一个可能的原因是,DMD 仍然需要额外的回归损失以确保训练稳定性。这反过来需要通过运行教师模型的完整采样步骤来创建数百万个噪声-图像对,对于文本到图像合成而言,这种方式成本极高。此外,回归损失也削弱了 DMD 无监督分布匹配目标的关键优势,因为它导致学生模型的质量被教师模型所限制。
在本论文中,我们展示了一种去除 DMD 回归损失的方法,同时不影响训练稳定性。然后,我们通过将 GAN 框架整合到 DMD 中来突破分布匹配的极限,并通过一种新颖的训练方法(我们称之为“反向模拟”)实现了少步采样。总体而言,我们的贡献使得生成模型达到了最先进的水平,并且使用最少 4 步采样即可超越其教师模型。我们的方法被称为 DMD2。
在单步图像生成方面,我们的方法取得了最先进的结果,在 ImageNet-64x64 上的 FID 分数达到 1.28,在零样本 COCO 2014 上达到 8.35,树立了新的基准。我们进一步展示了方法的可扩展性,通过从 SDXL 蒸馏高质量的百万像素级图像,在少步方法中建立了新的标准。
简而言之,我们的贡献如下:
-
我们提出了一种新的分布匹配蒸馏技术,该技术无需回归损失即可实现稳定训练,从而消除了昂贵数据收集的需求,并允许更灵活和可扩展的训练。
-
我们表明,在 DMD [22] 中,去除回归损失会导致训练不稳定的原因在于训练不足的 伪扩散判别器(fake diffusion critic),并提出了一种双时间尺度更新规则来解决该问题。
-
我们在 DMD 框架中整合了 GAN 目标,其中判别器被训练用于区分学生生成器样本与真实图像。这种额外的监督在 分布(distribution) 级别上运行,更符合 DMD 的分布匹配理念,相较于原始回归损失方法能更有效地减少教师扩散模型的近似误差,从而提高生成图像质量。
-
原始 DMD 仅支持单步学生模型,而我们引入了一种新技术,使其支持多步生成器。与先前的多步蒸馏方法不同,我们通过在训练过程中模拟推理时间生成器输入,避免了训练和推理之间的 域不匹配(domain mismatch),从而提高了整体性能。
图2:1024×1024个样品由我们的4步中的生成器从SDXL蒸馏出来。请放大以获取详细信息。
2 Related Work
扩散蒸馏(Diffusion Distillation)。 近年来,扩散加速技术主要聚焦于通过蒸馏来加速生成过程 [9, 10, 13–20, 22, 23, 30]。通常,它们训练一个生成器来近似教师模型的常微分方程(ODE)采样轨迹,从而以更少的采样步数生成图像。值得注意的是,Luhman 等人 [16] 预先计算了一个包含噪声和图像对的数据集,该数据集由教师模型使用 ODE 采样器生成,并用于训练学生模型,使其在单次网络评估中学习映射关系。后续工作,如 Progressive Distillation [10, 13],则消除了离线预计算这一数据集的需求,而是通过迭代训练一系列学生模型,使每个模型的采样步数减半。另一种互补技术,Instaflow [11] 通过拉直 ODE 轨迹,使得其更易于用单步学生模型进行近似。Consistency Distillation [9, 12, 19, 26, 31, 32] 和 TRACT [33] 训练学生模型,使其在 ODE 轨迹上的任何时间步均保持自洽,并与教师模型的输出一致。
对抗生成网络(GANs)。 另一种研究方向是使用对抗训练,使学生模型在更广泛的分布级别上对齐教师模型。在 ADD [23] 中,生成器从扩散模型初始化权重,并使用图像空间分类器 [34] 进行训练。在此基础上,LADD [24] 利用一个预训练的扩散模型作为判别器,并在潜在空间中运行,从而提升可扩展性,并支持更高分辨率的合成。受 DiffusionGAN [28, 29] 启发,UFOGen [25] 在判别器执行 真实 vs. 伪造 分类之前引入噪声注入,以平滑输出分布,从而稳定训练动态。一些最近的研究结合了对抗目标和蒸馏损失,以保持原始采样轨迹。例如,SDXL-Lightning [27] 结合了 DiffusionGAN 损失 [25] 和 Progressive Distillation 目标 [10, 13],而 Consistency Trajectory Model [26] 则将 GAN [35] 与改进版的一致性蒸馏 [9] 相结合。
分数蒸馏(Score Distillation)。 该技术最初在文本到 3D(text-to-3D)合成任务中被提出 [36–39],其中使用预训练的文本到图像扩散模型作为分布匹配损失。这些方法通过对齐渲染视图与文本条件分布,优化 3D 对象,利用的是预训练扩散模型的分数信息。最近的研究扩展了分数蒸馏方法 [36, 37, 40–42],将其应用于扩散蒸馏 [22, 43–45]。值得注意的是,DMD [22] 通过最小化近似 KL 散度进行蒸馏,其梯度表示为两个分数函数的差值:一个固定且预训练的目标分布,另一个动态训练的生成器输出分布。DMD 使用扩散模型对两个分数函数进行参数化。该训练目标比基于 GAN 的方法更稳定,并在单步图像合成方面表现出色。然而,一个重要的注意事项是,DMD 需要一个回归损失来确保训练稳定性,该损失是通过预计算的噪声-图像对进行计算的,类似于 Luhman 等人 [16] 的方法。我们的工作消除了这一需求。我们提出了一系列技术,使 DMD 训练过程在不使用回归正则项的情况下仍然保持稳定,从而显著降低了因数据配对预计算而产生的计算成本。此外,我们扩展了 DMD,使其支持多步生成,并结合了 GAN 和分布匹配方法的优势 [22, 44, 45],最终在文本到图像合成任务上实现了最先进的结果。
3 Background: Diffusion and Distribution Matching Distillation
本节简要介绍了扩散模型(Diffusion Models)和分布匹配蒸馏(Distribution Matching Distillation, DMD)。
扩散模型通过迭代去噪生成图像。在前向扩散过程中,噪声被逐步添加到从数据分布 p real p_{\text{real}} preal 采样的数据 x x x 上,使其受到高斯噪声的污染,经过预定的 T T T 个步骤后,在每个时间步 t t t,扩散样本遵循条件分布 p real , t ( x t ) = ∫ p real ( x ) q t ( x t ∣ x ) d x p_{\text{real}, t}(x_t) = \int p_{\text{real}}(x) q_t(x_t | x) dx preal,t(xt)=∫preal(x)qt(xt∣x)dx,其中 q t ( x t ∣ x ) ∼ N ( α t x , σ t 2 I ) q_t(x_t | x) \sim \mathcal{N}(\alpha_t x, \sigma_t^2 I) qt(xt∣x)∼N(αtx,σt2I), α t , σ t > 0 \alpha_t, \sigma_t > 0 αt,σt>0 由噪声调度确定 [46, 47]。扩散模型通过预测一个去噪估计值 μ ( x t , t ) \mu(x_t, t) μ(xt,t) 来反转这个破坏过程,该预测值由当前的噪声样本 x t x_t xt 和时间步 t t t 进行条件计算,最终引导生成来自目标分布 p real p_{\text{real}} preal 的图像。训练后,生成的密度关系到数据对数似然函数的梯度,即扩散分布的分数函数 [47]:
s real ( x t , t ) = ∇ x log p real , t ( x t ) = − x t − α t μ real ( x t , t ) σ t 2 s_{\text{real}}(x_t, t) = \nabla_x \log p_{\text{real}, t}(x_t) = -\frac{x_t - \alpha_t \mu_{\text{real}}(x_t, t)}{\sigma_t^2} sreal(xt,t)=∇xlogpreal,t(xt)=−σt2xt−αtμreal(xt,t)
通常,生成一张图像需要几十到上百次去噪步骤 [48–51]。
分布匹配蒸馏(DMD)
DMD 通过最小化目标分布 p real , t p_{\text{real}, t} preal,t 和扩散生成器的输出分布 p fake , t p_{\text{fake}, t} pfake,t 之间的近似 Kullback-Leibler(KL)散度的期望值 [22],将多步扩散模型蒸馏为单步生成器 G G G。由于 DMD 通过梯度下降训练 G G G,其损失的梯度可以计算为两个分数函数的差异:
∇ DMD = E t ( ∇ θ KL ( p fake , t ∣ ∣ p real , t ) ) = − E t ( ∫ ( s real ( F ( G θ ( z ) , t ) , t ) − s fake ( F ( G θ ( z ) , t ) , t ) ) d G θ ( z ) d θ ) \nabla_{\text{DMD}} = \mathbb{E}_t \left( \nabla_\theta \text{KL}(p_{\text{fake}, t} || p_{\text{real}, t}) \right) = - \mathbb{E}_t \left( \int \left( s_{\text{real}}(F(G_\theta(z), t), t) - s_{\text{fake}}(F(G_\theta(z), t), t) \right) \frac{dG_\theta(z)}{d\theta} \right) ∇DMD=Et(∇θKL(pfake,t∣∣preal,t))=−Et(∫(sreal(F(Gθ(z),t),t)−sfake(F(Gθ(z),t),t))dθdGθ(z))
其中 z ∼ N ( 0 , I ) z \sim \mathcal{N}(0, I) z∼N(0,I) 是随机高斯噪声输入, θ \theta θ 是生成器参数, F F F 是前向扩散过程(即噪声注入), s real s_{\text{real}} sreal 和 s fake s_{\text{fake}} sfake 是分别基于真实数据分布和假数据分布的分数函数,分别由扩散模型 μ real \mu_{\text{real}} μreal 和 μ fake \mu_{\text{fake}} μfake 计算(见公式 (1))。DMD 使用冻结的预训练扩散模型作为教师 μ real \mu_{\text{real}} μreal,并在训练 G G G 的同时动态更新 μ fake \mu_{\text{fake}} μfake,使用从单步生成器采样的数据进行去噪分数匹配 [22, 46]。
Yin 等人 [22] 发现,为了减少分布匹配梯度(公式 (2))的方差并提高单步模型的质量,还需要一个额外的回归项 [16]。他们通过收集大量的噪声-图像对 ( z , y ) (z, y) (z,y) 进行训练,其中 y y y 是由教师扩散模型生成的图像,而 z z z 是通过确定性采样 [48, 49, 52] 从噪声映射到 y y y 的噪声输入。给定相同的输入噪声 z z z,回归损失可以迫使生成器的输出与教师预测值对齐:
L reg = E ( z , y ) ∼ D d ( G θ ( z ) , y ) \mathcal{L}_{\text{reg}} = \mathbb{E}_{(z,y) \sim D} d(G_\theta(z), y) Lreg=E(z,y)∼Dd(Gθ(z),y)
其中 d d d 是一个距离函数,例如 LPIPS [53]。尽管对于小数据集(如 CIFAR-10),收集此类数据集的成本可以忽略不计,但对于大规模文本到图像任务或需要提示调节的模型而言,这种成本可能成为训练瓶颈 [54–56]。例如,在 SDXL [57] 的实验中,生成一个噪声-图像对大约需要 5 秒,约 700 个 A100 GPU 天来构建 LAION 6.0 数据集 [58]。如 Yin 等人 [22] 所述,这种数据集构造成本已经比总训练计算量高出 4 倍以上(详情见附录 F)。该正则化目标虽然有助于 DMD 让学生模型在分布上接近教师模型,但它也会鼓励生成器在教师的采样路径上产生偏差。
4 Improved Distribution Matching Distillation
我们在DMD算法[22]中重新审视多个设计选择,并确定重大改进。
图 3:我们的方法将一个昂贵的扩散模型(灰色,右侧)蒸馏为一个单步或多步生成器(红色,左侧)。我们的训练过程交替进行两个步骤:
- 使用隐式分布匹配目标的梯度(红色箭头)和 GAN 损失(绿色)优化生成器。
- 训练一个分数函数(蓝色)来建模由生成器产生的“假”样本的分布,同时训练一个 GAN 鉴别器(绿色)来区分假样本和真实图像。
学生生成器可以是单步或多步模型,如图所示,其中包含一个中间步骤输入。
4.1 Removing the regression loss: true distribution matching and easier large-scale training
DMD [22] 中使用的回归损失 [16] 旨在确保模式覆盖性和训练稳定性,但正如我们在第 3 节讨论的那样,它使大规模蒸馏变得繁琐,并且与分布匹配的理念相悖,因此本质上限制了蒸馏生成器的性能,使其只能达到教师模型的水平。我们的第一个改进是去除这一损失。
4.2 Stabilizing pure distribution matching with a Two Time-scale Update Rule
直接去除 DMD 中的回归目标(如方程 (3) 所示)会导致训练不稳定,并显著降低质量(表 3)。例如,我们观察到生成样本的平均亮度及其他统计量波动显著,且无法收敛到稳定点(详见附录 C)。我们将这种不稳定性归因于伪扩散模型 μ fake \mu_{\text{fake}} μfake 中的近似误差,该模型无法准确跟踪伪分数,因为它是针对生成器的非平稳输出分布动态优化的。这导致了近似误差和偏差的生成器梯度([30] 也讨论了这一点)。我们采用 Heusel 等人 [59] 提出的两时间尺度更新规则来解决这一问题。具体来说,我们以不同的频率训练 μ fake \mu_{\text{fake}} μfake 和生成器 G G G,以确保 μ fake \mu_{\text{fake}} μfake 能够准确跟踪生成器的输出分布。我们发现,在每次生成器更新时使用 5 次伪分数更新,而不使用回归损失,能够提供良好的稳定性,并在 ImageNet 上的质量(表 3)与原始 DMD 相匹配,同时大幅加快收敛速度。进一步的分析包含在附录 C 中。
4.3 Surpassing the teacher model using a GAN loss and real data
我们的模型在无需昂贵数据集构建(表 3)的情况下,达到了与 DMD [22] 相当的训练稳定性和性能。然而,蒸馏生成器与教师扩散模型之间仍然存在性能差距。我们推测,这一差距可能归因于 DMD 中真实分数函数 μ real \mu_{\text{real}} μreal 的近似误差,这些误差会传播到生成器并导致次优结果。由于 DMD 的蒸馏模型从未在真实数据上训练,因此它无法从这些误差中恢复。
我们通过在训练管道中引入额外的 GAN 目标来解决这一问题,其中判别器被训练用于区分真实图像和生成器生成的图像。GAN 分类器使用真实数据进行训练,因此不受教师模型的限制,这使得学生生成器有可能在样本质量上超越教师模型。我们对 DMD 进行 GAN 分类器的集成采用了极简设计:我们在假扩散去噪器(见图 3)的瓶颈处添加了一个分类分支。
分类分支和 UNet 中的上游编码器特征通过最大化标准的非饱和 GAN 目标进行训练:
L GAN = E x ∼ p real , t ∼ [ 0 , T ] [ log D ( F ( x , t ) ) ] + E z ∼ p noise , t ∼ [ 0 , T ] [ − log ( D ( F ( G θ ( z ) , t ) ) ) ] , \mathcal{L}_{\text{GAN}} = \mathbb{E}_{x \sim p_{\text{real}}, t \sim [0, T]} [\log D(F(x, t))] + \mathbb{E}_{z \sim p_{\text{noise}}, t \sim [0, T]} [-\log (D(F(G_{\theta}(z), t)))], LGAN=Ex∼preal,t∼[0,T][logD(F(x,t))]+Ez∼pnoise,t∼[0,T][−log(D(F(Gθ(z),t)))],
(4)
其中, D D D 是判别器, F F F 是前向扩散过程(即第 3 节定义的噪声注入),其中噪声级别对应时间步 t t t。生成器 G G G 通过最小化该目标进行训练。我们的设计受到以往利用扩散模型作为判别器的相关工作的启发 [24, 25, 27]。需要注意的是,该 GAN 目标与分布匹配哲学更为一致,因为它不需要配对数据,并且独立于教师模型的采样轨迹。
4.4 Multi-step generator
通过所提出的改进,我们能够在 ImageNet 和 COCO 上匹配教师扩散模型的性能(见表 1 和表 5)。然而,我们发现,像 SDXL [57] 这样的大规模模型仍然难以蒸馏为一步生成器,这是因为模型容量有限,且优化空间复杂,难以直接学习从噪声到高度多样和精细图像的映射。这促使我们扩展 DMD 以支持多步采样。
我们在训练和推理期间固定一个包含 N N N 个时间步的预设调度 { t 1 , t 2 , … , t N } \{t_1, t_2, \dots, t_N\} {t1,t2,…,tN},确保训练和推理过程一致。在推理时,每个时间步交替进行去噪和噪声注入步骤,遵循一致性模型 [9] 以提高样本质量。具体而言,从高斯噪声 z 0 ∼ N ( 0 , I ) z_0 \sim \mathcal{N}(0, I) z0∼N(0,I) 开始,我们交替进行去噪更新 x ^ t i = G θ ( x t i , t i ) \hat{x}_{t_i} = G_{\theta}(x_{t_i}, t_i) x^ti=Gθ(xti,ti),以及前向扩散步骤 x t i + 1 = α t i + 1 x ^ t i + σ t i + 1 ϵ x_{t_{i+1}} = \alpha_{t_{i+1}} \hat{x}_{t_i} + \sigma_{t_{i+1}} \epsilon xti+1=αti+1x^ti+σti+1ϵ,其中 ϵ ∼ N ( 0 , I ) \epsilon \sim \mathcal{N}(0, I) ϵ∼N(0,I),直至获得最终图像 x ^ t N \hat{x}_{t_N} x^tN。我们的 4 步模型采用如下调度:999, 749, 499, 249,对应于训练 1000 步的教师模型。
4.5 Multi-step generator simulation to avoid training/inference mismatch
以往的多步生成器通常被训练用于去噪真实的带噪图像 [23,24,27]。然而,在推理过程中,除了第一步从纯噪声开始外,生成器的输入来自先前生成器采样步骤 x ^ t i \hat{x}_{t_i} x^ti。这导致了训练-推理不匹配问题,从而影响生成质量(见图 4)。我们通过在训练过程中用当前学生生成器运行数步后生成的合成带噪图像 x t i x_{t_i} xti 替换真实带噪图像来解决这一问题,类似于我们的推理流程(§ 4.4)。这种方法是可行的,因为与教师扩散模型不同,我们的生成器只运行少量步数。随后,生成器对这些模拟图像进行去噪,并使用提出的损失函数对输出进行监督。使用合成带噪图像可以避免训练-推理不匹配,并提升整体性能(见 Sec. 5.3)。
图 4:大多数多步蒸馏方法在训练过程中(左)使用前向扩散来模拟中间步骤。这会导致训练时模型所见输入与推理时的输入不匹配。我们提出的解决方案(右)通过在训练过程中模拟推理时的反向过程来缓解这一问题。
一项同时进行的研究 Imagine Flash [60] 提出了一种类似的技术。他们的反向蒸馏算法与我们的动机相似,旨在通过在训练时使用学生生成的图像作为后续采样步骤的输入来减少训练和测试之间的差距。然而,他们并未完全解决不匹配问题,因为回归损失的教师模型仍然受到训练-测试差距的影响——它从未使用合成图像进行训练。这种误差会沿着采样路径累积。相比之下,我们的分布匹配损失不依赖于输入到学生模型的数据,从而缓解了这一问题。
4.6 Putting everything together
总的来说,我们的蒸馏方法取消了 DMD [22] 对预计算噪声-图像对的严格要求。此外,它进一步整合了 GAN 的优势,并支持多步生成器。如图 3 所示,从预训练的扩散模型开始,我们交替优化生成器 Gθ,以最小化原始分布匹配目标以及 GAN 目标,同时优化假分数估计器 μfake,使用假数据上的去噪分数匹配目标和 GAN 分类损失。为了确保假分数估计的准确性和稳定性,尽管其是在在线优化的,我们仍以比生成器更高的频率进行更新(5 次 vs. 1 次)。