在深度学习的实际应用中,很少会去从头训练一个网络,尤其是当没有大量数据的时候。即便拥有大量数据,从头训练一个网络也很耗时,因为在大数据集上所构建的网络通常模型参数量很大,训练成本大。所以在构建深度学习应用时,通常会使用预训练模型
需求
训练一个模型来分类蚂蚁ants和蜜蜂bees
步骤
- 加载数据集
- 编写函数(训练并寻找最优模型)
- 编写函数(查看模型效果)
- 使用torchvision微调模型
- 使用tensorboard可视化训练情况
Pytorch保存和加载模型的两种方式
1.完整保存模型、加载模型
torch.save(net, 'mnist.pth')
net = torch.load('mnist.pth', 'map_location="cpu"')
# # Model class must be defined somewhere
load() 默认会将该张量加载到保存时所在的设备上,map_location 可以强制加载到指定的设备上
2. 保存、加载模型的状态字典(模型中的参数)
torch.save(model.state_dict(), PATH)
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
迁移学习全过程
1.加载数据集
ants和bees各有约120张训练图片。
每个类有75张验证图片,从零开始在 如此小的数据集上进行训练通常是很难泛化的。
由于我们使用迁移学习,模型的泛化能力会相当好。
from __future__ import print_function, division
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torch.backends.cudnn as cudnn
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy
cudnn.benchmark = True
plt.ion()
lr_scheduler:学习率调度器,用于在训练过程中动态调整学习率
torch.backends.cudnn.benchmark = True:大部分情况下,设置这个 flag 可以让内置的 cuDNN 的 auto-tuner 自动寻找最适合当前配置的高效算法,从而加速运算
plt.ion():这允许你在一个交互式环境中运行matplotlib
data_transforms = {
'train': transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
t