【机器学习】【深入浅出】混淆矩阵全解析:搞懂 TP、FP、TN、FN 与分类模型评估

1. 混淆矩阵是什么

混淆矩阵(Confusion Matrix)是用于二分类或多分类问题评估模型性能的常见工具。它通过一个矩阵来展示模型在预测时各类别之间的“混淆”情况——即真实标签和预测标签的对应关系。

二分类中的混淆矩阵示例

对于二分类问题(如“是否患病”、“好瓜/坏瓜”等),混淆矩阵通常是一个 的表格:

 

预测:负类 (Negative)

预测:正类 (Positive)

真实:负类 (Negative)

True Negative (TN)

False Positive (FP)

真实:正类 (Positive)

False Negative (FN)

True Positive (TP)

True Negative (TN):真实为负类,模型也预测为负类

False Positive (FP):真实为负类,但模型错误地预测为正类(“假正例”)

False Negative (FN):真实为正类,但模型错误地预测为负类(“假负例”)

True Positive (TP):真实为正类,模型预测也为正类

 

如果是多分类,混淆矩阵就会变成 的表格( 为类别数),行表示“真实类别”,列表示“预测类别”,对角线上的数字表示被正确预测的样本数,而非对角线上的数字表示不同类别间的混淆。


2. 怎么计算混淆矩阵

1. 预测结果与真实标签对比

• 在测试集或验证集上,模型会给出每个样本的预测标签;

• 同时,我们也知道该样本的真实标签。

2. 分类对比统计

• 对于每个样本,查看其“真实标签”和“预测标签”组合;

• 在二分类情况下,如果真实为正、预测也为正,则对“TP”计数加 1;真实为负、预测为负,则对“TN”计数加 1,依此类推;

• 对所有测试样本进行统计后,就能得到混淆矩阵四个格子的数值 (TN, FP, FN, TP)。


例子演示

假设我们有一个二分类模型,测试集中有 10 个样本。真实标签和预测标签如下表所示:

样本

真实标签 (Actual)

预测标签 (Predicted)

1

正 (1)

正 (1)

2

正 (1)

负 (0)

3

负 (0)

负 (0)

4

正 (1)

正 (1)

5

负 (0)

负 (0)

6

正 (1)

负 (0)

7

负 (0)

正 (1)

8

正 (1)

正 (1)

9

负 (0)

负 (0)

10

负 (0)

负 (0)

让我们统计一下:

TP(真正例):真实为正、预测也为正。表中样本 #1, #4, #8 共 3 个。

TN(真负例):真实为负、预测也为负。表中样本 #3, #5, #9, #10 共 4 个。

FP(假正例):真实为负、预测却为正。表中样本 #7 共 1 个。

FN(假负例):真实为正、预测却为负。表中样本 #2, #6 共 2 个。

所以混淆矩阵如下:

 

预测:负 (0)

预测:正 (1)

真实:负(0)

TN = 4

FP = 1

真实:正(1)

FN = 2

TP = 3

3. 混淆矩阵有什么作用

1. 精细化分析模型错误类型

• 混淆矩阵不仅能告诉我们模型总体上对多少样本预测正确,也能告诉我们“模型容易把正例预测成负例”还是“容易把负例预测成正例”。在很多实际任务中,这个区别非常重要(如医疗诊断中,FP 与 FN 的代价不一样)。

2. 衍生多种评价指标

• 通过混淆矩阵的 TP, TN, FP, FN,能计算准确率 (Accuracy)精确率 (Precision)召回率 (Recall)F1-score 等多种指标。

• 这些指标能更全面地评估模型在不同角度(如区分不同错误类型)上的表现。

• 在多分类情形下,通过对混淆矩阵对角线和非对角线元素进行统计,也能衍生类似指标或宏平均、加权平均等方法。



3. 帮助进行模型改进或决策

• 如果发现 FP 特别多,说明模型经常把负例误判为正例,可能需要进一步降低“假警报”;

• 如果 FN 特别多,说明模型经常把正例漏掉,可能需要提高召回率;

• 不同场景下,FP 和 FN 的代价可能不一样,混淆矩阵能帮助你更好地做出折中决定。


4. 案例讲解

1. 以医疗诊断为例:

