知识提炼部分方法代码实现

1. 知识蒸馏的核心代码结构

知识蒸馏主要包含三部分:

  • 教师模型(Teacher):通常是已经训练好的大模型,保持参数不更新。

  • 学生模型(Student):结构较轻量,需要通过蒸馏学习教师的知识。

  • 蒸馏损失(Distillation Loss):融合了学生的真实标签损失和模仿教师输出的损失。

2. 基础蒸馏代码示例

import torch
import torch.nn.functional as F

def distillation_loss(student_logits, teacher_logits, labels, temperature=4.0, alpha=0.7):
    """
    计算蒸馏损失,包括:
    - 软标签KL散度(学生 vs 教师的soft logits)
    - 真实标签交叉熵损失

    参数:
    - student_logits: 学生模型输出,未经过softmax
    - teacher_logits: 教师模型输出,未经过softmax
    - labels: 真实标签
    - temperature: 温度系数,用于软化概率分布
    - alpha: 两部分损失权重平衡系数,alpha越大,强调蒸馏损失

    返回:
    - 综合损失
    """

    # 软标签概率分布
    student_soft = F.log_softmax(student_logits / temperature, dim=1)
    teacher_soft = F.softmax(teacher_logits / temperature, dim=1)

    # KL散度损失(蒸馏损失)
    kd_loss = F.kl_div(student_soft, teacher_soft, reduction='batchmean') * (temperature ** 2)

    # 真实标签的交叉熵损失
    ce_loss = F.cross_entropy(student_logits, labels)

    # 总损失加权求和
    loss = alpha * kd_loss + (1 - alpha) * ce_loss

    return loss

 3. 训练示例代码框架

def train_epoch(student_model, teacher_model, dataloader, optimizer, device,
                temperature=4.0, alpha=0.7):
    student_model.train()
    teacher_model.eval()  # 教师模型不更新参数

    total_loss = 0
    for imgs, labels in dataloader:
        imgs, labels = imgs.to(device), labels.to(device)

        optimizer.zero_grad()

        # 教师模型预测,关闭梯度
        with torch.no_grad():
            teacher_logits = teacher_model(imgs)

        # 学生模型预测
        student_logits = student_model(imgs)

        # 计算蒸馏损失
        loss = distillation_loss(student_logits, teacher_logits, labels, temperature, alpha)

        loss.backward()
        optimizer.step()

        total_loss += loss.item() * imgs.size(0)

    avg_loss = total_loss / len(dataloader.dataset)
    return avg_loss

4. 其他蒸馏方法的代码示例

4.1 特征蒸馏(Feature Distillation)
def feature_distillation_loss(student_features, teacher_features):
    """
    计算学生和教师特征之间的L2距离作为蒸馏损失
    """

    loss = 0
    for sf, tf in zip(student_features, teacher_features):
        loss += F.mse_loss(sf, tf.detach())

    return loss
4.2 注意力蒸馏(Attention Distillation)
def attention_map(feature):
    # 计算注意力图,按通道平方后求和
    return feature.pow(2).mean(dim=1, keepdim=True)

def attention_distillation_loss(student_features, teacher_features):
    loss = 0
    for sf, tf in zip(student_features, teacher_features):
        student_attention = attention_map(sf)
        teacher_attention = attention_map(tf).detach()
        loss += F.mse_loss(student_attention, teacher_attention)
    return loss
4.3 多教师蒸馏示例(平均融合)
def multi_teacher_distillation_loss(student_logits, teacher_logits_list, temperature=4.0):
    teacher_avg = sum(teacher_logits_list) / len(teacher_logits_list)

    student_soft = F.log_softmax(student_logits / temperature, dim=1)
    teacher_soft = F.softmax(teacher_avg / temperature, dim=1)

    kd_loss = F.kl_div(student_soft, teacher_soft, reduction='batchmean') * (temperature ** 2)

    return kd_loss

5. 自蒸馏示例(中间层特征自蒸馏)

def self_distillation_loss(student_features):
    teacher_feat = student_features[-1].detach()
    loss = 0
    for feat in student_features[:-1]:
        loss += F.mse_loss(feat, teacher_feat)
    return loss

6. 训练过程集成蒸馏

将真实标签损失和蒸馏损失结合:

loss_cls = F.cross_entropy(student_logits, labels)
loss_kd = distillation_loss(student_logits, teacher_logits, labels, temperature, alpha)
loss = loss_cls * (1 - alpha) + loss_kd * alpha

或者多目标融合:

loss = loss_cls + lambda_kd * loss_kd + lambda_feat * loss_feat + ...

7. 代码调试和实用建议

  • 温度参数:一般4~10范围效果好,过大会使soft标签过于均匀,过小则接近硬标签。

  • 权重参数:蒸馏损失权重alpha或lambda需要根据任务调试。

  • 特征匹配:学生和教师中间层特征维度要匹配,必要时用1×1卷积做变换。

  • 冻结教师:训练过程中一定要设置教师模型eval且不更新参数。

  • 批量大小:蒸馏时建议批量大小适中,过小会影响统计稳定性。

8. 总结

功能模块作用
教师模型提供软标签或中间特征作为知识源
学生模型学习教师知识,提升性能与压缩模型大小
蒸馏损失将真实标签和教师软标签结合,优化学生

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值