1. 知识蒸馏的核心代码结构
知识蒸馏主要包含三部分:
-
教师模型(Teacher):通常是已经训练好的大模型,保持参数不更新。
-
学生模型(Student):结构较轻量,需要通过蒸馏学习教师的知识。
-
蒸馏损失(Distillation Loss):融合了学生的真实标签损失和模仿教师输出的损失。
2. 基础蒸馏代码示例
import torch
import torch.nn.functional as F
def distillation_loss(student_logits, teacher_logits, labels, temperature=4.0, alpha=0.7):
"""
计算蒸馏损失,包括:
- 软标签KL散度(学生 vs 教师的soft logits)
- 真实标签交叉熵损失
参数:
- student_logits: 学生模型输出,未经过softmax
- teacher_logits: 教师模型输出,未经过softmax
- labels: 真实标签
- temperature: 温度系数,用于软化概率分布
- alpha: 两部分损失权重平衡系数,alpha越大,强调蒸馏损失
返回:
- 综合损失
"""
# 软标签概率分布
student_soft = F.log_softmax(student_logits / temperature, dim=1)
teacher_soft = F.softmax(teacher_logits / temperature, dim=1)
# KL散度损失(蒸馏损失)
kd_loss = F.kl_div(student_soft, teacher_soft, reduction='batchmean') * (temperature ** 2)
# 真实标签的交叉熵损失
ce_loss = F.cross_entropy(student_logits, labels)
# 总损失加权求和
loss = alpha * kd_loss + (1 - alpha) * ce_loss
return loss
3. 训练示例代码框架
def train_epoch(student_model, teacher_model, dataloader, optimizer, device,
temperature=4.0, alpha=0.7):
student_model.train()
teacher_model.eval() # 教师模型不更新参数
total_loss = 0
for imgs, labels in dataloader:
imgs, labels = imgs.to(device), labels.to(device)
optimizer.zero_grad()
# 教师模型预测,关闭梯度
with torch.no_grad():
teacher_logits = teacher_model(imgs)
# 学生模型预测
student_logits = student_model(imgs)
# 计算蒸馏损失
loss = distillation_loss(student_logits, teacher_logits, labels, temperature, alpha)
loss.backward()
optimizer.step()
total_loss += loss.item() * imgs.size(0)
avg_loss = total_loss / len(dataloader.dataset)
return avg_loss
4. 其他蒸馏方法的代码示例
4.1 特征蒸馏(Feature Distillation)
def feature_distillation_loss(student_features, teacher_features):
"""
计算学生和教师特征之间的L2距离作为蒸馏损失
"""
loss = 0
for sf, tf in zip(student_features, teacher_features):
loss += F.mse_loss(sf, tf.detach())
return loss
4.2 注意力蒸馏(Attention Distillation)
def attention_map(feature):
# 计算注意力图,按通道平方后求和
return feature.pow(2).mean(dim=1, keepdim=True)
def attention_distillation_loss(student_features, teacher_features):
loss = 0
for sf, tf in zip(student_features, teacher_features):
student_attention = attention_map(sf)
teacher_attention = attention_map(tf).detach()
loss += F.mse_loss(student_attention, teacher_attention)
return loss
4.3 多教师蒸馏示例(平均融合)
def multi_teacher_distillation_loss(student_logits, teacher_logits_list, temperature=4.0):
teacher_avg = sum(teacher_logits_list) / len(teacher_logits_list)
student_soft = F.log_softmax(student_logits / temperature, dim=1)
teacher_soft = F.softmax(teacher_avg / temperature, dim=1)
kd_loss = F.kl_div(student_soft, teacher_soft, reduction='batchmean') * (temperature ** 2)
return kd_loss
5. 自蒸馏示例(中间层特征自蒸馏)
def self_distillation_loss(student_features):
teacher_feat = student_features[-1].detach()
loss = 0
for feat in student_features[:-1]:
loss += F.mse_loss(feat, teacher_feat)
return loss
6. 训练过程集成蒸馏
将真实标签损失和蒸馏损失结合:
loss_cls = F.cross_entropy(student_logits, labels)
loss_kd = distillation_loss(student_logits, teacher_logits, labels, temperature, alpha)
loss = loss_cls * (1 - alpha) + loss_kd * alpha
或者多目标融合:
loss = loss_cls + lambda_kd * loss_kd + lambda_feat * loss_feat + ...
7. 代码调试和实用建议
-
温度参数:一般4~10范围效果好,过大会使soft标签过于均匀,过小则接近硬标签。
-
权重参数:蒸馏损失权重alpha或lambda需要根据任务调试。
-
特征匹配:学生和教师中间层特征维度要匹配,必要时用1×1卷积做变换。
-
冻结教师:训练过程中一定要设置教师模型eval且不更新参数。
-
批量大小:蒸馏时建议批量大小适中,过小会影响统计稳定性。
8. 总结
功能模块 | 作用 |
---|---|
教师模型 | 提供软标签或中间特征作为知识源 |
学生模型 | 学习教师知识,提升性能与压缩模型大小 |
蒸馏损失 | 将真实标签和教师软标签结合,优化学生 |