本文将分享一个使用PyTorch实现的手写数字识别项目,基于MNIST数据集构建全连接神经网络模型。
项目概述
本项目实现了一个三层全连接神经网络,用于识别手写数字(0-9)。包含完整的数据处理、模型训练、验证评估和预测功能。
代码结构解析
1. 数据预处理与加载
transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) # MNIST的均值和标准差 ])
关键点:
-
ToTensor()
: 将PIL图像转换为PyTorch张量,并自动归一化到[0,1]范围 -
Normalize()
: 使用MNIST数据集的均值和标准差进行标准化,提高训练稳定性
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True) eval_dataset = datasets.MNIST(root='./data', train=False, transform=transform, download=True)
数据划分:
-
训练集:60,000张图像
-
验证集:10,000张图像
-
自动下载并缓存到本地
./data
目录
下载好后的目录结构
2. 神经网络模型设计
class MyNet(nn.Module): def __init__(self): super(MyNet, self).__init__() self.fc1 = nn.Linear(784, 256) # 输入层→隐藏层1 self.bn1 = nn.BatchNorm1d(256) # 批归一化层 self.relu = nn.ReLU() # 激活函数 self.fc2 = nn.Linear(256, 128) # 隐藏层1→隐藏层2 self.bn2 = nn.BatchNorm1d(128) # 批归一化层 self.fc3 = nn.Linear(128, 10) # 隐藏层2→输出层
网络架构特点:
-
输入层: 784个神经元(28×28像素展平)
-
隐藏层1: 256个神经元 + 批归一化 + ReLU激活
-
隐藏层2: 128个神经元 + 批归一化 + ReLU激活
-
输出层: 10个神经元(对应0-9数字分类)
批归一化(BatchNorm)的作用:
-
加速训练收敛
-
减少对初始化的敏感性
-
提供一定的正则化效果
3. 训练过程优化
def train(model, train_loader, epochs): model.train() for epoch in range(epochs): total_loss = 0 correct = 0 for batch_idx, (data, target) in enumerate(train_loader): # 前向传播 output = model(data) loss = criterion(output, target) # 反向传播 optimizer.zero_grad() loss.backward() optimizer.step() # 统计指标 total_loss += loss.item() _, predicted = torch.max(output.data, 1) correct += (predicted == target).sum().item()
训练关键要素:
-
损失函数: CrossEntropyLoss(适合多分类问题)
-
优化器: Adam(自适应学习率,训练效果较好)
-
学习率: 0.001(常用初始值)
-
批量大小: 64(平衡训练效率和内存使用)
4. 验证评估机制
def eval(model, eval_loader): model.eval() # 切换到评估模式 with torch.no_grad(): # 禁用梯度计算 for data, target in eval_loader: output = model(data) # ...计算损失和准确率...
评估模式特点:
-
model.eval()
: 关闭Dropout和BatchNorm的训练/评估模式切换 -
torch.no_grad()
: 减少内存消耗,加速计算 -
使用完整验证集进行评估,确保结果可靠性
5. 预测功能实现
def predict(img_path, model): img = Image.open(img_path).convert('L') # 转换为灰度图 transform = transforms.Compose([ transforms.Resize((28, 28)), # 调整尺寸 transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) # 与训练一致的预处理 ]) t_img = transform(img).unsqueeze(0) # 添加批次维度
预测注意事项:
-
图像必须预处理为与训练数据相同的格式
-
保持相同的归一化参数至关重要
-
unsqueeze(0)
为单张图像添加批次维度
完整代码:
import torch
from torch import nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torch import optim
from PIL import Image
# 数据预处理 - 添加归一化
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)) # MNIST的均值和标准差
])
# 数据准备
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
eval_dataset = datasets.MNIST(root='./data', train=False, transform=transform, download=True)
train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
eval_loader = DataLoader(dataset=eval_dataset, batch_size=512, shuffle=False) # 验证时不需要shuffle
# 定义网络结构
class MyNet(nn.Module):
def __init__(self):
super(MyNet, self).__init__()
self.fc1 = nn.Linear(784, 256)
self.bn1 = nn.BatchNorm1d(256)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(256, 128)
self.bn2 = nn.BatchNorm1d(128)
self.fc3 = nn.Linear(128, 10)
def forward(self, x):
x = x.view(-1, 28 * 28)
x = self.fc1(x)
x = self.bn1(x) # 批归一化在激活函数之前
x = self.relu(x)
x = self.fc2(x)
x = self.bn2(x) # 批归一化在激活函数之前
x = self.relu(x)
x = self.fc3(x)
return x
model = MyNet()
# 定义损失函数
criterion = nn.CrossEntropyLoss()
# 定义优化器
optimizer = optim.Adam(model.parameters(), lr=0.001)
# 训练
def train(model, train_loader, epochs):
model.train()
for epoch in range(epochs):
total_loss = 0
correct = 0
for batch_idx, (data, target) in enumerate(train_loader):
output = model(data)
loss = criterion(output, target)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
_, predicted = torch.max(output.data, 1)
correct += (predicted == target).sum().item()
if batch_idx % 100 == 0:
print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)}] '
f'Loss: {loss.item():.6f}')
avg_loss = total_loss / len(train_loader)
accuracy = 100. * correct / len(train_loader.dataset)
print(f'Train Epoch: {epoch} Average loss: {avg_loss:.4f}, Accuracy: {accuracy:.2f}%')
# 验证
def eval(model, eval_loader):
model.eval()
eval_loss = 0
correct = 0
with torch.no_grad():
for data, target in eval_loader:
output = model(data)
eval_loss += criterion(output, target).item() # 累加批损失
_, predicted = torch.max(output.data, 1)
correct += (predicted == target).sum().item()
eval_loss /= len(eval_loader) # 除以批数量,不是数据集大小
accuracy = 100. * correct / len(eval_loader.dataset)
print(f'\nValidation set: Average loss: {eval_loss:.4f}, Accuracy: {accuracy:.2f}%\n')
# 保存模型
def save_model():
torch.save(model.state_dict(), 'mnist_fc_model.pt')
# 预测 - 使用与训练相同的预处理
def predict(img_path, model):
model.eval()
img = Image.open(img_path).convert('L')
transform = transforms.Compose([
transforms.Resize((28, 28)),
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)) # 添加与训练相同的归一化
])
t_img = transform(img).unsqueeze(0)
with torch.no_grad():
output = model(t_img)
_, predicted = torch.max(output.data, 1)
print(f'Predicted digit: {predicted.item()}')
epochs = 5
train(model, train_loader, epochs)
eval(model, eval_loader)
save_model()
下载好对应的依赖库后,代码可直接运行
可以看到,经过五轮的训练后,我们的准确率达到了98%以上,并且对于测试集的预测正确率也达到了98%左右
这个项目展示了深度学习项目的基本流程,是学习PyTorch和图像分类的入门示例。通过理解每个组件的功能和作用,可以为进一步的深度学习项目打下坚实基础。