Pytorch自定义Dataset和使用DataLoader装载数据

本文介绍如何使用Pytorch自定义Dataset类和DataLoader加载lfw人脸数据库的人脸图像,涵盖图像读取、预处理及数据组织,如分批、shuffle等操作。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

本文以lfw人脸数据库为例,用Pytorch自定义Dataset和使用DataLoader装载人脸图像。Dataset主要功能是读取数据源,而DataLoader在Dataset基础上组织数据供给深度算法使用,比如对图像的分批、shuaffle、扩展样本等操作。本文用的图片放在facesBmp目录下面,如下图所示:

下面是实现代码。代码比较简单,可以看里面注释

# -*-coding: utf-8 -*-
import torch
from torch.autograd import Variable
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
import numpy as np
import os,sys
from PIL import Image
import matplotlib.pyplot as plt

class  LfwDataset(Dataset):
    def __init__(self, image_dir, resize_height=64, resize_width=64):
        '''
        :param image_dir: 图片路径:image_dir+imge_name.jpg构成图片的完整路径
        :param resize_height 为图像高,
        :param resize_width  为图像宽        
        '''
        # 所有图片的绝对路径
        imgs=os.listdir(image_dir)
        self.imgs=[os.path.join(image_dir,k) for k in imgs]
      # 相关预处理的初始化
      #  self.transforms=transform
        self.transforms=True
        self.transform= transforms.Compose([
        transforms.ToTensor(),  # 将图片转换为Tensor,归一化至[0,1]
         transforms.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5])  # 标准化至[-1,1]
])
 
    def __getitem__(self, i):
        img_path = self.imgs[i]
        pil_img = Image.open(img_path)
        if self.transforms:
            data =self.transform(pil_img)
        else:
            pil_img = np.asarray(pil_img)
            data = torch.from_numpy(pil_img)
        return data
 
    def __len__(self):
        return len(self.imgs)

   

 
if __name__=='__main__':
    image_dir="../facesBmp" #该文件夹下面直接是图像,与原始文件不一样,原始文件是有人名的二级目录
 
    epoch_num=1   #总样本循环次数 反对
    batch_size=1000  #训练时的一组数据的大小
    train_data_nums=13233
    max_iterate=int((train_data_nums+batch_size-1)/batch_size*epoch_num) #总迭代次数 
    train_data = LfwDataset(image_dir=image_dir)
    train_loader = DataLoader(dataset=train_data, batch_size=batch_size, shuffle=False)
 
    # [1]使用epoch方法迭代,LfwDataset的参数repeat=1
    for epoch in range(epoch_num):
        for batch_image in train_loader:
            image=batch_image[0,:]
            image=image.numpy()#       
            # plt.imshow(image)
            # plt.show()
            print("batch_image.shape:{}".format(batch_image.shape))    
 
    '''
    以两种方式实现训练集迭代,退出循环由max_iterate设定
    '''
    train_data = LfwDataset(image_dir=image_dir)
    train_loader = DataLoader(dataset=train_data, batch_size=batch_size, shuffle=False)
    # [2]第2种迭代方法
    print('第2种迭代方法')
    print(enumerate(train_loader))
    for step, batch_image in enumerate(train_loader):
        image=batch_image[0,:]
        image=image.numpy()#image=np.array(image)    
        # plt.imshow(image)
        # plt.show()
        print("step:{},batch_image.shape:{}".format(step,batch_image.shape))
        if step>=max_iterate:
            break

 代码中使用了两种方法迭代训练数据集,输出结果如下图所示:

 

 

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值