1. DataLoader的核心概念
DataLoader是PyTorch中一个重要的类,用于将数据集(dataset)和数据加载器(sampler)结合起来,以实现批量数据加载和处理。它可以高效地处理数据加载、多线程加载、批处理和数据增强等任务。
核心参数
dataset
: 数据集对象,必须是继承自torch.utils.data.Dataset
的类。batch_size
: 每个批次的大小。shuffle
: 是否在每个epoch
开始时打乱数据。sampler
: 定义数据加载顺序的对象,通常与shuffle
互斥。num_workers
: 使用多少个子进程加载数据。collate_fn
: 如何将单个样本合并为一个批次的函数。pin_memory
: 是否将数据加载到CUDA
固定内存中。
2. 基本使用方法
定义数据集类
首先定义一个数据集类,该类需要继承自torch.utils.data.Dataset
并实现__len__
和__getitem__
方法。
import torch
from torch.utils.data import Dataset, DataLoader
class CustomDataset(Dataset):
def __init__(self, data, labels):
self.data = data
self.labels = labels
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
sample = {
'data': self.data[idx], 'label': self.labels[idx]}
return sample
# 创建一些示例数据
data = torch.randn(100, 3, 64, 64) # 100个样本,每个样本为3x64x64的图像
labels = torch.randint(0, 2, (100,)) # 100个标签,0或1
dataset = CustomDataset(data, labels)
创建DataLoader
使用自定义数据集类创建DataLoader对象。
batch_size = 4
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2)
迭代DataLoader
遍历DataLoader获取批量数据。
for batch in dataloader:
data, labels = batch['data'], batch['label']
print(data.shape, labels.shape)
3. 进阶技巧
自定义collate_fn
如果需要自定义如何将样本合并为批次,可以定义自己的collate_fn
函数。
def custom_collate_fn(batch):
data = [item['data'] for item in batch]
labels = [item['label'] for item in batch]
return {
'data': torch.stack(data), 'label': torch.tensor(labels)}
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2, collate_fn=custom_collate_fn)
使用Sampler
Sampler
定义了数据加载的顺序。可以自定义一个Sampler来实现更复杂的数据加载策略。
from torch.utils.data import Sampler
class CustomSampler(Sampler):
def __init__(self, data_source):
self.data_source = data_source
def __iter__(self):
return iter(range(len(self.data_source)))
def __len__(self):
return len(self.</