目标:判断某病患是否患某种疾病(“阳性/患病” 记作 1,“阴性/健康” 记作 0)。

模型预测:如果把病人真实情况是患病,但模型预测成“健康”,这是 FN(假阴性),对医疗场景来说代价很大(漏诊)。

混淆矩阵

1. TP(真正例):患病且模型预测患病。

2. TN(真负例):健康且模型预测健康。

3. FP(假正例):健康但模型错误地预测为患病(假警报)。

4. FN(假阴性):患病却模型预测健康(漏诊)。


举个数字示例:

• 测试集共有 100 个病人,实际患病 20 人,健康 80 人;

• 模型预测后统计得到:TP=15, TN=75, FP=5, FN=5。

• 混淆矩阵可写为:

 

预测:健康(0)

预测:患病(1)

真实:健康(0)

TN=75

FP=5

真实:患病(1)

FN=5

TP=15

从中可以看出,FN=5,说明有 5 个实际患病的人被漏诊了;FP=5 表示有 5 个健康的人被误诊为患病。这对于医疗决策很重要:如果 FN 代价大,可能需要提高模型的敏感度(召回率),即尽量减少漏诊。


5. 乳腺癌案例代码

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

from sklearn.datasets import load_breast_cancer  # 导入乳腺癌数据集
from sklearn.model_selection import train_test_split, GridSearchCV  # 数据集划分与交叉验证调参工具
from sklearn.preprocessing import StandardScaler  # 特征标准化工具
from sklearn.linear_model import LogisticRegression  # 逻辑回归模型
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report  # 模型评估指标
from joblib import dump, load  # 模型保存和加载工具

# mac电脑设置全局字体为PingFang HK,确保中文显示正常
plt.rcParams['font.family'] = 'PingFang HK'
plt.rcParams['axes.unicode_minus'] = False  # 解决负号显示问题

