相关文章:
《Pytorch深度学习框架实战教程01:简介》
1.前言
在本节中,我们将探讨如何通过保存、加载和运行模型预测来持久化模型状态。相关依赖包如下所示:
import torch import torchvision.models as models
Pytorch模型保存和加载有两种方式,如下图所示:
2.保存和加载模型权重:仅权重
2.1 保存模型权重
PyTorch 模型将其学习到的参数存储在一个内部状态字典中,称为 state_dict
。这些参数可以通过 torch.save
方法进行持久化。
model = models.vgg16(weights='IMAGENET1K_V1') torch.save(model.state_dict(), 'model_weights.pth')
2.2 加载模型权重
要加载模型权重,首先需要创建相同模型的实例,然后使用 load_state_dict()
方法加载参数。
在下面的代码中,我们将 weights_only
设置为 True
,以限制反序列化过程中执行的函数,仅保留加载权重所需的函数。加载权重时,使用 weights_only=True
被认为是最佳实践。
model = models.vgg16() # we do not specify ``weights``, i.e. create untrained model
model.load_state_dict(torch.load('model_weights.pth', weights_only=True))
model.eval()
3.保存和加载模型结构和权重:结构 + 权重
3.1 模型保存-结构+权重
加载模型权重时,我们需要先实例化模型类,因为类定义了网络的结构。我们可能希望将这个类的结构与模型一起保存,在这种情况下,我们可以将 model
(而不是 model.state_dict()
)传递给保存函数。
torch.save(model, 'model.pth')
3.2 模型加载-结构+权重
模型保存之后,我们可以按如下所示加载模型。
如 Saving and loading torch.nn.Modules 中所述,保存 state_dict
被认为是最佳实践。然而,下面我们使用 weights_only=False
是因为这涉及加载整个模型,这是 torch.save
的一个遗留用例。
model = torch.load('model.pth', weights_only=False),
特别说明:执行导入之前,必须声明/导入 待导入的模型类(vgg16),否则会报异常。