一、技术原理(数学公式+示意图)
1.1 核心数学公式
温度缩放(Temperature Scaling):
软目标概率计算:
[ q_i = \frac{\exp(z_i/T)}{\sum_j \exp(z_j/T)} ]
其中:
- ( z_i ):类别i的logits输出
- ( T ):温度系数(( T > 1 )时概率分布更平滑)
损失函数:
[ \mathcal{L} = \alpha \cdot \mathcal{L}{KD}(q^T, p^T) + (1-\alpha) \cdot \mathcal{L}{CE}(y, p) ]
- ( \mathcal{L}_{KD} ):KL散度损失(教师vs学生)
- ( \mathcal{L}_{CE} ):交叉熵损失(学生vs真实标签)
- ( \alpha ):蒸馏损失权重(通常0.5-0.9)
案例对比:
在CIFAR-100分类任务中,当T=1时,教师模型对"狗"类别的概率为0.9,其他类别接近0;当T=5时,"狗"概率降为0.4,"猫"和"狼"分别提升到0.3和0.2,保留类别间关联性。
二、实现方法(PyTorch/TensorFlow代码片段)
2.1 PyTorch实现
# 教师模型推理(高温T=5)
teacher_logits = teacher_model(inputs)
soft_targets = torch.nn.functional.softmax(teacher_logits / T, dim=-1)
# 学生模型训练
student_logits = student_model(inputs)
loss_kd = nn.KLDivLoss(reduction='batchmean')