def binary_classification_demo():
    """
    使用乳腺癌数据集,通过逻辑回归实现二分类预测(良性 vs. 恶性),
    并利用 GridSearchCV 调参,最后输出预测结果和评估指标,同时保存模型。
    """
    # 1. 加载乳腺癌数据集
    data = load_breast_cancer()
    X = data.data         # 特征矩阵
    y = data.target       # 目标:0=恶性, 1=良性

    # 打印英文版特征名称和目标类别
    print("乳腺癌数据集英文特征名称:", data.feature_names)
    print("乳腺癌数据集英文目标类别:", data.target_names)
    
    # 定义中文版本的特征名称和目标类别
    chinese_feature_names = [
        '平均半径', '平均纹理', '平均周长', '平均面积', '平均平滑度',
        '平均紧凑度', '平均凹度', '平均凹点', '平均对称性', '平均分形维度',
        '半径误差', '纹理误差', '周长误差', '面积误差', '平滑度误差',
        '紧凑度误差', '凹度误差', '凹点误差', '对称性误差', '分形维度误差',
        '最差半径', '最差纹理', '最差周长', '最差面积', '最差平滑度',
        '最差紧凑度', '最差凹度', '最差凹点', '最差对称性', '最差分形维度'
    ]
    chinese_target_names = ['恶性', '良性']
    
    # 打印中文版特征名称和目标类别
    print("乳腺癌数据集中文特征名称:", chinese_feature_names)
    print("乳腺癌数据集中文目标类别:", chinese_target_names)
    # 2. 打印原始数据前5条(带中文标题)
    # 将特征矩阵转换为DataFrame,并为列名添加中文说明
    # 定义完整的中文版本特征名称映射字典
    feature_mapping = {
        'mean radius': '平均半径',
        'mean texture': '平均纹理',
        'mean perimeter': '平均周长',
        'mean area': '平均面积',
        'mean smoothness': '平均平滑度',
        'mean compactness': '平均紧凑度',
        'mean concavity': '平均凹度',
        'mean concave points': '平均凹点',
        'mean symmetry': '平均对称性',
        'mean fractal dimension': '平均分形维度',
        'radius error': '半径误差',
        'texture error': '纹理误差',
        'perimeter error': '周长误差',
        'area error': '面积误差',
        'smoothness error': '平滑度误差',
        'compactness error': '紧凑度误差',
        'concavity error': '凹度误差',
        'concave points error': '凹点误差',
        'symmetry error': '对称性误差',
        'fractal dimension error': '分形维度误差',
        'worst radius': '最差半径',
        'worst texture': '最差纹理',
        'worst perimeter': '最差周长',
        'worst area': '最差面积',
        'worst smoothness': '最差平滑度',
        'worst compactness': '最差紧凑度',
        'worst concavity': '最差凹度',
        'worst concave points': '最差凹点',
        'worst symmetry': '最差对称性',
        'worst fractal dimension': '最差分形维度'
    }
    
    # 2.1 将特征矩阵转换为 DataFrame,并将英文列名替换为中文
    df_features = pd.DataFrame(X, columns=data.feature_names)
    df_features.rename(columns=feature_mapping, inplace=True)
    # 将目标值转换为 DataFrame,添加中文标题
    df_target = pd.DataFrame(y, columns=['肿瘤类型 (0=恶性, 1=良性)'])
    
    print("\n【原始数据预览】前5个样本的特征:")
    print(df_features.head(5))
    print("\n【原始数据预览】前5个样本的目标值:")
    print(df_target.head(5))
    
    # 3. 划分训练集和测试集(80%训练,20%测试)
    # 使用 train_test_split 函数将数据分为训练集和测试集,其中 test_size=0.2 表示 20% 的数据作为测试集,
    # random_state=42 用于固定随机种子,使得数据划分结果每次运行都相同,便于结果复现。
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.2, random_state=42
    )
    
    # 4. 标准化处理:对训练集和测试集的特征进行标准化
    # 实际上,不同特征的取值范围可能存在较大差异,使用 StandardScaler 将数据转换为均值为0、标准差为1的分布,
    # 这有助于提高模型训练时的稳定性和收敛速度。
    scaler = StandardScaler()
    X_train_scaled = scaler.fit_transform(X_train)  # 使用训练集数据计算均值和标准差,并对训练集进行转换
    X_test_scaled = scaler.transform(X_test)        # 使用训练集计算得到的均值和标准差对测试集进行转换
    
    # 5. 构建逻辑回归模型,并使用 GridSearchCV 调参
    # 创建逻辑回归模型实例 log_reg,设置 max_iter=10000 确保模型在收敛前有足够的迭代次数,
    # random_state=42 用于固定随机种子以保证结果可重复。逻辑回归默认使用 L2 正则化。
    log_reg = LogisticRegression(max_iter=10000, random_state=42)
    
    # 定义一个参数网格 param_grid,用于调节正则化参数 C。
    # 在逻辑回归中,C 是正则化参数的倒数,值越大表示正则化越弱;这里尝试的取值有 0.01, 0.1, 1, 10。
    param_grid = {
        'C': [0.01, 0.1, 1, 10]  
    }
    
    # 使用 GridSearchCV 对逻辑回归模型进行超参数调优。
    # 参数说明:
    #   - log_reg: 要调参的模型实例。
    #   - param_grid: 参数搜索范围。
    #   - cv=3: 使用3折交叉验证,确保模型在不同数据划分下都能稳定表现。
    #   - scoring='accuracy': 以准确率作为评价指标,准确率越高越好。
    #   - n_jobs=-1: 使用所有可用的 CPU 核心加速运算。
    grid = GridSearchCV(log_reg, param_grid, cv=3, scoring='accuracy', n_jobs=-1)
    
    # 在标准化后的训练数据上训练模型,并自动搜索最佳参数组合
    grid.fit(X_train_scaled, y_train)
    
    # 得到经过调参后的最佳逻辑回归模型
    best_log_reg = grid.best_estimator_
    
    # 输出最佳参数组合
    print("\n逻辑回归最佳参数:", grid.best_params_)
    
    # 6. 在测试集上预测,并计算准确率及评估指标
    
    # y_pred = best_log_reg.predict(X_test_scaled)
    # 使用最佳逻辑回归模型 best_log_reg 对经过标准化的测试集数据 X_test_scaled 进行预测,得到预测标签 y_pred。
    
    y_pred = best_log_reg.predict(X_test_scaled)
    
    # acc = accuracy_score(y_test, y_pred)
    # 计算准确率(accuracy),即预测正确的样本数占测试集中样本总数的比例。
    # 对于二分类问题,accuracy_score 能直观反映模型对正负类整体预测的准确程度。
    acc = accuracy_score(y_test, y_pred)
    
    # print("\n逻辑回归测试集准确率:{:.4f}".format(acc))
    # 打印测试集准确率,四舍五入到小数点后 4 位。
    print("\n逻辑回归测试集准确率:{:.4f}".format(acc))
    
    # print("\n混淆矩阵:\n", confusion_matrix(y_test, y_pred))
    # 混淆矩阵能够展示模型在各类别上的预测正确与错误情况:
    #   - 行表示真实标签,列表示预测标签
    #   - 对于二分类任务:矩阵形状为 2x2:
    #       [ [TN, FP],
    #         [FN, TP] ]
    #     其中:
    #       TN:真负(真实为负,预测也为负)
    #       FP:假正(真实为负,预测为正)
    #       FN:假负(真实为正,预测为负)
    #       TP:真正(真实为正,预测为正)
    # 通过混淆矩阵可以更细致地分析模型错误类型。
    print("\n混淆矩阵:\n", confusion_matrix(y_test, y_pred))
    
    # print("\n分类报告:\n", classification_report(y_test, y_pred))
    # 分类报告(classification_report)会输出模型在每个类别上的精准率(precision)、召回率(recall)、
    # F1-score 以及支持(support)信息:
    #   - precision:预测为该类的样本中,有多少是真实属于该类
    #   - recall:真实属于该类的样本中,有多少被正确预测为该类
    #   - f1-score:precision 与 recall 的调和平均
    #   - support:该类别在真实标签中出现的样本数
    # 另外还包括整体上的 accuracy、macro avg、weighted avg 等指标,用于综合衡量模型表现。
    print("\n分类报告:\n", classification_report(y_test, y_pred))
    
    # 7. 可视化混淆矩阵
    
    # plt.figure(figsize=(6, 4))
    # 设置画布大小为 6 (宽) x 4 (高) 英寸
    plt.figure(figsize=(6, 4))
    
    # cm = confusion_matrix(y_test, y_pred)
    # 再次计算混淆矩阵,得到一个 2x2 数组,方便后续可视化。
    cm = confusion_matrix(y_test, y_pred)
    
    # plt.matshow(cm, cmap=plt.cm.Blues)
    # 使用 matplotlib 的 matshow 方法可视化混淆矩阵,指定颜色映射 cmap=plt.cm.Blues
    # 这样矩阵中的数值越大,颜色越深。
    plt.matshow(cm, cmap=plt.cm.Blues)
    
    # plt.title("混淆矩阵", pad=20)
    # 设置图像标题为 "混淆矩阵",并在标题下方留 20 像素的空间。
    plt.title("混淆矩阵", pad=20)
    
    # plt.colorbar()
    # 显示颜色条,用于表明颜色深浅对应的数值范围。
    plt.colorbar()
    
    # plt.xlabel("预测标签")
    # plt.ylabel("真实标签")
    # 设置 x 轴和 y 轴标签,方便阅读:
    #   - x 轴为预测标签(列)
    #   - y 轴为真实标签(行)
    plt.xlabel("预测标签")
    plt.ylabel("真实标签")
    
    # plt.show()
    # 显示绘制好的混淆矩阵图
    plt.show()
    
    # 8. 保存模型和标准化器
    
    # dump(best_log_reg, "logistic_regression_model.joblib")
    # 使用 joblib.dump 将训练得到的最佳逻辑回归模型保存到本地文件 logistic_regression_model.joblib,
    # 以便后续无需重新训练即可加载并进行预测。
    dump(best_log_reg, "logistic_regression_model.joblib")
    
    # dump(scaler, "scaler_breast.joblib")
    # 同理,将标准化器 scaler 保存为 scaler_breast.joblib,
    # 这样在实际部署或后续使用时可保证测试数据与训练数据使用相同的标准化方式。
    dump(scaler, "scaler_breast.joblib")
    
    # print("\n逻辑回归模型已保存。")
    # 提示用户模型已成功保存。
    print("\n逻辑回归模型已保存。")

