python中CIFAR10数据集的使用

本文详细介绍了如何下载和使用CIFAR10数据集,包括通过torchvision.datasets.CIFAR10加载数据,查看数据集内容,并通过transforms将数据转换为适合PyTorch使用的Tensor格式。此外,还展示了使用SummaryWriter将数据写入TensorBoard的过程。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

本文主要解决了如何把数据集与transforms结合在一起的问题。


一、CIFAR10的官方解释

torchvision.datasets.CIFAR10(
root: str, 
train: bool = True, 
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False)

注释:

root (string) -- 存在 cifar-10-batches-py 目录的数据集的根目录,如果下载设置为 True,则将保存到该目录。

train (bool, optional) -- 如果为True,则从训练集创建数据集, 如果为False,从测试集创建数据集。

transform (callable, optional) – 它接受一个 PIL 图像并返回一个转换后的版本。 例如,transforms.RandomCrop/transforms.ToTensor

target_transform (callable, optional) – 接收目标并对其进行转换的函数/转换。

download (bool, optional) – 如果为 true,则从 Internet 下载数据集并将其放在根目录中。 如果数据集已经下载,则不会再次下载。

二、实战操作

1.CIAFR10数据集的下载

代码如下:

import torchvision   #导入torchvision这个类

train_set = torchvision.datasets.CIFAR10(root = "./dataset", train = True, 
download= True)  #从训练集创建数据集
test_set = torchvision.datasets.CIFAR10(root="./dataset", train=False,
 download=True)    #从测试集创建数据集

root = "./dataset",将下载的数据集保存在这个文件夹下;download= True,从 Internet 下
载数据集并将其放在根目录中,这里就是在相对路径中,创建dataset文件夹,将数据集保存
在dataset中。

2.查看下载的CIAFR10数据集

运行程序,开始下载数据集。下载成功后,可以进行一些查看。代码如下:

接着输入:
print(train_set[0])  #查看train_set训练集中的第一个数据
print(train_set.classes)   #查看train_set训练集中有多少个类别

img, target = train_set[0]
print(img)
print(target)
print(train_set.classes[target])
img.show()  #显示图片
输出结果:
(<PIL.Image.Image image mode=RGB size=32x32 at 0x161E924B8D0>, 6)
['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship',
'truck']
<PIL.Image.Image image mode=RGB size=32x32 at 0x161E924B710>
6
frog

注释:可以看见,train_set数据集中有10个类别,train_set中第0个元素的target是6,也就是说,这个元素是属于第7个类别frog的。
 

3.数据转换

因为这些图片类型都是PIL Image,如果要供给pytorch使用的话,需要将数据全都转化成tensor类型。

完整代码如下:

import torchvision   #导入torchvision这个类
from torch.utils.tensorboard import SummaryWriter

from torchvision import transforms
dataset_transforms = transforms.ToTensor()

# dataset_transforms = torchvision.transforms.Compose([
#     torchvision.transforms.ToTensor()
# ])    第3  4 行代码可以用compose直接写
train_set = torchvision.datasets.CIFAR10(root = "./dataset", train = True, transform=dataset_transforms, download= True) #训练集
test_set = torchvision.datasets.CIFAR10(root="./dataset", train=False, transform=dataset_transforms, download=True)   #测试集

writer = SummaryWriter("logs")

# print(train_set[0])  #查看train_set训练集中的第一个数据
# print(train_set.classes)   #查看train_set训练集中有多少个类别

# img, target = train_set[0]
# print(img)
# print(target)
# print(train_set.classes[target])
# img.show()
for i in range(20):
    img, target = train_set[i]
    writer.add_image("cifar10_test2", img, i)

writer.close()
 


总结

CIFAR10数据集内存很小,只有100多m,下载方便。对我们学习数据集非常友好,练习的时候,我们可以使用SummaryWriter来将数据写入tensorboard中。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

晓亮.

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值