PyTorch入门:快速掌握深度学习框架

本教程介绍了在PyTorch中实现机器学习工作流,包括使用FashionMNIST数据集训练神经网络。内容涵盖数据加载、张量操作、模型创建、自动微分、优化循环以及保存和加载模型。教程还提供了下载数据、处理数据、创建数据加载器、定义模型和训练模型的代码示例。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

PyTorch 的定义

PyTorch 是一个开源的机器学习框架,基于 Torch 库开发,由 Facebook 的 AI 研究团队(现为 Meta AI)主导维护。它专为深度学习任务设计,支持动态计算图(Dynamic Computation Graph),提供灵活的张量计算和自动微分功能,广泛用于学术研究和工业应用。

PyTorch 的核心特性

动态计算图
PyTorch 使用动态计算图(称为“即时执行”模式),允许用户在运行时修改计算流程。这种特性便于调试和实现复杂的模型结构(如递归神经网络)。

张量计算与 GPU 加速
提供类似 NumPy 的多维张量(Tensor)操作,支持自动并行化计算和 GPU 加速。通过 CUDA 接口,可以高效利用 NVIDIA 显卡进行训练和推理。

自动微分(Autograd)
内置的 autograd 模块可自动计算梯度,简化反向传播的实现。用户只需定义前向传播,梯度会通过动态图自动追踪和计算。

丰富的生态系统

  • TorchVision:支持计算机视觉任务(如数据集加载、预训练模型)。
  • TorchText:处理自然语言处理(NLP)任务。
  • TorchAudio:提供音频数据处理工具。

PyTorch 的基本用法示例

以下是一个简单的线性回归模型代码片段:

import torch
import torch.nn as nn

# 定义模型
model = nn.Linear(1, 1)  # 输入和输出维度均为1
criterion = nn.MSELoss()  # 损失函数
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)  # 优化器

# 模拟数据
x = torch.tensor([[1.0], [2.0], [3.0]])
y = torch.tensor([[2.0], [4.0], [6.0]])

# 训练循环
for epoch in range(100):
    optimizer.zero_grad()
    outputs = model(x)
    loss = criterion(outputs, y)
    loss.backward()
    optimizer.step()

PyTorch 与其他框架的对比

  • TensorFlow:早期采用静态计算图(TensorFlow 2.x 已支持动态图),适合生产部署。
  • JAX:提供函数式编程和自动微分,但生态规模较小。
  • Keras:高层 API,通常作为 TensorFlow 的封装。

PyTorch 因其易用性和灵活性,成为学术界的主流选择,同时在工业界的应用也逐渐扩大。

学习基础知识

大多数机器学习工作流程都涉及处理数据、创建模型、优化模型参数和保存经过训练的模型。本教程向您介绍在 PyTorch 中实现的完整 ML 工作流,并提供链接以了解有关每个概念的更多信息。

我们将使用 Fashion MNIST 数据集来训练一个神经网络,该网络预测输入图像是否属于以下类别之一:T 恤/上衣、裤子、套头衫、连衣裙、外套、凉鞋、衬衫、运动鞋、包或脚踝开机。

Fashion MNIST 简介

Fashion MNIST 是一个广泛使用的图像分类数据集,专为机器学习和计算机视觉任务设计。它包含 10 类时尚单品的灰度图像,每张图像分辨率为 28x28 像素,共 70,000 张(60,000 张训练集 + 10,000 张测试集)。该数据集常被用作经典 MNIST 数据集(手写数字)的替代品,因其更具挑战性且更贴近真实场景。

数据集类别

Fashion MNIST 包含以下 10 类时尚单品标签:

  • T-shirt/top(T恤/上衣)
  • Trouser(裤子)
  • Pullover(套头衫)
  • Dress(连衣裙)
  • Coat(外套)
  • Sandal(凉鞋)
  • Shirt(衬衫)
  • Sneaker(运动鞋)
  • Bag(包)
  • Ankle boot(短靴)

特点与用途

  • 复杂度适中:比 MNIST 更具区分难度,适合测试模型鲁棒性。
  • 轻量级:图像尺寸小,便于快速实验和原型开发。
  • 通用基准:常用于验证卷积神经网络(CNN)等模型的性能。

数据加载示例(Python)

通过主流深度学习框架可直接加载数据集:

# 使用 TensorFlow 加载  
import tensorflow as tf  
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()  

# 使用 PyTorch 加载  
from torchvision import datasets  
data = datasets.FashionMNIST(root='./data', download=True)  

应用场景

  • 图像分类任务入门
  • 模型性能对比基准
  • 数据增强(如旋转、缩放)的实验

Fashion MNIST 因其简单性和实用性,成为机器学习教育及研究中常用的工具之一。

Class 10 mnist 分类

本教程假定您基本熟悉 Python 和深度学习概念。

运行教程代码

您可以通过以下几种方式运行本教程:

  • 在云端:这是最简单的入门方式!每个部分的顶部都有一个“在 Microsoft Learn 中运行”链接,该链接在 Microsoft Learn 中打开一个集成笔记本,其中包含完全托管环境中的代码。
  • 本地:此选项要求您首先在本地机器上设置 PyTorch 和 TorchVision(安装说明)。下载笔记本或将代码复制到您最喜欢的 IDE 中。

例子还是考虑用本地的方式

如何使用本指南

如果您熟悉其他深度学习框架,请先查看0. Quickstart,以快速熟悉 PyTorch 的 API。

如果您不熟悉深度学习框架,请直接进入我们分步指南的第一部分:1. 张量

脚本总运行时间:(0分0.000秒)

快速开始

本节贯穿机器学习中常见任务的 API。请参阅每个部分中的链接以深入了解。

处理数据

PyTorch 有两个处理数据的原语: torch.utils.data.DataLoadertorch.utils.data.DatasetDataset存储样本及其对应的标签,并DataLoaderDataset.

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

PyTorch 提供特定领域的库,例如TorchText、 TorchVisionTorchAudio,所有这些库都包含数据集。在本教程中,我们将使用 TorchVision 数据集。

torchvision.datasets模块包含Dataset许多真实世界视觉数据的对象,如 CIFAR、COCO(此处为完整列表)。在本教程中,我们使用 FashionMNIST 数据集。每个 TorchVision 都Dataset包含两个参数:transform和 target_transform分别修改样本和标签。

# Download training data from open datasets.
training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
)

# Download test data from open datasets.
test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor(),
)

出去:

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to data/FashionMNIST/raw/train-images-idx3-ubyte.gz
Extracting data/FashionMNIST/raw/train-images-idx3-ubyte.gz to data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw/train-labels-idx1-ubyte.gz
Extracting data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz
Extracting data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz
Extracting data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw

我们将Dataset作为参数传递给

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

KENYCHEN奉孝

您的鼓励是我的进步源泉

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值