if __name__ == "__main__":
    binary_classification_demo()

输出:

第一部分输出:

第二部分输出:

第三部分输出:


6. 关键点总结

1. 混淆矩阵 在分类问题中是一个极为重要的工具,可以直观展示真实标签与预测标签之间的匹配与差异。

2. 二分类场景:常用 TN, FP, FN, TP 四个要素来描述模型的预测结果;多分类场景:矩阵规模变成 ,同样展示类别间的混淆情况。

3. 计算方式:对于测试集中每个样本,比较其“真实标签”和“预测标签”组合,在相应的混淆矩阵格子里加 1。

4. 作用

• 帮助发现模型倾向于哪种错误(FP or FN);

• 可以派生 Accuracy、Precision、Recall、F1 等评价指标;

• 帮助在实际业务中做“错误代价”评估,例如医疗中 FN 和 FP 的不同后果。

5. 案例:以医疗诊断为例,混淆矩阵能直观体现“漏诊”和“误诊”的数量,指导模型阈值或改进方向。

结论:混淆矩阵是分类模型评估的核心工具之一,通过它你可以更深入地理解模型在不同类别之间的预测分布,从而在实际业务场景中更好地做决策和模型改进。


7.常见问题:

问题1:既然容易出现漏诊和误诊,那是不是生活中都是使用的 f1 score 来评分就好了呢?

