使用预训练模型构建自己的深度学习模型(迁移学习)

本文介绍了如何在深度学习中使用迁移学习方法,特别是在处理蚂蚁和蜜蜂图像分类任务时,通过加载预训练的ResNet18模型,微调最后的全连接层,并使用TensorBoard可视化训练过程。作者详细描述了数据预处理、模型训练、评估以及优化策略。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

在深度学习的实际应用中,很少会去从头训练一个网络,尤其是当没有大量数据的时候。即便拥有大量数据,从头训练一个网络也很耗时,因为在大数据集上所构建的网络通常模型参数量很大,训练成本大。所以在构建深度学习应用时,通常会使用预训练模型

需求

训练一个模型来分类蚂蚁ants和蜜蜂bees

步骤

  1. 加载数据集
  2. 编写函数(训练并寻找最优模型)
  3. 编写函数(查看模型效果)
  4. 使用torchvision微调模型
  5. 使用tensorboard可视化训练情况

Pytorch保存和加载模型的两种方式

1.完整保存模型、加载模型

torch.save(net, 'mnist.pth')
net = torch.load('mnist.pth', 'map_location="cpu"')
# # Model class must be defined somewhere

load() 默认会将该张量加载到保存时所在的设备上,map_location 可以强制加载到指定的设备上

2. 保存、加载模型的状态字典(模型中的参数)

torch.save(model.state_dict(), PATH)
model = TheModelClass(*args, **kwargs)

model.load_state_dict(torch.load(PATH))

迁移学习全过程

1.加载数据集

ants和bees各有约120张训练图片。

每个类有75张验证图片,从零开始在 如此小的数据集上进行训练通常是很难泛化的。

由于我们使用迁移学习,模型的泛化能力会相当好。

from __future__ import print_function, division

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torch.backends.cudnn as cudnn
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy
cudnn.benchmark = True  
plt.ion()

lr_scheduler:学习率调度器,用于在训练过程中动态调整学习率

torch.backends.cudnn.benchmark = True:大部分情况下,设置这个 flag 可以让内置的 cuDNN 的 auto-tuner 自动寻找最适合当前配置的高效算法,从而加速运算

plt.ion():这允许你在一个交互式环境中运行matplotlib

data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        t
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

鹤入云霄

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值