import torch
import torchvision
from torch import nn
from torch.utils.data import Dataset,DataLoader
from torchvision import transforms
from d2l import torch as d2l
def load_data_fashion_mnist(batch_size,resize):
trans = []
if resize:
trans.append(transforms.Resize(size=resize))
trans.append(transforms.ToTensor())
transform = transforms.Compose(trans)
#准备数据集
train_data = torchvision.datasets.FashionMNIST(root='./data', train=True, transform=transform,
download=True)
test_data = torchvision.datasets.FashionMNIST(root='./data', train=False, transform=transform,
download=True)
#加载数据集
train_dataloader = DataLoader(dataset=train_data,batch_size=batch_size,shuffle=True)
test_dataloader = DataLoader(dataset=test_data,batch_size=batch_size,shuffle=False)
return train_dataloader,test_dataloader
#搭建网络
class AlexNet(nn.Module):
def __init__(self):
super(AlexNet, self).__init__()
self.model = nn.Sequential(
nn.Conv2d(1,96,kernel_size=11,stride=4,padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=3,stride=2),
nn.Conv2d(96,256,kernel_size=5,padding=2),
nn.ReLU(),
nn.MaxPool2d
CNN——AlexLet
最新推荐文章于 2024-12-14 05:00:00 发布