在实际应用中,并不一定都是使用 F1 指标。选择哪种评估指标主要取决于以下因素:

1. 业务场景和错误代价

• 如果你在医疗诊断场景中,漏诊(FN)带来的代价极高,则会更加关注 召回率(Recall)

• 如果在垃圾邮件过滤等场景中,错把正常邮件当垃圾邮件(FP)带来的影响可能比漏掉垃圾邮件(FN)小或大,需要根据实际情况确定是更关心精确率(Precision)还是召回率(Recall)。

2. 类别不平衡

• 如果数据中正例远少于负例(如罕见疾病检测),准确率(Accuracy)往往会失去意义,因为即便模型把所有样本预测为负例也能得到很高的 Accuracy;

• 此时更关注的是 Precision、Recall、F1-score 或者 AUC(ROC 曲线下面积)等能反映不平衡数据下模型表现的指标。

3. 综合度量需求

F1-scorePrecisionRecall 的调和平均,能够在二者之间取得折中,如果你同时关心模型不要过多漏诊(召回率)也不要过多误报(精确率),F1-score 会是一个平衡点;

• 但如果在你的场景中,漏诊(FN)或误报(FP)的代价有明显差别,则可能需要关注特定的指标(如更关注召回率、或更关注精确率)。

4. 可解释性与业务需求

• 有时业务方只需要一个简单易懂的指标(如 Accuracy);

• 有时为了更精准地衡量少数类,你可能需要查看 RecallPrecision

• 也可能需要看 ROC 曲线PR 曲线 等图形化指标来做决策。

小结

F1-score 在很多情况下确实被广泛使用,因为它在精确率和召回率之间做了折中,适合大多数场景的初步评估;

• 但在真正的业务决策中,往往会结合多种指标(如 Precision、Recall、Accuracy、AUC 等)以及实际错误代价来选择或制定最合适的度量方式。

• 因此,并不是所有场景都用 F1-score,也没有“唯一正确”的指标,指标的选取应当与具体的业务需求、错误代价、数据分布等因素相匹配。


问题2: 这里的TFPN 这几个是什么 缩写呢,是对什么的修饰呢,怎么理解呢,然后应该怎么方便记住呢,特别是生活中的 假阴性,假阳性又怎么理解呢?

TFPN 通常指的是 True Positive (TP)、False Positive (FP)、True Negative (TN)、False Negative (FN) 这四个指标,常见于二分类任务的混淆矩阵。下面详细解释它们的含义、如何理解和记忆,尤其是假阳性(FP)和假阴性(FN)

1. TP、FP、TN、FN 的含义

1.1 TP(True Positive,真正例)

“预测为正类”,且真实也是正类

• 例子:在医疗诊断中,“正”可以代表“患病”,TP 就是模型预测患病,且病人确实患病


1.2 FP(False Positive,假阳性)

“预测为正类”,但真实是负类

• 例子:在医疗诊断中,模型认为“患病”,但病人实际上并没有患病;相当于“虚惊一场”。

• 又称**“Type I Error”(第一类错误)或“False Alarm”**(假警报)。


