一、加载数据
1.使用torchvision.datasets的方法加载经典数据集
在此网址查看支持哪些经典数据集:Datasets — Torchvision 0.18 documentation (pytorch.org)
data_train = torchvision.datasets.CIFAR10(root="CIFAR10", train=True, transform=ToTensor(),
target_transform=None, download=True)
data_test = torchvision.datasets.CIFAR10(root="CIFAR10", train=False, transform=ToTensor(),
target_transform=None, download=True)
下面三个参数是所有加载经典数据集的函数共有的参数:
- root:存储数据集的目录
- transform:通常为对图像数据进行一系列转换操作的函数
- transform_target:通常为对目标数据进行一系列转换操作的函数
2.自己收集的数据集
(1)使用列表缓存图像和标签
root = "./data"
x = []
y = []
label_to_int = {}
int_to_label = {}
for kind in os.listdir(root):
label = len(label_to_int)
label_to_int[kind] = label
int_to_label[label] = kind
kind_root = os.path.join(root, kind)
images_path = os.listdir(kind_root)
for img_path in images_path:
img = Image.open(os.path.join(kind_root,img_path)).convert("RGB")
img = torchvision.transforms.ToTensor()(img)
x.append(img)
y.append(label)
(2)使用自定义Dataset 动态加载
class MyDataset(Dataset):
def __init__(self, root):
self.root = root
self.image_paths = []
self.labels = []
self.label_to_int = {}
self.int_to_label = {}
for kind in os.listdir(self.root):
label = len(self.label_to_int)
s