弱监督度量学习在all-umass/metric-learn项目中的应用指南
什么是弱监督度量学习
弱监督度量学习是度量学习的一个分支,它不需要像监督学习那样精确的标签数据,而是通过更"弱"的监督信息来学习数据点之间的距离度量。在all-umass/metric-learn项目中,弱监督算法接受的是数据点元组(如相似/不相似的点对)作为输入,而不是传统的标签数据。
弱监督度量学习的核心概念
输入数据结构
弱监督算法主要处理两种形式的输入数据:
-
3D数组形式:直接表示元组中的每个数据点
- 形状:(n_tuples, tuple_size, n_features)
- 适用于小规模数据,但内存效率不高
-
2D索引数组+预处理器:更高效的表示方式
- 只存储数据点在原始数据集中的索引
- 需要配合预处理器(preprocessor)使用,指定原始数据X
基本API工作流程
# 初始化算法
ml_algorithm = MetricLearningAlgorithm()
# 拟合模型
ml_algorithm.fit(tuples, y)
# 转换新数据
transformed_data = ml_algorithm.transform(new_data)
# 计算点对距离
distances = ml_algorithm.pair_distance(pairs)
基于点对的度量学习
算法拟合
点对学习是最常见的弱监督形式,算法接受点对和标签(+1表示相似,-1表示不相似):
from metric_learn import MMC
# 示例数据
pairs = np.array([[[1.2, 3.2], [2.3, 5.5]],
[[4.5, 2.3], [2.1, 2.3]]])
y_pairs = np.array([1, -1]) # 第一对相似,第二对不相似
# 训练模型
mmc = MMC(random_state=42)
mmc.fit(pairs, y_pairs)
预测与评分
训练好的模型可以预测新点对的相似性:
new_pairs = np.array([[[0.6, 1.6], [1.15, 2.75]],
[[3.2, 1.1], [5.4, 6.1]]])
predictions = mmc.predict(new_pairs) # 返回+1或-1
模型还提供决策函数和评分功能:
# 获取决策分数(距离的相反数)
scores = mmc.decision_function(new_pairs)
# 模型评分(默认使用ROC AUC)
test_pairs = [...]
test_y = [...]
score = mmc.score(test_pairs, test_y)
阈值校准
预测相似性需要设置距离阈值,有三种校准方式:
-
训练时自动校准:
mmc.fit(pairs, y, threshold_params={'method': 'accuracy'})
-
验证集手动校准:
mmc.calibrate_threshold(val_pairs, val_y)
-
直接设置阈值:
mmc.set_threshold(0.5)
核心算法解析
ITML (信息理论度量学习)
ITML通过最小化两个多元高斯分布之间的KL散度来学习距离度量,其优化目标是:
min_M tr(MM₀⁻¹) - log det(MM₀⁻¹) - n
约束条件:
- 相似点对的距离 ≤ u
- 不相似点对的距离 ≥ l
其中M₀是先验距离度量(默认为单位矩阵)。
特点:
- 不依赖特征值计算或半正定规划
- 可以处理多种约束类型
- 可以融入先验距离知识
示例代码:
from metric_learn import ITML
itml = ITML()
itml.fit(pairs, y)
SDML (稀疏高维度量学习)
SDML通过双重正则化在高维空间中学习稀疏度量:
- 对马氏矩阵非对角元素的L1惩罚
- M与M₀之间的对数行列式散度
优化目标: min_M tr((M₀ + ηXLXᵀ)M) - log det M + λ||M||₁,off
特点:
- 适合高维数据
- 产生稀疏解
- 计算效率高
示例代码:
from metric_learn import SDML
sdml = SDML()
sdml.fit(pairs, y)
实际应用建议
-
数据准备:
- 对于大规模数据,优先使用索引+预处理器方式
- 确保元组标签的语义清晰(相似/不相似)
-
模型选择:
- 高维数据考虑SDML
- 需要融入先验知识考虑ITML
-
评估验证:
- 使用交叉验证评估模型性能
- 注意阈值校准对预测结果的影响
-
与scikit-learn集成:
- 所有算法兼容scikit-learn的模型选择工具
- 可以用于Pipeline和GridSearchCV
from sklearn.model_selection import cross_val_score
scores = cross_val_score(mmc, pairs_indices, y, cv=5)
弱监督度量学习为缺乏精确标签的场景提供了有效的解决方案,all-umass/metric-learn项目实现了多种先进算法,开发者可以根据具体需求选择合适的算法和配置。
创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考