1.3 TN(True Negative,真负例)

“预测为负类”,且真实也是负类

• 例子:模型预测“健康”,而病人确实健康。


1.4 FN(False Negative,假阴性)

“预测为负类”,但真实是正类

• 例子:在医疗诊断中,模型预测“健康”,但病人实际上患病;相当于漏诊

• 又称**“Type II Error”(第二类错误)或“Missed Detection”**(漏报)。


2. 为什么叫“正 (Positive)” 和 “负 (Negative)”

• “Positive (正)” 通常指模型预测样本属于我们关心的那一类(如“患病” / “有故障” / “垃圾邮件”等)。

• “Negative (负)” 则是模型预测样本不属于那个目标类(如“健康” / “正常” / “正常邮件”)。

• 在实际业务中,哪一类是“正”,哪一类是“负”,通常由数据编码或业务需求决定;一般将“少数且更重要的类”编码为 1(正),将另一个类编码为 0(负)。

3. 记忆技巧

1. True/False 表示模型预测与真实标签是否一致:

True:预测与真实相同

False:预测与真实不同

2. Positive/Negative 表示模型预测为正类还是负类:

Positive:模型预测为“正”

Negative:模型预测为“负”


于是:

TP = 预测为正,且预测正确 => 真实也为正

FP = 预测为正,但预测错误 => 真实是负 => “假阳性”

TN = 预测为负,且预测正确 => 真实也为负

FN = 预测为负,但预测错误 => 真实是正 => “假阴性”


如何快速记住?

• “假阳性(FP)” 可以想成“模型报了警,但实际上没有问题”——即虚惊一场

• “假阴性(FN)” 可以想成“模型说没事,但实际上出事了”——即漏诊

4. 应用案例示例

医疗诊断为例:判断一个病人是否患有某种疾病。

正类 (Positive):患病

负类 (Negative):健康


假设测试集中有 10 个病人,真实患病 4 人、健康 6 人。模型预测后统计到:

TP = 3:真实患病,预测也患病

FP = 2:真实健康,但模型预测患病 => “假阳性”

TN = 4:真实健康,模型预测也健康

FN = 1:真实患病,但模型预测健康 => “假阴性”


因此得到的混淆矩阵为:

 

预测:健康(0)

预测:患病(1)

真实:健康(0)

TN=4

FP=2

真实:患病(1)

FN=1

TP=3

• FP=2 表示有 2 位病人其实没病,但被模型误判为“患病”,算是“虚惊一场”。

• FN=1 表示有 1 位病人其实患病,却被模型漏诊(模型预测“健康”),可能带来严重后果。

5. 常见衍生指标

1. Accuracy(准确率)

• 衡量模型整体预测正确的比例。

2. Precision(精确率)

• 在所有预测为正的样本中,真正为正的比例;反映“假阳性”多少。

3. Recall(召回率)

• 在所有真实为正的样本中,成功预测为正的比例;反映“假阴性”多少。

4. F1-score:精确率和召回率的调和平均,用于平衡二者的影响。

6. 小结

1. TFPN 是指 True Positive、False Positive、True Negative、False Negative 四个概念,常用于二分类混淆矩阵的四个格子:

TP(真正例):真实为正、预测也为正

FP(假阳性):真实为负、预测却为正

TN(真负例):真实为负、预测也为负

FN(假阴性):真实为正、预测却为负

2. 记忆方法

• “Positive”/“Negative”表示模型的预测标签;

• “True”/“False”表示预测是否与真实相符;

• “FP” = “假阳性” => 模型报了警(预测正),却是虚惊;

• “FN” = “假阴性” => 模型没报警(预测负),却真出事了。

3. 作用

• 混淆矩阵能更细致地分析模型错误类型;

• 由 TP, FP, TN, FN 可以计算 Accuracy、Precision、Recall、F1 等多种指标,帮助评估模型的不同方面表现。


通过理解 TFPN 及混淆矩阵,你可以在分类任务中更准确地判断模型是在哪些方面表现良好,在哪些方面需要改进,并针对具体业务场景(如医疗诊断、垃圾邮件过滤等)制定更合适的策略或阈值。


喜欢本文的朋友请点赞、收藏、转发,并关注我的博客,让更多人加入我们的数据科学狂欢吧!


