GAN网络的模型坍塌和不稳定的分析

本文深入探讨了生成对抗网络(GAN)的两大难题:模式坍塌和训练不稳定,解析了其背后的数学原理,包括KL和JS散度的作用,以及Wasserstein GAN如何解决梯度消失问题。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

众所周知,GAN异常强大,同时也非常难以训练。主要有以下亮点原因:

  • 模型坍塌(mode collapse)
  • 难以收敛和训练不稳定(convergence and instability)
    GAN网络的一般表达式可以表示为:
    利用minmax获得公式1
    原始公式
    给定G,求D的最优化2
    损失函数
    针对D进行求导:
    gan3
    获取最优解:
    gan4
    最优解结果:
    D最优解
    原有公式
    KL和JS散度表达式:
    KL和JS散度表达式
    GAN表达式:
    在这里插入图片描述
    增加一项,表示生成网络G的损失函数:
    G
    简化:
    简化
    结合公式(6)和公式(8)可以得出:
    生成模型的结果
    公式(12)中的KL散度使得两个分布尽可能的小,而JS的负号使得两个分布近可能的大。
    两种情况:
    在这里插入图片描述
    第一种情况,生成了不真实的样本,惩罚很大;第二种情况,未能产生真实的样本,惩罚很小。第一种生成的样本不准确,第二种生成的样本不够多样。基于这个原理,G 倾向于生成重复但是安全的样本,而不愿意冒险生成不同但不安全的样本,这会导致模式坍塌(mode collapse)问题。
    当然很有可能出现两个分布不相关的情况3,这种情况就会导致称为一个常数,也就出现了梯度消失的情况。 所以判别器训练得太好,生成器梯度消失,生成器loss降不下去;判别器训练得不好,生成器梯度不准,四处乱跑。只有判别器训练得不好不坏才行,但是这个火候又很难把握。
    对于另外一种G网络的表现形式,可以总结为:D表现越好,G的梯度消失越严重。

  1. Generative Adversarial Networks (GANs): Challenges, Solutions, and Future Directions ↩︎

  2. A Review on Generative Adversarial Networks: Algorithms, Theory, and Applications ↩︎

  3. 令人拍案叫绝的Wasserstein GAN/https://zhuanlan.zhihu.com/p/25071913 ↩︎

### Wasserstein GAN (WGAN) 网络模型介绍及原理 #### 一、背景引入 传统生成对抗网络GANs)在训练过程中面临诸多挑战,诸如模式崩溃不稳定收敛等问题。为了克服这些局限性,研究者们提出了多种改进方案,其中一种便是基于地球移动距离(Earth Mover's Distance, EMD),亦称为Wasserstein距离的新型GAN架构——Wasserstein GAN (WGAN)[^1]。 #### 二、核心概念解析 WGAN的核心在于采用不同的损失函数来替代原始GAN中的JS散度。具体来说,在标准GAN框架下,判别器试图最大化真实样本得分而最小化伪造样本得分;而在WGAN里,则通过优化两个分布间的EMD实现这一目的。这种改变不仅使得训练过程更加稳定,而且有效缓解了模式坍塌现象的发生概率[^2]。 #### 三、数学基础阐述 设\( P_r \)表示来自实际数据集的概率分布,\( P_g \)代表由生成器产生的假样本所遵循的概率分布。那么两者之间的Wasserstein距离可以定义为: \[ W(P_r,P_g)=\inf_{\gamma\in\Pi(P_r,P_g)}\mathbb{E}_{(x,y)\sim\gamma}[d(x,y)] \] 这里 \( d(\cdot,\cdot) \) 是成本函数,默认情况下可取欧几里得距离;\( \Pi(P_r,P_g) \) 表示所有可能联合分布集合,满足边缘条件分别为 \( P_r \),\( P_g \) 。然而直接求解上述表达式较为困难,因此实践中通常借助Kantorovich-Rubinstein对偶定理将其转化为更易于处理的形式: \[ W(P_r,P_g)=\sup_{||f||_L≤1}\left[\mathbb{E}_{x∼Pr}f(x)-\mathbb{E}_{y∼Pg}f(y)\right]\] 此处 \( f \) 称作Critic 函数,并且要求其 Lipschitz 常数不超过1。这也就意味着我们需要找到一个合适的 Critic 来逼近最优解,从而间接估计两者的Wasserstein距离[^3]。 #### 四、算法流程概述 - **初始化**:设定超参数并随机初始化权重; - **迭代更新**: - 对于每一轮epoch, - 训练critic T次,每次采样一批真实图片以及对应的噪声向量作为输入给到critic网络中前向传播得到评分值,接着利用反向传播调整参数以减小loss; - 更新一次generator,同样地先产生一批伪图像再送入已经训练好的critic获取反馈信号指导自身学习方向直至达到预设轮次数目停止循环。 值得注意的是,在此期间还需施加一定的约束措施确保 critic 的 lipschitz 属性得以保持,比如裁剪权值范围或者加入梯度惩罚项等方式均可达成该目标[^4]。 ```python import torch.nn as nn class Generator(nn.Module): def __init__(self): super().__init__() self.main = ... # 定义生成器结构 def forward(self, input): return self.main(input) class Discriminator(nn.Module): def __init__(self): super().__init__() self.main = ... # 定义鉴别器(Critic)结构 def forward(self, input): return self.main(input).view(-1) ```
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值