PyTorch 的定义
PyTorch 是一个开源的机器学习框架,基于 Torch 库开发,由 Facebook 的 AI 研究团队(现为 Meta AI)主导维护。它专为深度学习任务设计,支持动态计算图(Dynamic Computation Graph),提供灵活的张量计算和自动微分功能,广泛用于学术研究和工业应用。
PyTorch 的核心特性
动态计算图
PyTorch 使用动态计算图(称为“即时执行”模式),允许用户在运行时修改计算流程。这种特性便于调试和实现复杂的模型结构(如递归神经网络)。
张量计算与 GPU 加速
提供类似 NumPy 的多维张量(Tensor)操作,支持自动并行化计算和 GPU 加速。通过 CUDA 接口,可以高效利用 NVIDIA 显卡进行训练和推理。
自动微分(Autograd)
内置的 autograd
模块可自动计算梯度,简化反向传播的实现。用户只需定义前向传播,梯度会通过动态图自动追踪和计算。
丰富的生态系统
- TorchVision:支持计算机视觉任务(如数据集加载、预训练模型)。
- TorchText:处理自然语言处理(NLP)任务。
- TorchAudio:提供音频数据处理工具。
PyTorch 的基本用法示例
以下是一个简单的线性回归模型代码片段:
import torch
import torch.nn as nn
# 定义模型
model = nn.Linear(1, 1) # 输入和输出维度均为1
criterion = nn.MSELoss() # 损失函数
optimizer = torch.optim.SGD(model.parameters(), lr=0.01) # 优化器
# 模拟数据
x = torch.tensor([[1.0], [2.0], [3.0]])
y = torch.tensor([[2.0], [4.0], [6.0]])
# 训练循环
for epoch in range(100):
optimizer.zero_grad()
outputs = model(x)
loss = criterion(outputs, y)
loss.backward()
optimizer.step()
PyTorch 与其他框架的对比
- TensorFlow:早期采用静态计算图(TensorFlow 2.x 已支持动态图),适合生产部署。
- JAX:提供函数式编程和自动微分,但生态规模较小。
- Keras:高层 API,通常作为 TensorFlow 的封装。
PyTorch 因其易用性和灵活性,成为学术界的主流选择,同时在工业界的应用也逐渐扩大。
学习基础知识
大多数机器学习工作流程都涉及处理数据、创建模型、优化模型参数和保存经过训练的模型。本教程向您介绍在 PyTorch 中实现的完整 ML 工作流,并提供链接以了解有关每个概念的更多信息。
我们将使用 Fashion MNIST 数据集来训练一个神经网络,该网络预测输入图像是否属于以下类别之一:T 恤/上衣、裤子、套头衫、连衣裙、外套、凉鞋、衬衫、运动鞋、包或脚踝开机。
Fashion MNIST 简介
Fashion MNIST 是一个广泛使用的图像分类数据集,专为机器学习和计算机视觉任务设计。它包含 10 类时尚单品的灰度图像,每张图像分辨率为 28x28 像素,共 70,000 张(60,000 张训练集 + 10,000 张测试集)。该数据集常被用作经典 MNIST 数据集(手写数字)的替代品,因其更具挑战性且更贴近真实场景。
数据集类别
Fashion MNIST 包含以下 10 类时尚单品标签:
- T-shirt/top(T恤/上衣)
- Trouser(裤子)
- Pullover(套头衫)
- Dress(连衣裙)
- Coat(外套)
- Sandal(凉鞋)
- Shirt(衬衫)
- Sneaker(运动鞋)
- Bag(包)
- Ankle boot(短靴)
特点与用途
- 复杂度适中:比 MNIST 更具区分难度,适合测试模型鲁棒性。
- 轻量级:图像尺寸小,便于快速实验和原型开发。
- 通用基准:常用于验证卷积神经网络(CNN)等模型的性能。
数据加载示例(Python)
通过主流深度学习框架可直接加载数据集:
# 使用 TensorFlow 加载
import tensorflow as tf
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()
# 使用 PyTorch 加载
from torchvision import datasets
data = datasets.FashionMNIST(root='./data', download=True)
应用场景
- 图像分类任务入门
- 模型性能对比基准
- 数据增强(如旋转、缩放)的实验
Fashion MNIST 因其简单性和实用性,成为机器学习教育及研究中常用的工具之一。
Class 10 mnist 分类
本教程假定您基本熟悉 Python 和深度学习概念。
运行教程代码
您可以通过以下几种方式运行本教程:
- 在云端:这是最简单的入门方式!每个部分的顶部都有一个“在 Microsoft Learn 中运行”链接,该链接在 Microsoft Learn 中打开一个集成笔记本,其中包含完全托管环境中的代码。
- 本地:此选项要求您首先在本地机器上设置 PyTorch 和 TorchVision(安装说明)。下载笔记本或将代码复制到您最喜欢的 IDE 中。
例子还是考虑用本地的方式
如何使用本指南
如果您熟悉其他深度学习框架,请先查看0. Quickstart,以快速熟悉 PyTorch 的 API。
如果您不熟悉深度学习框架,请直接进入我们分步指南的第一部分:1. 张量。
脚本总运行时间:(0分0.000秒)
快速开始
本节贯穿机器学习中常见任务的 API。请参阅每个部分中的链接以深入了解。
处理数据
PyTorch 有两个处理数据的原语: torch.utils.data.DataLoader
和torch.utils.data.Dataset
. Dataset
存储样本及其对应的标签,并DataLoader
在Dataset
.
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
PyTorch 提供特定领域的库,例如TorchText、 TorchVision和TorchAudio,所有这些库都包含数据集。在本教程中,我们将使用 TorchVision 数据集。
该torchvision.datasets
模块包含Dataset
许多真实世界视觉数据的对象,如 CIFAR、COCO(此处为完整列表)。在本教程中,我们使用 FashionMNIST 数据集。每个 TorchVision 都Dataset
包含两个参数:transform
和 target_transform
分别修改样本和标签。
# Download training data from open datasets.
training_data = datasets.FashionMNIST(
root="data",
train=True,
download=True,
transform=ToTensor(),
)
# Download test data from open datasets.
test_data = datasets.FashionMNIST(
root="data",
train=False,
download=True,
transform=ToTensor(),
)
出去:
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to data/FashionMNIST/raw/train-images-idx3-ubyte.gz
Extracting data/FashionMNIST/raw/train-images-idx3-ubyte.gz to data/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw/train-labels-idx1-ubyte.gz
Extracting data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz
Extracting data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz
Extracting data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw
我们将Dataset
作为参数传递给