希望这篇文章和代码讲解对你有所帮助!如果有任何疑问或建议,欢迎在评论区留言讨论,共同进步!

 

### 目标检测中计算TPFPFN的代码实现 对于目标检测任务而言,理解如何计算真阳性(True Positive, TP)、假阳性(False Positive, FP)以及假阴性(False Negative, FN)至关重要。这不仅有助于评估模型性能,还能帮助定位具体问题所在。 #### 计算逻辑说明 为了判断预测框是否为TP/FP/FN,通常会设定一个交并比阈值(Intersection over Union, IoU)。当预测框真实框之间的IoU超过该阈值,则认为是一个有效的匹配;反之则不是。基于此原则: - **TP**:预测框成功匹配到对应的真实框; - **FP**:未能找到匹配的真实框或被错误分类的目标; - **FN**:实际存在但未被正确检测出来的对象[^1]。 #### Python代码示例 下面给出一段Python代码用于计算上述三个量,在这里假设已经获得了所有测试图片上的ground truth bounding boxes列表`gt_boxes`和对应的detection results `pred_boxes`: ```python def calculate_tp_fp_fn(gt_boxes, pred_boxes, iou_threshold=0.5): """ Calculate the number of true positives (TP), false positives (FP), and false negatives (FN) based on given ground-truth and prediction data. Parameters: gt_boxes : list[List[float]] List containing lists where each sublist represents a single GT box with format [xmin,ymin,xmax,ymax]. pred_boxes : list[tuple] List containing tuples where each tuple contains two elements, first element is confidence score(float), second one is predicted bbox(list). iou_threshold : float Threshold value used to determine whether an object has been correctly detected. Returns: dict: Dictionary holding counts for 'tp', 'fp' and 'fn'. """ from itertools import chain # Sort predictions by their scores in descending order sorted_pred_boxes = sorted(pred_boxes, key=lambda x:x[0], reverse=True) matched_gt_indices = set() tp_count = fp_count = fn_count = 0 for conf_score, pred_bbox in sorted_pred_boxes: best_iou = -float('inf') match_index = None for idx, gt_bbox in enumerate(gt_boxes): current_iou = compute_iou(pred_bbox, gt_bbox) if current_iou >= iou_threshold and current_iou > best_iou \ and idx not in matched_gt_indices: best_iou = current_iou match_index = idx if best_iou != -float('inf'): tp_count += 1 matched_gt_indices.add(match_index) else: fp_count += 1 total_gts = len(set(chain(*[[i]*len(b) for i,b in enumerate(gt_boxes)]))) unmatched_gts = total_gts - len(matched_gt_indices) fn_count = unmatched_gts return {"tp": tp_count, "fp": fp_count, "fn": fn_count} def compute_iou(boxA, boxB): """Compute intersection-over-union between two bounding boxes.""" # Determine coordinates of intersecting rectangle xA = max(boxA[0], boxB[0]) yA = max(boxA[1], boxB[1]) xB = min(boxA[2], boxB[2]) yB = min(boxA[3], boxB[3]) interArea = max(0, xB - xA + 1) * max(0, yB - yA + 1) area_A = (boxA[2]-boxA[0]+1)*(boxA[3]-boxA[1]+1) area_B = (boxB[2]-boxB[0]+1)*(boxB[3]-boxB[1]+1) union_area = float(area_A + area_B - interArea) return interArea / union_area if union_area!=0 else 0 if __name__ == "__main__": # Example usage gt_bboxes = [[100, 78, 394, 282]] # Format:[xmin, ymin, xmax, ymax] det_results = [(0.9, [98, 76, 396, 284]), (.4, [200, 200, 400, 400])] result = calculate_tp_fp_fn(gt_bboxes, det_results) print(f"Results:\nTP={result['tp']}\nFP={result['fp']}\nFN={result['fn']}") ``` 这段程序定义了一个函数`calculate_tp_fp_fn()`来接收真实的边界框(`gt_boxes`)和预测的结果(`pred_boxes`)作为输入参数,并返回字典形式的结果,其中包含了TPFPFN 的数量。此外还提供了一个辅助方法`compute_iou()`用来计算两个矩形区域间的IOU值[^2]。
评论 7
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值