机器学习算法(九): 基于线性判别模型的分类
2.LDA降维最多降到类别数 k-1 的维数,如果我们降维的维度大于 k-1,则不能使用 LDA。当然目前有一些LDA的进化版算法可以绕过这个问题
LDA在模式识别领域(比如人脸识别,舰艇识别等图形图像识别领域)中有非常广泛的应用,因此我们有必要了解一下它的算法原理。不过在学习LDA之前,我们有必要将其与自然语言处理领域中的LDA区分开,在自然语言处理领域,LDA是隐含狄利克雷分布(Latent DIrichlet Allocation,简称LDA),它是一种处理文档的主题模型,我们本文讨论的是线性判别分析,因此后面所说的LDA均为线性判别分析。
- 掌握LDA算法基本原理
- 掌握利用LDA进行代码实战
Part 1 Demo实践
- Step1:库函数导入
- Step2:模型训练
- Step3:模型参数查看
- Step4:数据和模型可视化
- Step5:模型预测
Part 2 基于LDA手写数字分类实践
- Step1:库函数导入
- Step2:数据读取/载入
- Step3:数据信息简单查看与可视化
- Step4:利用LDA在手写数字上进行训练和预测
4.1 Demo实践
- Step1:库函数导入
# 基础数组运算库导入
import numpy as np
# 画图库导入
import matplotlib.pyplot as plt
# 导入三维显示工具
from mpl_toolkits.mplot3d import Axes3D
# 导入LDA模型
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
# 导入demo数据制作方法
from sklearn.datasets.samples_generator import make_classification
- Step2:模型训练
# 制作四个类别的数据,每个类别100个样本
X, y = make_classification(n_samples=1000, n_features=3, n_redundant=0,
n_classes=4, n_informative=2, n_clusters_per_class=1,
class_sep=3, random_state=10)
# 将四个类别的数据进行三维显示
fig = plt.figure()
ax = Axes3D(fig, rect=[0, 0, 1, 1], elev=20, azim=20)
ax.scatter(X[:, 0], X[:, 1], X[:, 2], marker='o', c=y)
# 建立 LDA 模型
lda = LinearDiscriminantAnalysis()
# 进行模型训练
lda.fit(X, y)
LinearDiscriminantAnalysis(n_components=None, priors=None, shrinkage=None,
solver=‘svd’, store_covariance=False, tol=0.0001)
- Step3:模型参数查看
# 查看 LDA 模型的参数
{‘n_components’: None,
‘priors’: None,
‘shrinkage’: None,
‘solver’: ‘svd’,
‘store_covariance’: False,
‘tol’: 0.0001}
- Step4:数据和模型可视化
# 进行模型预测
X_new = lda.transform(X)
# 可视化预测数据
plt.scatter(X_new[:, 0], X_new[:, 1], marker='o', c=y)
- Step5:模型预测
# 进行新的测试数据测试
a = np.array([[-1, 0.1, 0.1]])
print(f"{a} 类别是: ", lda.predict(a))
print(f"{a} 类别概率分别是: ", lda.predict_proba(a))
a = np.array([[-12, -100, -91]])
print(f"{a} 类别是: ", lda.predict(a))
print(f"{a} 类别概率分别是: ", lda.predict_proba(a))
a = np.array([[-12, -0.1, -0.1]])
print(f"{a} 类别是: ", lda.predict(a))
print(f"{a} 类别概率分别是: ", lda.predict_proba(a))
a = np.array([[0.1, 90.1, 9.1]])
print(f"{a} 类别是: ", lda.predict(a))
print(f"{a} 类别概率分别是: ", lda.predict_proba(a))
[[-1. 0.1 0.1]] 类别是: [0]
[[-1. 0.1 0.1]] 类别概率分别是: [[9.37611354e-01 1.88760664e-05 3.36891510e-02 2.86806189e-02]]
[[ -12 -100 -91]] 类别是: [1]
[[ -12 -100 -91]] 类别概率分别是: [[1.08769337e-028 1.00000000e+000 1.54515810e-221 9.05666876e-183]]
[[-12. -0.1 -0.1]] 类别是: [2]
[[-12. -0.1 -0.1]] 类别概率分别是: [[1.60268201e-07 1.46912978e-39 9.99999840e-01 3.57001075e-28]]
[[ 0.1 90.1 9.1]] 类别是: [3]
[[ 0.1 90.1 9.1]] 类别概率分别是: [[8.42065614e-08 9.45021749e-11 8.63060269e-02 9.13693889e-01]]
Part 2 基于LDA手写数字分类实践
- Step1:库函数导入
# 导入手写数据集 MNIST
from sklearn.datasets import load_digits
# 导入训练集分割方法
from sklearn.model_selection import train_test_split
# 导入LDA模型
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
# 导入预测指标计算函数和混淆矩阵计算函数
from sklearn.metrics import classification_report, confusion_matrix
# 导入绘图包
import seaborn as sns
import matplotlib
- Step2:数据读取/载入
# 导入MNIST数据集
mnist = load_digits()
# 查看数据集信息
print('The Mnist dataeset:\n',mnist)
# 分割数据为训练集和测试集
x, test_x, y, test_y = train_test_split(mnist.data, mnist.target, test_size=0.1, random_state=2)
The Mnist dataeset:
{‘data’: array([[ 0., 0., 5., …, 0., 0., 0.],
[ 0., 0., 0., …, 10., 0., 0.],
[ 0., 0., 0., …, 16., 9., 0.],
[ 0., 0., 1., …, 6., 0., 0.],
[ 0., 0., 2., …, 12., 0., 0.],
[ 0., 0., 10., …, 12., 1., 0.]]), ‘target’: array([0, 1, 2, …, 8, 9, 8]), ‘target_names’: array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]), ‘images’: array([[[ 0., 0., 5., …, 1., 0., 0.],
[ 0., 0., 13., …, 15., 5., 0.],
[ 0., 3., 15., …, 11., 8., 0.],
[ 0., 4., 11., …, 12., 7., 0.],
[ 0., 2., 14., …, 12., 0., 0.],
[ 0., 0., 6., …, 0., 0., 0.]],
[[ 0., 0., 0., …, 5., 0., 0.],
[ 0., 0., 0., …, 9., 0., 0.],
[ 0., 0., 3., …, 6., 0., 0.],
[ 0., 0., 1., …, 6., 0., 0.],
[ 0., 0., 1., …, 6., 0., 0.],
[ 0., 0., 0., …, 10., 0., 0.]],
[[ 0., 0., 0., …, 12., 0., 0.],
[ 0., 0., 3., …, 14., 0., 0.],
[ 0., 0., 8., …, 16., 0., 0.],
[ 0., 9., 16., …, 0., 0., 0.],
[ 0., 3., 13., …, 11., 5., 0.],
[ 0., 0., 0., …, 16., 9., 0.]],
[[ 0., 0., 1., …, 1., 0., 0.],
[ 0., 0., 13., …, 2., 1., 0.],
[ 0., 0., 16., …, 16., 5., 0.],
[ 0., 0., 16., …, 15., 0., 0.],
[ 0., 0., 15., …, 16., 0., 0.],
[ 0., 0., 2., …, 6., 0., 0.]],
[[ 0., 0., 2., …, 0., 0., 0.],
[ 0., 0., 14., …, 15., 1., 0.],
[ 0., 4., 16., …, 16., 7., 0.],
[ 0., 0., 0., …, 16., 2., 0.],
[ 0., 0., 4., …, 16., 2., 0.],
[ 0., 0., 5., …, 12., 0., 0.]],
[[ 0., 0., 10., …, 1., 0., 0.],
[ 0., 2., 16., …, 1., 0., 0.],
[ 0., 0., 15., …, 15., 0., 0.],
[ 0., 4., 16., …, 16., 6., 0.],
[ 0., 8., 16., …, 16., 8., 0.],
[ 0., 1., 8., …, 12., 1., 0.]]]), ‘DESCR’: “… _digits_dataset:\n\nOptical recognition of handwritten digits dataset\n--------------------------------------------------\n\nData Set Characteristics:\n\n :Number of Instances: 5620\n :Number of Attributes: 64\n :Attribute Information: 8x8 image of integer pixels in the range 0…16.\n :Missing Attribute Values: None\n :Creator: E. Alpaydin (alpaydin ‘@’ boun.edu.tr)\n :Date: July; 1998\n\nThis is a copy of the test set of the UCI ML hand-written digits datasets\nhttps://archive.ics.uci.edu/ml/datasets/Optical+Recognition+of+Handwritten+Digits\n\nThe data set contains images of hand-written digits: 10 classes where\neach class refers to a digit.\n\nPreprocessing programs made available by NIST were used to extract\nnormalized bitmaps of handwritten digits from a preprinted form. From a\ntotal of 43 people, 30 contributed to the training set and different 13\nto the test set. 32x32 bitmaps are divided into nonoverlapping blocks of\n4x4 and the number of on pixels are counted in each block. This generates\nan input matrix of 8x8 where each element is an integer in the range\n0…16. This reduces dimensionality and gives invariance to small\ndistortions.\n\nFor info on NIST preprocessing routines, see M. D. Garris, J. L. Blue, G.\nT. Candela, D. L. Dimmick, J. Geist, P. J. Grother, S. A. Janet, and C.\nL. Wilson, NIST Form-Based Handprint Recognition System, NISTIR 5469,\n1994.\n\n… topic:: References\n\n - C. Kaynak (1995) Methods of Combining Multiple Classifiers and Their\n Applications to Handwritten Digit Recognition, MSc Thesis, Institute of\n Graduate Studies in Science and Engineering, Bogazici University.\n - E. Alpaydin, C. Kaynak (1998) Cascading Classifiers, Kybernetika.\n - Ken Tang and Ponnuthurai N. Suganthan and Xi Yao and A. Kai Qin.\n Linear dimensionalityreduction using relevance weighted LDA. School of\n Electrical and Electronic Engineering Nanyang Technological University.\n 2005.\n - Claudio Gentile. A New Approximate Maximal Margin Classification\n Algorithm. NIPS. 2000.”}
- Step3:数据信息简单查看与可视化
## 输出示例图像
images = range(0,9)
for i in images:
plt.subplot(330 + 1 + i)
plt.imshow(x[i].reshape(8, 8), cmap = matplotlib.cm.binary,interpolation="nearest")
# show the plot
- Step4:利用LDA在手写数字上进行训练和预测
# 建立 LDA 模型
m_lda = LinearDiscriminantAnalysis()
# 进行模型训练
m_lda.fit(x, y)
LinearDiscriminantAnalysis(n_components=None, priors=None, shrinkage=None,
solver=‘svd’, store_covariance=False, tol=0.0001)
# 进行模型预测
x_new = m_lda.transform(x)
# 可视化预测数据
plt.scatter(x_new[:, 0], x_new[:, 1], marker='o', c=y)
plt.title('MNIST with LDA Model')
# 进行测试集数据的类别预测
y_test_pred = m_lda.predict(test_x)
print("测试集的真实标签:\n", test_y)
print("测试集的预测标签:\n", y_test_pred)
[4 0 9 1 4 7 1 5 1 6 6 7 6 1 5 5 4 6 2 7 4 6 4 1 5 2 9 5 4 6 5 6 3 4 0 9 9
8 4 6 8 8 5 7 9 6 9 6 1 3 0 1 9 7 3 3 1 1 8 8 9 8 5 4 4 7 3 5 8 4 3 1 3 8
7 3 3 0 8 7 2 8 5 3 8 7 6 4 6 2 2 0 1 1 5 3 5 7 6 8 2 2 6 4 6 7 3 7 3 9 4
7 0 3 5 8 5 0 3 9 2 7 3 2 0 8 1 9 2 1 9 1 0 3 4 3 0 9 3 2 2 7 3 1 6 7 2 8
3 1 1 6 4 8 2 1 8 4 1 3 1 1 9 5 4 8 7 4 8 9 5 7 6 9 0 0 4 0 0 4]
[4 0 9 1 8 7 1 5 1 6 6 7 6 2 5 5 8 6 2 7 4 6 4 1 5 2 9 5 4 6 5 6 3 4 0 9 9
8 4 6 8 1 5 7 9 6 9 6 1 3 0 1 9 7 3 3 1 1 8 8 9 8 5 8 4 9 3 5 8 4 3 9 3 8
7 3 3 0 8 7 2 8 5 3 8 7 6 4 6 2 2 0 1 1 5 3 5 7 1 8 2 2 6 4 6 7 3 7 3 9 4
7 0 3 5 1 5 0 3 9 2 7 3 2 0 8 1 9 2 1 9 9 0 3 4 3 0 8 3 2 2 7 3 1 6 7 2 8
3 1 1 6 4 8 2 1 8 4 1 3 1 1 9 5 4 9 7 4 8 9 5 7 6 9 6 0 4 0 0 9]
# 进行预测结果指标统计 统计每一类别的预测准确率、召回率、F1分数
print(classification_report(test_y, y_test_pred))
precision recall f1-score support
0 1.00 0.93 0.96 14
1 0.86 0.86 0.86 22
2 0.93 1.00 0.97 14
3 1.00 1.00 1.00 22
4 1.00 0.81 0.89 21
5 1.00 1.00 1.00 16
6 0.94 0.94 0.94 18
7 1.00 0.94 0.97 18
8 0.80 0.84 0.82 19
9 0.75 0.94 0.83 16
accuracy 0.92 180
macro avg 0.93 0.93 0.93 180
weighted avg 0.93 0.92 0.92 180
# 计算混淆矩阵
C2 = confusion_matrix(test_y, y_test_pred)
# 打混淆矩阵
# 将混淆矩阵以热力图的防线显示
f, ax = plt.subplots()
# 画热力图
sns.heatmap(C2, cmap="YlGnBu_r", annot=True, ax=ax)
# 标题
ax.set_title('confusion matrix')
# x轴为预测类别
# y轴实际类别
如上述公式 所示,分子为投影数据后的均值只差,分母为方差之后,LDA的目的就是使得 值最大化,那么可以理解为最大化分子,即使得类别之间的距离越远,同时最小化分母,使得每个类别内部的方差越小,这样就能使得每个类类别的数据可以在投影矩阵