查看是否存在样本不均衡问题,同时方便选择合适的数据预处理方式。
#%%
from glob import glob
import pandas as pd
import numpy as np
import os
import cv2
from PIL import Image
from matplotlib import pyplot as plt
from tqdm import tqdm
#%%
# 训练集探索
TRAIN_DATASET_PATH = '../train_data'
image_fns = glob(os.path.join(TRAIN_DATASET_PATH, '*', '*.*'))
label_names = [s.split('/')[-2] for s in image_fns]
unique_labels = list(set(label_names))
#%%
# 类别数
print(len(unique_labels))
#%%
# 图片总数
print(