Iterative pruning techniques
Here, we’ll talk about iterative pruning, which allows you to prune a small fraction of weights at a time over multiple training steps. This method reduces the risk of drastic performance drops and provides more opportunities for the model to recover and adjust to the pruning.
The iterative approach also allows for fine-tuning after each pruning step, enabling the model to “heal” from the weight reduction:
# Iteratively prune 10% of the model after every 10 epochs for epoch in range(1, num_epochs+1): train(model, train_loader, optimizer) # Regular training step if epoch % 10 == 0: for name, module in model.named_modules(): if isinstance(module, torch.nn.Linear): ...