Pytorch深度学习框架实战教程09:模型的保存和加载

相关文章:

《Pytorch深度学习框架实战教程01:简介

1.前言

在本节中,我们将探讨如何通过保存、加载和运行模型预测来持久化模型状态。相关依赖包如下所示:

import torch
import torchvision.models as models

Pytorch模型保存和加载有两种方式,如下图所示:

2.保存和加载模型权重:仅权重

2.1 保存模型权重

PyTorch 模型将其学习到的参数存储在一个内部状态字典中,称为 state_dict。这些参数可以通过 torch.save 方法进行持久化。

model = models.vgg16(weights='IMAGENET1K_V1')
torch.save(model.state_dict(), 'model_weights.pth')

2.2 加载模型权重

要加载模型权重,首先需要创建相同模型的实例,然后使用 load_state_dict() 方法加载参数。

在下面的代码中,我们将 weights_only 设置为 True,以限制反序列化过程中执行的函数,仅保留加载权重所需的函数。加载权重时,使用 weights_only=True 被认为是最佳实践。

model = models.vgg16() # we do not specify ``weights``, i.e. create untrained model
model.load_state_dict(torch.load('model_weights.pth', weights_only=True))
model.eval()

3.保存和加载模型结构和权重:结构 + 权重

3.1 模型保存-结构+权重

加载模型权重时,我们需要先实例化模型类,因为类定义了网络的结构。我们可能希望将这个类的结构与模型一起保存,在这种情况下,我们可以将 model(而不是 model.state_dict())传递给保存函数。

torch.save(model, 'model.pth')

3.2 模型加载-结构+权重

模型保存之后,我们可以按如下所示加载模型。

如 Saving and loading torch.nn.Modules 中所述,保存 state_dict 被认为是最佳实践。然而,下面我们使用 weights_only=False 是因为这涉及加载整个模型,这是 torch.save 的一个遗留用例。

model = torch.load('model.pth', weights_only=False),

特别说明:执行导入之前,必须声明/导入 待导入的模型类(vgg16),否则会报异常。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

数据饕餮

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值