knowledge distillation论文阅读之:Learning from a Lightweight Teacher for Efficient Knowledge Distillation

本文介绍论文核心思路,涵盖经典KD、无教师蒸馏(TF - KD)和轻量级知识蒸馏(LW - KD)。LW - KD基于MNIST设计合成数据集,训练轻量级教师网络,结合改进的KD损失函数提升学生网络表现。详细阐述了合成数据集生成、软标签生成、损失函数改进及算法实现。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

论文核心思路:

1. 经典的 KD

传统的经典 KD 方法试图让 student 网络的 soft target 尽可能地接近 teacher 网络产生的 soft target;这是通过在KD学习框架添加另一个损失函数实现的,这个新的损失函数用以补充用标准交叉熵损失测量 groundtruth 和预测值之间产生的差距,公式(1)。
在这里插入图片描述
在这里插入图片描述

  • KL(·) 代表 KL 散度
  • ∣ X t r ∣ |X^{tr}| Xtr 代表的是在数据集 ∣ X t r ∣ |X^{tr}| Xtr 中的实例个数
  • p τ s p^s_τ pτs p τ t p^t_τ pτt 代表的分别是 student 网络和 teacher 网络产生的 soft target

软输出(soft output)决定于 logits 层的情况(logits被认为是输入 softmax 前的数据,是一个未归一化的概率分布)。以 ptt 为例,展示其计算的过程

在这里插入图片描述

  • z k t z_k^t zkt 是 teacher 网络 logits 层的输出,
  • k k k 是 teacher 网络分类种类的数量
  • τ τ τ 被称为是温度系数,是一个超参数,用来控制 logits 的范围(scale)

2. Teacher-free distillation(TF-KD)

在 teacher-free(TF)网络中,新奇的点是: p τ t p^t_τ pτt 不基于表现优秀的 teacher 网络来获得,而是通过人为制作的概率分布而学得(通过人为设计的概率分布来代替 teacher network 产生的 soft target)

3. Lightweight knowledge distillation(LW-KD)

在这里插入图片描述
这个图的结构表示了 LW-KD 的整个思路

  • 首先,LW-KD 为了实验,基于 MNIST 手写数据集设计了一个专用的合成数据集 Synthetic MNIST
  • 通过 synthetic MNIST 数据集来训练一个轻量级(lightweight)的 teacher network ,产生的 teacher soft output 结合改进过的 KD loss function 来提升 student network 的表现
  • 在这种实验的情况下,student network 可以提升其表现

下面的论文中,我们详细地描述了如何制定新的 LW-KD 的 loss function。

  • 我们用 T ( x ; θ t ) T(x;θ^t) T(x;θt) 来表示带有可学习参数 θ t θ^t θt 的 teacher network
  • S ( x ; θ s ) S(x;θ^s) S(x;θs) 来表示有可学习参数 θ s θ^s θs 的 student network

3.1 生成 synthetic MNIST 合成数据集

  • 我们基于 MNIST 数据集合成了一个 synthetic MNIST 用于训练 lightweight 的 teacher network
  • 本文提供了详细地产生 synthetic MNIST 数据集的算法:
    在这里插入图片描述
  • 对于训练student network S S S 所使用的合成数据集,算法只需要知道其总类数 K K K 以及每个类的实例数。
  • 该算法的创新之处在于将不同的基本数字图像结合在一起,合成出对应较大数值的新图像,每个数值都可以表示一个特定的类;例如,如果 07,30,41 这种数字代表的就是 最多 100 类(0-99),而 010,100,379 这种合成数据集可以代表 1000 个类(0-999);
  • 因此,这个算法可以支持不同数量的类

