一、介绍
往往在训练一个模型之前,我们需要准备好数据集,简单省事一般是下载公开数据集,但有的公开数据集只画好了标签,还没有划分,所以需要我们划分一下。除了划分数据集以外吗,还需要别的工具代码,比如标签转格式,如xml转txt,json转txt等,还有提取公开数据集中的某一类标签的代码,后续均上传分享,也作为保存备份。
本次划分数据之前需要准备图片路径和标签路径,再新建一个输出路径,不新建输出目录时,会自动自行创建。运行程序后输出路径下会保存图片和标签路径,并分别分为训练验证测试集。划分比例可调,如果不需要划分测试集,比例置为0即可。下面是转换前后的文件树。
二、代码
import os
import random
import shutil
from tqdm import tqdm
def find_files(directory, extensions):
files = []
for root, _, filenames in os.walk(directory):
for filename in filenames:
if any(filename.lower().endswith(ext) for ext in extensions):
files.append(os.path.join(root, filename))
return files
def split_dataset(image_dir, label_dir, output_dir, split_ratios):
random.seed(123)
image_extensions = [".jpg", ".jpeg", ".png"]
label_extensions = [".txt", ".xml"]
image_files = find_files(image_dir, image_extensions)
label_files = find_files(label_dir, label_extensions)
num_samples = len(image_files)
num_train = int(num_samples * split_ratios["train"])
num_val = int(num_samples * split_ratios["val"])
num_test = num_samples - num_train - num_val
random.shuffle(image_files)
random.shuffle(label_files)
train_images = image_files[:num_train]
val_images = image_files[num_train:num_train + num_val]
test_images = image_files[num_train + num_val:]
train_labels = [img.replace(image_dir, label_dir).replace(".jpg", ".txt").replace(".png", ".txt").replace(".jpeg",
".txt").replace(
".xml", ".txt") for img in train_images]
val_labels = [img.replace(image_dir, label_dir).replace(".jpg", ".txt").replace(".png", ".txt").replace(".jpeg",
".txt").replace(
".xml", ".txt") for img in val_images]
test_labels = [img.replace(image_dir, label_dir).replace(".jpg", ".txt").replace(".png", ".txt").replace(".jpeg",
".txt").replace(
".xml", ".txt") for img in test_images]
# 创建输出目录
os.makedirs(output_dir, exist_ok=True)
for split_name in ["train", "val", "test"]:
os.makedirs(os.path.join(output_dir, "images", split_name), exist_ok=True)
os.makedirs(os.path.join(output_dir, "labels", split_name), exist_ok=True)
# 复制数据文件
for images, labels, split_name in [(train_images, train_labels, "train"),
(val_images, val_labels, "val"),
(test_images, test_labels, "test")]:
for image_file, label_file in tqdm(zip(images, labels), desc=f"Copying {split_name} data"):
shutil.copy(image_file, os.path.join(output_dir, "images", split_name))
shutil.copy(label_file, os.path.join(output_dir, "labels", split_name))
# 新增功能:生成路径文件
for split_name in ["train", "val", "test"]:
# 获取图片绝对路径列表
split_img_dir = os.path.join(output_dir, "images", split_name)
img_paths = []
# 遍历目录收集所有图片文件
for root, _, files in os.walk(split_img_dir):
for file in files:
if any(file.lower().endswith(ext) for ext in image_extensions):
img_paths.append(os.path.abspath(os.path.join(root, file)))
# 写入对应的txt文件
txt_path = os.path.join(output_dir, f"{split_name}.txt")
with open(txt_path, "w") as f:
f.write("\n".join(img_paths))
print(f"Generated {txt_path} with {len(img_paths)} entries")
print("Dataset split completed!")
def main():
image_dir = r"E:\yolov7-main\NEU-DET\IMAGES"
label_dir = r"E:\yolov7-main\NEU-DET\txt"
output_dir = r"E:\yolov7-main\NEU-DET\splitdata"
split_ratios = {
"train": 0.8,
"val": 0.1,
"test": 0.1
}
split_dataset(image_dir, label_dir, output_dir, split_ratios)
if __name__ == "__main__":
main()