李宏毅机器学习2023-HW13-Network Compression

Task

通过network compression完成图片分类,数据集跟hw3中11种食品分类一致。需要设计小模型student model,参数数目少于60k,训练该模型接近teacher model的精度test-Acc ≅ 0.902

Baseline

Simple Baseline

Just run the sample code

Medium Baseline

loss function定义为 KL divergence,公式如下:
L o s s = α T 2 × K L ( p ∣ ∣ q ) + ( 1 − α ) ( O r i g i n a l C r o s s E n t r o p y L o s s ) , w h e r e   p = s o f t m a x ( student’s logits T ) , a n d   q = s o f t m a x ( teacher’s logits T ) Loss=αT^2×KL(p||q)+(1−α)(Original Cross Entropy Loss),where \ p=softmax(\frac{\text{student's logits}}{T}),and\ q=softmax(\frac{\text{teacher's logits}}{T}) Loss=αT2×KL(p∣∣q)+(1α)(OriginalCrossEntropyLoss),where p=softmax(Tstudent’s logits),and q=softmax(Tteacher’s logits)
同时epoch可以增加到50,其他地方标有#medium的也要修改

# medium
def loss_fn_kd(student_logits, labels, teacher_logits, alpha=0.5, temperature=5.0):
    # ------------TODO-------------
    # Refer to the above formula and finish the loss function for knowkedge distillation using KL divergence loss and CE loss.
    # If you have no idea, please take a look at the provided useful link above.
    student_prob = F.softmax(student_logits/temperature, dim=-1)
    teacher_prob = F.softmax(teacher_logits/temperature, dim=-1)
    KL_loss = (teacher_prob * (teacher_prob.log() - student_prob)).mean()
    CE_loss = nn.CrossEntropyLoss()(student_logits, labels)
    loss = alpha * temperature**2* KL_loss + (1 - alpha) * CE_loss
    return loss

Strong Baseline

用depth-wise and point-wise convolutions修改model architecture+增加epoch

还可以应用中间层特征学习,这里我没有做

# strong
def dwpw_conv(in_channels, out_channels, kernel_size, stride=1, padding=0):
    return nn.Sequential(
        nn.Conv2d(in_channels, in_channels, kernel_size, stride=stride, padding=padding, groups=in_channels), #depthwise convolution
        nn.BatchNorm2d(in_channels),
        nn.ReLU(),
        nn.Conv2d(in_channels, out_channels, 1), # pointwise convolution
        nn.BatchNorm2d(out_channels),
        nn.ReLU(),
    )

Boss Baseline

Other advanced Knowledge Distillation(FitNet/RKD/DM) + 增加epoch + Depthwise & Pointwise Conv layer(深度可分离卷积)

当然也可以应用中间层特征学习

FitNet Knowledge Distillation

FitNet focuses on transferring knowledge from intermediate feature representations (hidden layers) instead of just using the output logits. The student model is trained to mimic the feature maps from certain layers of the teacher model.

#boss
def loss_fn_fitnet(teacher_feature, student_feature, labels, alpha=0.5):
    """
    FitNet Knowledge Distillation Loss Function.
    
    Args:
    - teacher_feature: The feature maps from a hidden layer of the teacher model.
    - student_feature: The feature maps from the corresponding hidden layer of the student model.
    - labels: Ground truth labels for the task.
    - alpha: Weighting factor for the feature distillation loss.
    
    Returns:
    - loss: Combined loss with cross-entropy and feature map alignment.
    """
    # Mean squared error loss to align feature maps of teacher and student
    feature_loss = F.mse_loss(student_feature, teacher_feature)

    # Hard label cross-entropy loss for the student output (classification)
    hard_loss = F.cross_entropy(student_feature, labels)
    
    # Combine both losses
    loss = alpha * hard_loss + (1 - alpha) * feature_loss
    return loss

Relational Knowledge Distillation (RKD)

Relational Knowledge Distillation focuses on transferring the relationships (distances and angles) between data samples as learned by the teacher. The student model is trained to match these relationships instead of just focusing on output probabilities.

# boss
def pairwise_distance(x):
    """Calculate pairwise distance between batch samples."""
    return torch.cdist(x, x, p=2)

def angle_between_pairs(x):
    """Calculate angles between all pairs of points in batch."""
    diff = x.unsqueeze(1) - x.unsqueeze(0)
    norm = diff.norm(dim=-1, p=2, keepdim=True)
    normalized_diff = diff / (norm + 1e-8)
    angles = torch.bmm(normalized_diff, normalized_diff.transpose(1, 2))
    return angles

def loss_fn_rkd(teacher_feature, student_feature, labels, alpha=0.5):
    """
    Relational Knowledge Distillation Loss Function.
    
    Args:
    - teacher_feature: Teacher model feature embeddings.
    - student_feature: Student model feature embeddings.
    - labels: Ground truth labels.
    - alpha: Weighting factor for relational distillation loss.
    
    Returns:
    - loss: Combined relational knowledge and hard label loss.
    """
    
    # Pairwise distances between features in the teacher and student model
    teacher_dist = pairwise_distance(teacher_feature)
    student_dist = pairwise_distance(student_feature)
    
    # Distillation loss using the L2 norm between relational distances
    distance_loss = F.mse_loss(student_dist, teacher_dist)

    # Angle-based loss between teacher and student feature vectors
    teacher_angle = angle_between_pairs(teacher_feature)
    student_angle = angle_between_pairs(student_feature)
    angle_loss = F.mse_loss(student_angle, teacher_angle)
    
    # Hard label cross-entropy loss for the student output
    hard_loss = F.cross_entropy(student_feature, labels)
    
    # Combine the losses
    loss = alpha * hard_loss + (1 - alpha) * (distance_loss + angle_loss)
    return loss

Distance Metric (DM) Knowledge Distillation

Distance Metric distillation focuses on transferring the distance metric (such as Euclidean distance or cosine similarity) between instances in the teacher’s feature space to the student model.

def loss_fn_dm(teacher_feature, student_feature, labels, alpha=0.5):
    """
    Distance Metric (DM) Knowledge Distillation Loss Function.
    
    Args:
    - teacher_feature: The feature representations from the teacher model.
    - student_feature: The feature representations from the student model.
    - labels: Ground truth labels for the task.
    - alpha: Weighting factor for distance metric loss.
    
    Returns:
    - loss: Combined distance metric loss and cross-entropy loss.
    """
    # Calculate pairwise distance between teacher and student embeddings
    teacher_dist = pairwise_distance(teacher_feature)
    student_dist = pairwise_distance(student_feature)
    
    # Distance metric loss using Mean Squared Error (MSE) loss
    dist_loss = F.mse_loss(student_dist, teacher_dist)
    
    # Hard label cross-entropy loss for the student's output
    hard_loss = F.cross_entropy(student_feature, labels)
    
    # Combine the losses
    loss = alpha * hard_loss + (1 - alpha) * dist_loss
    return loss

Code Link

kaggle
具体代码在Github

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值