上次我们基于CIFAR-10训练一个图像分类器,梳理了一下训练模型的全过程,并且对卷积神经网络有了一定的理解,我们再在GPU上搭建一个手写的数字识别cnn网络,加深巩固一下
步骤
- 加载数据集
- 定义神经网络
- 定义损失函数
- 训练网络
- 测试网络
MNIST数据集简介
MINIST是一个手写数字数据库(官网地址:http://yann.lecun.com/exdb/mnist/),它有6w张训练样本和1w张测试样本,每张图的像素尺寸为28*28,如下图一共4个图片,这些图片文件均被保存为二进制格式
训练全过程
1.加载数据集
import torch
import torchvision
from torchvision import transforms
trainset = torchvision.datasets.MNIST(
root='./data',
train=True,
download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
]))
trainloader = torch.utils.data.DataLoader(trainset,
batch_size=64,
shuffle=True
)
testset = torchvision.datasets.MNIST('./data', train=False, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
]))
test_loader = torch.utils.data.DataLoader(testset
, batch_size=64, shuffle=True)
展示一些训练图片
import nu