Python:K折交叉验证,将数据集分成训练集与测试集

该代码实现了对图像数据集的交叉验证,用于分类任务的验证。通过sklearn的KFold进行k折划分,根据指定的根目录和交叉验证次数,将数据集分为训练集和测试集,保存到新的路径下,方便后续训练过程使用。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

注意文件夹格式:父文件夹/类别/图像(同torch读取图像格式保存一致),传入路径为父文件夹路径。


"""
对图像进行交叉验证, 用于检验分类效果
对每个类别的n张图像进行交叉验证分类 获取数据集 从而在训练网络时进行交叉验证
输入:数据集路径 保存数据集的位置 k折交叉验证


输出:k个数据集
将一个数据集分成k份,其中由k-1份组成训练集
余下1份组成测试集

2022年04月14日20:13:48
进行如下修改: 移除删除特定的隐藏文件, 判断i是否为目录, 当为目录时才进行操作; 如想对item进行类型判断 请使用item.endwith('.jpg')
"""
import os
import shutil
import time

from sklearn.model_selection import KFold


def make_path(path_targe):
    """
    输入一个路径, 如果存在就删除 不存在就生成
    :param path_targe:
    :return:
    """
    # 判断是否存在并重新创建文件夹
    if os.path.exists(path_targe):
        shutil.rmtree(path_targe)
        os.makedirs(path_targe)
        print("succeed : ", path_targe)
    else:
        os.makedirs(path_targe)
        print("succeed : ", path_targe)


if __name__ == '__main__':
    # ==================设置超参数===================== #
    #  root_data : 数据路径    格式:文件名->n个类别->m张图像(需符合pytorch读取训练数据集要求)
    root_data = "/Users/yida/Desktop/土壤_wwr/实验二/dataset"
    #  save_path:存放数据路径
    save_path = "/Users/yida/Desktop"
    #  k_num : 设置交叉验证数
    k_num = 5
    # =============================================== #
    # 定义交叉验证:设置参数  shuffle=True 打乱, random_state随机数种子
    kf = KFold(n_splits=k_num, shuffle=True, random_state=0)

    # 生成存放文件目录
    save_path = os.path.join(save_path, time.strftime("%Y-%m-%d"))
    file = os.listdir(root_data)
    start_time = time.time()  # 记录操作时间
    for i in file:
        root_sub = os.path.join(root_data, i)
        flag = os.path.isdir(root_sub)  # 验证root_sub是否为目录
        if flag:
            file_sub = os.listdir(root_sub)
            print(file_sub)
            for n, data in enumerate(kf.split(file_sub)):
                train, test = data
                save_path_sub_train = os.path.join(save_path, str(n), "train", i)
                save_path_sub_test = os.path.join(save_path, str(n), "test", i)
                # 生成路径
                make_path(save_path_sub_train)
                make_path(save_path_sub_test)
                print(len(train), len(test))
                for item in train:
                    img_name = file_sub[item]
                    # 图像路径
                    path_train = os.path.join(root_sub, img_name)
                    # 目标路径
                    targe_path = os.path.join(save_path_sub_train, img_name)
                    # 开始移动
                    shutil.copy(path_train, targe_path)
                    print("{} to {}...".format(path_train, targe_path))
                for item in test:
                    img_name = file_sub[item]
                    # 图像路径
                    path_train = os.path.join(root_sub, img_name)
                    # 目标路径
                    targe_path = os.path.join(save_path_sub_test, img_name)
                    # 开始移动
                    shutil.copy(path_train, targe_path)
                    print("{} to {}...".format(path_train, targe_path))
    print("End...{}折交叉验证已完成 ,  数据集路径为: {}".format(k_num, save_path))
    end_time = time.time()


【推荐阅读】

Pytorch-GPU安装教程大合集(Perfect完美系列)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

陈嘿萌

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值