% ========== 训练阶段:统一预处理流程 ==========
clc; clear;
% ========== 参数配置 ==========
test_dir = 'F:\dataset\test';
val_dir = 'F:\dataset\val';
save_dir = 'F:\results';
rows = 480; cols = 480; bands = 300;
sg_params.window = 11;
sg_params.order = 3;
% ========== 加载训练数据 ==========
roi_dir = 'F:\dataset_roi_masks';
[X_train, Y_train] = load_all_rois(roi_dir);
if iscell(X_train), X_train = cell2mat(X_train); end
X_train = double(X_train);
% SG 滤波
X_train = apply_sg_filter(X_train, sg_params.window, sg_params.order);
% 去除NaN
nan_rows = any(isnan(X_train), 2);
X_train(nan_rows, :) = [];
Y_train(nan_rows, :) = [];
% PCA:训练阶段只 fit 一次
[coeff, ~, latent, ~, explained, mu] = pca(X_train);
cum_explained = cumsum(explained);
num_pc = find(cum_explained >= 95, 1);
X_train_pca = (X_train - mu) * coeff(:, 1:num_pc);
num_pc = min(num_pc, 10);
pca_coeff = coeff;
pca_mu = mu;
fprintf('选择主成分数 num_pc = %d,累计贡献率 = %.2f%%\n', num_pc, cum_explained(num_pc));
% 类别样本权重(平衡)
num_0 = sum(Y_train == 0);
num_1 = sum(Y_train == 1);
weight_0 = length(Y_train) / (2 * num_0);
weight_1 = length(Y_train) / (2 * num_1);
sample_weights = zeros(size(Y_train));
sample_weights(Y_train == 0) = weight_0;
sample_weights(Y_train == 1) = weight_1;
% ========== SVM 训练 ==========
C = 0.1;
Gamma = 0.1;
SVMModel = fitcsvm(X_train_pca, Y_train, ...
'KernelFunction', 'rbf', ...
'BoxConstraint', C, ...
'KernelScale', Gamma, ...
'Weights', sample_weights, ...
'Standardize', true, ...
'ClassNames', [0; 1]);
fprintf('训练完成,支持向量数量: %d\n', size(SVMModel.SupportVectors, 1));
% 保存模型与参数
save('SVMModel.mat', 'SVMModel', 'pca_coeff', 'pca_mu', 'sg_params', 'num_pc');
% ========== 批量预测 ==========
predict_folder(test_dir, SVMModel, rows, cols, bands, ...
sg_params, pca_coeff, pca_mu, num_pc, save_dir, 'test_');
predict_folder(val_dir, SVMModel, rows, cols, bands, ...
sg_params, pca_coeff, pca_mu, num_pc, save_dir, 'val_');
分析代码
最新发布