文章目录
一、前言
冻结 Batch Normalization 层(BN 层)是深度学习训练过程中常用的一种技巧,特别是在迁移学习或 微调(fine-tuning) 时。冻结 BN 层意味着在训练过程中不更新 BN 层的参数(如均值、方差、缩放和偏移参数)。
二、BN 层的基本工作原理
Batch Normalization 的作用是对每个 mini-batch 内的数据进行归一化,使其均值接近 0,方差接近 1。它有以下参数:
-
可学习参数:
- γ(缩放参数)
- β(偏移参数)
-
统计参数:
- running_mean(滑动平均均值)
- running_var(滑动平均方差)
在训练过程中:
- BN 层会计算每个 mini-batch 的均值和方差,并更新全局的 running_mean 和 running_var。
- 同时,γ 和 β 也会参与梯度更新。
三、冻结 BN 层的作用
- 稳定训练过程
- 在小批量数据或微调阶段: BN 层对每个 batch 计算统计量(均值、方差),但如果 batch size 很小,这些统计量可能不稳定,导致模型训练波动较大。冻结 BN 层可以防止这种不稳定。
- 保持原有模型的统计信息
- 在迁移学习中: 预训练模型的 BN 层已经在大规模数据集上学习到了较好的 running_mean 和 running_var。冻结它们可以保留这些信息,避免在新的小数据集上被破坏。
- 减少计算开销
- 冻结 BN 层意味着在前向传播时,不需要重新计算新的均值和方差。这在训练时可以提高计算效率。
- 防止过拟合
- 在某些场景中,BN 层的动态调整可能导致模型对当前数据集过于敏感,冻结它们可以减少过拟合风险。
四、冻结 BN 层的方法
在 PyTorch 中,你可以通过以下方法冻结 BN 层:
def freeze_bn(model):
for layer in model.modules():
if isinstance(layer, nn.BatchNorm2d) or isinstance