我最近需要用pytorch的场景挺多的,所以经常需要读取各种的数据,然后需要自定义自己的dataset和dataloader,这里分享一个dataset和dataloader的demo,方便大家来快速的用起来,代码示例:
import torch
from torch.utils.data import DataLoader, Dataset
class ImageDataset(Dataset):
'Characterizes a dataset for PyTorch'
def __init__(self, list_IDs, labels):
'Initialization'
self.labels = labels
self.list_IDs = list_IDs
def __len__(self):
'Denotes the total number of samples'
return len(self.list_IDs)
def __getitem__(self, index):
'Generates one sample of data'
# Select sample
# ID = self.list_IDs[index]
# Load data and get label
# X = torch.load('data/' + ID + '.pt')
X = torch.rand((3, 224,224))
y = torch.randint(0,5,(1,))
return X, y
if __name__=="__main__":
params = {'batch_size': 2,
'shuffle': True,
'num_workers': 1}
max_epochs = 100
# Datasets
partition =list(range(1000)) # IDs
labels =list(range(1000)) # Labels
# Generators
training_set = ImageDataset(partition, labels)
training_loader = DataLoader(training_set, **params)
for batch in training_loader:
x,y = batch
print(x.shape)
print(y)
break
数据的读取部分不太统一,所以需要靠大家自己来实现啦。
参考文献
A detailed example of how to generate your data in parallel with PyTorch