目录

  • 前言
  • 0、混淆矩阵的定义
  • 一、原理详解
  • 1-1、多分类样例
  • 1-2、二分类样例
  • 1-3、不同的分类性能指标介绍
  • 二、混淆矩阵的相关API介绍以及样例
  • 2-1、混淆矩阵介绍
  • 2-2、混淆矩阵样例
  • 总结



前言

混淆矩阵用来评估分类的准确性。



0、混淆矩阵的定义

混淆矩阵(Confusion Matrix)是在机器学习中,用于对分类模型的性能进行评估的一种方法。混淆矩阵展示了模型在分类任务中的预测结果与实际标签之间的对应关系。

混淆矩阵通常是一个二维矩阵,其中每一行代表着实际标签的类别,每一列代表着预测结果的类别。在二分类问题中,混淆矩阵包括四个元素,它们分别是:

  • 真正例(True Positive, TP):表示模型将正样本正确地预测为正样本的数量。
  • 假负例(False Negative, FN):表示模型将正样本错误地预测为负样本的数量。
  • 假正例(False Positive, FP):表示模型将负样本错误地预测为正样本的数量。
  • 真负例(True Negative, TN):表示模型将负样本正确地预测为负样本的数量。

在多分类问题中,混淆矩阵的维度将会更高,它的每一个元素表示实际标签为某一类别,而预测结果为另一类别的数量。

通过观察混淆矩阵,我们可以计算出多种模型性能指标,如准确率(Accuracy)、精确率(Precision)、召回率(Recall)和 F1 值(F1-score)等,这些指标可以帮助我们更好地评估模型的分类效果。

一、原理详解

混淆矩阵:混淆矩阵是将真实值与预测值匹配以及不匹配的项一起放入到矩阵中,它可以清楚的反映出真实值和预测值相同的地方,也可以反映出与预测值不相同的地方,

1-1、多分类样例

混淆举证的Python代码 混淆矩阵 pytorch_混淆矩阵


如图所示:图示为一个情感多分类例子的混淆矩阵,从图中我们可以看出,真实样例为生气,并且被预测为生气的例子一共有98例,同一列的其他行表示真实样例为生气,但是被预测为其他情感的例子数量。混淆矩阵的正对角线表示的是真实值与预测值相互匹配的样例数。

1-2、二分类样例

假设我们有一个二元分类器,可以将电子邮件归类为垃圾邮件(positive)或非垃圾邮件(negative)。我们可以将分类器的预测结果与实际类别进行比较,并使用混淆矩阵来衡量模型的性能。

混淆举证的Python代码 混淆矩阵 pytorch_混淆举证的Python代码_02

  • 在上表中,真实标签为positive和negative的邮件分别为130和80封。分类器的预测结果包括100封垃圾邮件被正确分类为positive,20封非垃圾邮件错误地被分类为positive,30封垃圾邮件错误地被分类为negative,以及50封非垃圾邮件被正确分类为negative。
  • 基于混淆矩阵的这些信息,我们可以计算出不同的分类性能指标,例如准确率、精确率、召回率和F1分数等,这些指标有助于我们评估模型的性能和优化模型的参数。

1-3、不同的分类性能指标介绍

当我们使用混淆矩阵来评估分类器性能时,可以根据混淆矩阵中的真阳性(True Positive)、假阳性(False Positive)、真阴性(True Negative)和假阴性(False Negative)的数量计算出不同的分类性能指标。下面是一些常用的指标:

  • 准确率(Accuracy):准确率表示模型正确分类的样本数占总样本数的比例,即:

混淆举证的Python代码 混淆矩阵 pytorch_python_03

准确率对于不同类别之间的样本数量不平衡的数据集可能会有误导性。在这种情况下,可以使用其他指标来更好地评估分类器性能。
  • 精确率(Precision):精确率表示被分类器正确分类为正例的样本数占分类器预测为正例的样本总数的比例,即:

混淆举证的Python代码 混淆矩阵 pytorch_python_04

精确率的计算方式强调了分类器正确识别正例的能力。
  • 召回率(Recall):召回率表示被分类器正确分类为正例的样本数占真实正例的样本总数的比例,即:

混淆举证的Python代码 混淆矩阵 pytorch_python_05

召回率的计算方式强调了分类器正确识别所有真实正例的能力。
  • F1分数(F1 Score):F1分数是精确率和召回率的调和平均数,它可以用以下公式计算:

混淆举证的Python代码 混淆矩阵 pytorch_混淆矩阵_06

F1分数综合考虑了精确率和召回率两个指标,适用于不平衡的数据集。如果一个分类器在F1分数上得分很高,那么它既能够保持低误判率,又能够识别出大部分真正的正例。

二、混淆矩阵的相关API介绍以及样例

2-1、混淆矩阵介绍

在Python中,可以使用scikit-learn库来计算混淆矩阵。该库提供了许多用于分类问题的函数和类,包括混淆矩阵相关的API。下面介绍一些常用的API及其用法。

confusion_matrix(y_true, y_pred, labels=None, sample_weight=None, normalize=None)
该函数可以计算给定真实标签和预测标签的混淆矩阵。参数说明如下

  • y_true: 真实标签数组
  • y_pred: 预测标签数组
  • labels: 所有标签的列表(可选参数,默认为None)
  • sample_weight: 样本权重的数组(可选参数,默认为None)
  • normalize: 是否将混淆矩阵中的值归一化(可选参数,默认为None)

2-2、混淆矩阵样例

该函数返回一个二维数组,表示混淆矩阵。例如,以下代码演示了如何使用confusion_matrix函数:

# 导入相关API
from sklearn.metrics import confusion_matrix
y_true = [2, 0, 2, 2, 0, 1]
y_pred = [0, 0, 2, 2, 0, 2]
# Parameters: 参数介绍
# y_true: 真实值
# y_pred:预测值
# labels: 标签列表,默认为空,按照标签列表来重新排列混淆矩阵。
confusion_matrix(y_true, y_pred, labels=None)
array([[2, 0, 0],
       [0, 0, 1],
       [1, 0, 2]])

参考文章:
sklearn官网——混淆矩阵.


总结

嗯,今天开启贤者模式。