CNN 猫狗分类
一、数据准备
准备要分类的图像数据:如猫狗数据(数据集)
二、数据载入
1.引入库
import os
import torch
import torch.nn as nn
from torch import optim
from torch.autograd import Variable
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
torch.manual_seed(1) # 为了每次的实验结果一致
2.加载数据
class MyDataset(Dataset):
def __init__(self, data, transform):
self.data = data
self.transform = transform
def __getitem__(self, item):
img, label = self.data[item]
img = Image.open(img).convert('RGB')
img = self.transform(img)
return img, label
def __len__(self):
return len(self.data)
def load_data():
transform = transforms.Compose([
transforms.RandomHorizontalFlip(p=0.3),
transforms.RandomVerticalFlip(p=0.3),
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) # normalization
])
#数据集存放路径
cat_img_path = 'data/cat_dog/cat'
dog_img_path = 'data/cat_dog/dog'
train_img_list = [[os.path.join(cat_img_path, i), 0] for i in os.</