Dataset的使用:
主要代码:
import os
dir_path="D:\\tensorflowdata\\number_gen"
img_path_list=os.listdir(dir_path)#读取dir_path路径下number_gan文件内的文件名的列表
img_path_list[0]
#输出:'00000.png'
整体代码:
from torch.utils.data import Dataset
from PIL import Image
import os
class MyData(Dataset):
def __init__(self,root_dir,label_dir):
self.root_dir=root_dir#获取文件目录
self.label_dir=label_dir#获取标签(以标签命名的文件名)
self.path=os.path.join(self.root_dir, self.label_dir)#拼接两个路径
self.img_path=os.listdir(self.path)#获得path路径下的文件名列表
def __getitem__(self, index):
img_name=self.img_path[index]#获取每张图片的文件名
img_item_path=os.path.join(self.root_dir, self.label_dir,img_name)#获取每张图片的地址
img=Image.open(img_item_path)#图片信息
label=self.label_dir#标签
return img,label#返回图片和标签
def __len__(self):
return len(self.img_path)#输出图片的张数
root_dir="D:\\tensorflowdata"
label_dir="number_gen"
gan_dataset=MyData(root_dir, label_dir)
print(gan_dataset[0])
img,label=gan_dataset[0]#传递第一张图片及其标签
img.show() #显示图片
#输出为:
#(<PIL.PngImagePlugin.PngImageFile image mode=RGBA size=288x288 at 0x2B89BE97130>, 'number_gan')
#100