3.2 生成 soft target 软标签

  • 给定合成数据集 ( X s y n , C s y n ) (X^{syn},C^{syn}) (Xsyn,Csyn)
  • 利用这个数据集训练表现良好的 teacher network T T T,teacher network 根据合成数据集产生的软标签为 p τ l t ( k ∣ x s y n ) , k ∈ { 1 , . . . K } p_τ^{lt}(k|x^{syn}), k∈\{1,...K\} pτlt(kxsyn),k{1,...K}, x s y n x^{syn} xsyn 是 合成数据集中的实例
  • LW-KD 的主要目的是实现 teacher network 产生的概率分布(teacher-soft target)的迁移,来实现对 student 网络性能的提升
  • 虽然如此,但是有个重要的问题:teacher network 和 student network 的训练数据集有很大的差异,所以他们所分得类不能够完全对齐(一致)
  • 在这种情况下,我们下意识的反应就是不能实现高效的知识迁移。然而事实并非如此,正如我们在论文中一再强调的那样,vanilla KD 扮演了 label smoothing (标签平滑)的功效,而 LW-KD 的作用就是利用 teacher network 产生的灵活的 类分布(class distribution)来平滑标签。因此,LW-KD 并不需要两个数据集严格的类一致(strict segmantic alignment between classes of two datasets)
  • 唯一对 soft target 进行的改进是 结合 soft target 中 最大的概率预测值目标数据集样本 x t r x^{tr} xtr 的 groundtruth 标签,保持其他的概率预测值不变。
  • 最后,我们使用以下方式来进行 soft target 的改进
    在这里插入图片描述
  • c c c 是 目标数据集实例 x t r x^{tr} xtr 的 groundtruth 标签类
  • m m m 是 teacher network 通过训练之后对于一个实例的预测类; m m m 求得的过程通过公式 m = a r g m a x k ( p τ l t ( k ∣ x s y n ) ) m=argmax_k(p_τ^{lt}(k|x^{syn})) m=argmaxk(pτlt(kxsyn))
  • shift 操作代表的是互换两个部分的值( p τ l t ( c ∣ x s y n ) p_τ^{lt}(c|x^{syn}) pτlt(cxsyn) 和 ( p τ l t ( m ∣ x s y n ) p_τ^{lt}(m|x^{syn}) pτlt(mxsyn)))
  • 通过以上方式,teacher network 生成的 soft target 对于目标实例 x s y n x^{syn} xsyn 来说就有了一定的现实意义。

3.3 改进 KD loss function:enhanced L K D L_{KD} LKD + L G A N L_{GAN} LGAN

3.3.1 enhanced L K D L_{KD} LKD
  • teacher 网络 T T T 已经产生了软标签 soft target
  • 我们可以采用标准的 KD 损失函数,它有(公式1 )所示的 KL 散度组成的损失、交叉熵损失组合而成:

在这里插入图片描述

  • p p p 是 ground_truth 的类分布(class distribution)
  • α α α 是超参数,用来控制这两种损失的比例

所以,通过整个损失函数对学生网络进行训练的过程可以如下理解:

第一部分: ( 1 − α ) H ( p , p s ) (1-α)H(p,p^s) (1α)H(p,ps)

  • p p p 是 ground_truth 的类分布(class distribution), p s p^s ps 代表的是训练过后的 student 网络对于类的预测概率分布, p p p p s p^s ps 用 cross-entropy 来训练,使得 p s p^s ps 可以尽可能地接近 p p p

第二部分: α L K L ( p ^ τ l t , p τ s ) αL_{KL}(\hat{p}^{lt}_τ,p^s_τ) αLKL(p^τlt,pτs)

  • p τ s p^s_τ pτs 代表 student 网络的 soft target 分布
  • p ^ τ l t \hat{p}^{lt}_τ p^τlt 代表 teacher 网络的 soft target 分布(根据 3.2 中的方式对 teacher 产生的软标签优化后的结果)
  • 通过 teacher 和 student 网络 soft target 的 KL 散度来使 student 网络的表现越来越像 teacher network
3.3.2 L G A N → L A D V L_{GAN}→L_{ADV} LGANLADV

因为我们可以把这个过程看做一个标签平滑的过程;LW-KD更进一步,通过有效生成对抗网络(GANs)实现:使student network 生成的 soft class distribution 与 teacher network 生成的 soft class distribution 无区别。

GAN 的核心操作:

  • 一方面,给定一个噪声向量 z z z ,通过 生成器 G G G z z z 映射到所需数据 x x x 的分布 G : z → x G: z→x G:zx
  • 另一方面,鉴别器 D D D 输出一个 x x x 实例是真实数据的概率 x → [ 0 , 1 ] x→[0,1] x[0,1]
  • GAN 的核心损失函数如下:
    在这里插入图片描述

其中,生成器 G G G 是根据鉴别器 D D D 反向传播的策略来调整,优化生成器的过程为:
在这里插入图片描述
对于 GAN 网络损失函数的理解,可以参考博文:GAN网络的损失函数

结合上述公式,我们把 teacher 网络获得的 soft target 的分布认为是真实数据,即 GAN 损失函数中 y y y 的位置,把 student 网络获得的 soft target 分布认为是 fake 的数据,即 GAN 损失函数中 z z z 的位置。采用两层全连通神经网络作为鉴别器 D D D。所以使用 GAN 定义的损失函数如下:
在这里插入图片描述
至此,我们把整个 LW-KD 损失函数定义如下:
在这里插入图片描述

  • β β β 是用来平衡 KD loss 和 GAN loss 的超参数

3.4 实现 LW-KD 的算法

算法2总结了LW-KD的整体学习过程,如第 6 - 13 行所示,将 teacher
network 在合成数据集上学习到的知识迁移到目标数据集上指导 student network 学习。
在这里插入图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

暖仔会飞

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值