38实战Kaggle
比赛:图像分类 (CIFAR-10
)
比赛链接:CIFAR-10 - Object Recognition in Images | Kaggle
导入包
import os
import glob
import pandas as pd
import numpy as np
import torch
import torchvision
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
from torch import nn
from d2l import torch as d2l
import liliPytorch as lp
import csv
预处理:数据集分析
# 获取精简数据集
#@save
d2l.DATA_HUB['cifar10_tiny'] = (d2l.DATA_URL + 'kaggle_cifar10_tiny.zip',
'2068874e4b9a9f0fb07ebe0ad2b29754449ccacd')
# 如果使用完整的Kaggle竞赛的数据集,设置demo为False
demo = True
if demo:
data_dir = d2l.download_extract('cifar10_tiny')
else:
data_dir = '../data/cifar-10/'
train_path = '../data/kaggle_cifar10_tiny/train.csv'
file_path = '../data/kaggle_cifar10_tiny/'
# 读取数据
train_data = pd.read_csv(train_path)
# 查看数据
print(train_data['label'].value_counts())
# """
# label
# automobile 112
# frog 107
# truck 103
# horse 102
# airplane 102
# deer 99
# bird 99
# ship 99
# cat 92
# dog 85
# """
1.数据处理与加载
train_path = '../data/kaggle_cifar10_tiny/train.csv'
test_path = '../data/kaggle_cifar10_tiny/test.csv'
file_path = '../data/kaggle_cifar10_tiny/'
# 统计label种类,并排序
cifar_labels = sorted(list(set(train_data['label'])))
# 将label对应编号
labels_to_num = dict(zip(cifar_labels, range(len(cifar_labels))))
# print(labels_to_num)
"""
{'airplane': 0, 'automobile': 1, 'bird': 2, 'cat': 3, 'deer': 4,
'dog': 5, 'frog': 6, 'horse': 7, 'ship': 8, 'truck': 9}
"""
# 将编号对应label,用于后续预测
num_to_labels = {
value : key for key, value in labels_to_num.items()}
# print(num_to_labels)
"""
{0: 'airplane', 1: 'automobile', 2: 'bird', 3: 'cat', 4: 'deer',
5: 'dog', 6: 'frog', 7: 'horse', 8: 'ship', 9: 'truck'}
"""
def get_image_filenames(folder_path, extensions=['.png', '.jpg', '.jpeg']):
# 获取指定文件夹中的所有图片文件
image_files = []
for ext in extensions:
image_files.extend(glob.glob(os.path.join(folder_path, f'*{
ext}')))
# 返回图片文件名列表
return [os.path.basename(image) for image in image_files]
def save_filenames_to_csv(filenames, csv_path):
with open(csv_path, mode='w', newline='', encoding='utf-8') as file:
writer = csv.writer(file)
# 写入CSV的第一行
writer.writerow(['id'])
# 写入每个文件名
for filename in filenames:
writer.writerow([filename])
# 获取测试图片名
test_images_path = '../data/kaggle_cifar10_tiny/test'
image_filenames = get_image_filenames(test_images_path)
# 保存到CSV文件
save_filenames_to_csv(image_filenames, file_path + 'test.csv')
class CifarDataset(Dataset):
def __init__(self, csv_path, file_path, mode='train', valid_ratio=0.2, resize_height