决策树


龙蜥操作系统应该安装mysql的什么版本_决策树



和支持向量机一样, 决策树是一种多功能机器学习算法, 即可以执行分类任务也可以执行回归任务, 甚至包括多输出(

multioutput)任务。它是一种功能很强大的算法,可以对很复杂的数据集进行拟合。

决策树也是随机森林的基本组成部分,而随机森林是当今最强大的机器学习算法之一。

熵和信息增益
熵(entropy):

量化不确定性的程度(混乱程度),是一种衡量方式,公式:
龙蜥操作系统应该安装mysql的什么版本_人工智能_02

信息增益(龙蜥操作系统应该安装mysql的什么版本_信息增益_03):

衡量熵的减少程度。表示:父节点的熵 和 所有子节点加权平均商的差。
龙蜥操作系统应该安装mysql的什么版本_信息增益_04

使用信息增益的经典示例,4个特征来分类猫和狗,训练数据如下表:

龙蜥操作系统应该安装mysql的什么版本_子节点_05

迭代二叉树 3 代(ID3)的算法:

该算法由罗斯·昆兰发明,是一种用于训练决策树的优先算法。

  • 二分类14个样本中,6只猫8只狗,其熵为 龙蜥操作系统应该安装mysql的什么版本_决策树_06
  • 用特征玩滚筒作问题:
  • 左子节点得到:9个样本(2狗7猫),其熵为:龙蜥操作系统应该安装mysql的什么版本_子节点_07
  • 右子节点得到:5个样本(4狗1猫),其熵为:龙蜥操作系统应该安装mysql的什么版本_人工智能_08
  • 加权平均得:0.7491
  • 信息增益得:0.2361
  • 用其他特征做问题,同理,有表如下:

龙蜥操作系统应该安装mysql的什么版本_机器学习_09

如表示例,使用猫粮检测的信息增益增加最多,它依然是最优检测,在剩余的特征中如此这般,继续下去。

龙蜥操作系统应该安装mysql的什么版本_信息增益_10

如表示例,脾气暴躁喜欢玩滚筒的检测的信息增益增加最多(信息增益相同),ID3算法会随机选择:若选择的是脾气暴躁,在剩余的特征中如此这般,继续下去。

龙蜥操作系统应该安装mysql的什么版本_决策树_11

如表所示,剩余的特征检测中,信息增益头相同,随机选择即可,最终,构建了决策树。

ID3 算法并不是唯一能用于训练决策树的算法。C4.5 算法是 ID3 的一个修改版本,它能够和连续解释变量一起使用,同时能为特征提供缺失的值。C4.5 算法也可以用于给树剪 枝。剪枝通过使用叶节点替代几乎不能对实例进行分类的分支来减少树的体积。CART 算法是另一种支持剪枝的学习算法,它同时也是 scikit-learn类库用来实现决策树的算法。

与信息增益类似,另一个常用于构造决策树的启发性算法:基尼不纯度 。
基尼不纯度

龙蜥操作系统应该安装mysql的什么版本_子节点_12

决策树的训练和可视化

如图,根据基尼不纯度对鸢尾花数据集中的花瓣长度、花瓣宽度这两个特征进行三分类,构建决策树:

龙蜥操作系统应该安装mysql的什么版本_机器学习_13

示例代码:

from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
import pandas as pd

iris = load_iris()
X = iris.data[:, 2:] # petal length and width
y = iris.target
#训练决策树模型,最大深度2
tree_clf = DecisionTreeClassifier(max_depth=2, random_state=42)
tree_clf.fit(X, y)
#观察下鸢尾花数据
data_iris= pd.DataFrame(iris.data,columns=iris.feature_names)
data_iris['target']= pd.Series(iris.target)
data_iris

#引入可视化工具graphviz,相关内容见http://graphviz.gitlab.io/download/
from graphviz import Source
from sklearn.tree import export_graphviz
#到处图片
export_graphviz(
        tree_clf,
        out_file=os.path.join(IMAGES_PATH, "iris_tree.dot"),
        feature_names=iris.feature_names[2:],
        class_names=iris.target_names,
        rounded=True,
        filled=True
    )
#查看图片
Source.from_file(os.path.join(IMAGES_PATH, "iris_tree.dot"))

测试查看轮廓

from matplotlib.colors import ListedColormap

def plot_decision_boundary(clf, X, y, axes=[0, 7.5, 0, 3], iris=True, legend=False, plot_training=True):
    x1s = np.linspace(axes[0], axes[1], 100) #100个petal length数据,行向量
    x2s = np.linspace(axes[2], axes[3], 100) #100个petal width数据,行向量
    x1, x2 = np.meshgrid(x1s, x2s) # 构建网格坐标
    X_new = np.c_[x1.ravel(), x2.ravel()]  #  转置后拼接成测试数据
    y_pred = clf.predict(X_new).reshape(x1.shape)  #预测,转成网格坐标
    custom_cmap = ListedColormap(['#fafab0','#9898ff','#a0faa0']) #颜色图谱
    plt.contourf(x1, x2, y_pred, alpha=0.3, cmap=custom_cmap)  #描轮廓
    if not iris:
        custom_cmap2 = ListedColormap(['#7d7d58','#4c4c7f','#507d50'])
        plt.contour(x1, x2, y_pred, cmap=custom_cmap2, alpha=0.8)
    if plot_training:
        plt.plot(X[:, 0][y==0], X[:, 1][y==0], "yo", label="Iris setosa")
        plt.plot(X[:, 0][y==1], X[:, 1][y==1], "bs", label="Iris versicolor")
        plt.plot(X[:, 0][y==2], X[:, 1][y==2], "g^", label="Iris virginica")
        plt.axis(axes)
    if iris:
        plt.xlabel("Petal length", fontsize=14)
        plt.ylabel("Petal width", fontsize=14)
    else:
        plt.xlabel(r"$x_1$", fontsize=18)
        plt.ylabel(r"$x_2$", fontsize=18, rotation=0)
    if legend:
        plt.legend(loc="lower right", fontsize=14)

plt.figure(figsize=(8, 4))
plot_decision_boundary(tree_clf, X, y)
plt.plot([2.45, 2.45], [0, 3], "k-", linewidth=2)
plt.plot([2.45, 7.5], [1.75, 1.75], "k--", linewidth=2)
plt.plot([4.95, 4.95], [0, 1.75], "k:", linewidth=2)
plt.plot([4.85, 4.85], [1.75, 3], "k:", linewidth=2)
plt.text(1.40, 1.0, "Depth=0", fontsize=15)
plt.text(3.2, 1.80, "Depth=1", fontsize=13)
plt.text(4.05, 0.5, "(Depth=2)", fontsize=11)

龙蜥操作系统应该安装mysql的什么版本_子节点_14

Scikit-Learn 用的是 CART 算法, CART 算法仅产生二叉树:每一个非叶节点总是只有两个子节点(只有是或否两个结果)。然而,像 ID3这样的算法可以产生超过两个子节点的决策树模型。

估计分类概率
tree_clf.predict_proba([[5, 1.5]])  #估计分类概率
#array([[0.        , 0.33333333, 0.66666667]])   
tree_clf.predict_proba([[4, 1.5]])
#array([[0.        , 0.97916667, 0.02083333]])
tree_clf.predict([[5, 1.5]]) # 
#array([2])
CART 训练算法

Scikit-Learn 用分类和回归树(Classification And Regression Tree,简称 CART)算法训练决策树(也叫“增长树”)。

该算法的工作原理:

首先使用单个特征 龙蜥操作系统应该安装mysql的什么版本_机器学习_15 和 阈值 (例如,花瓣长度≤2.45cm)将训练集分成两个子集。

它如何选择 龙蜥操作系统应该安装mysql的什么版本_机器学习_15龙蜥操作系统应该安装mysql的什么版本_决策树_17

它寻找能够产生最纯粹的子集(通过子集大小加权计算)一对龙蜥操作系统应该安装mysql的什么版本_信息增益_18

CART 算法分类的损失函数(成本函数):

龙蜥操作系统应该安装mysql的什么版本_子节点_19

正则化超参数:

DecisionTreeClassifier类还有一些其他的参数用于限制树模型的形状:

  • min_samples_split(节点在被分裂之前必须具有的最小样本数),
  • min_samples_leaf(叶节点必须具有的最小样本数),
  • min_weight_fraction_leaf(和min_samples_leaf相同,但表示为加权总数的一小部分实例),
  • max_leaf_nodes(叶节点的最大数量)
  • max_features(在每个节点被评估是否分裂的时候,具有的最大特征数量)。

增加min_* hyperparameters或者减少max_* hyperparameters会使模型正则化。

一些其他算法的工作原理是在没有任何约束条件下训练决策树模型,让模型自由生长,然后再对不需要的节点进行剪枝。当一个节点的全部子节点都是叶节点时,如果它对纯度的提升不具有统计学意义,我们就认为这个分支是不必要的。

标准的假设检验,例如卡方检测,通常会被用于评估一个概率值 – 即改进是否纯粹是偶然性的结果(也叫原假设)

如果 p 值比给定的阈值更高(通常设定为 5%,也就是 95% 置信度,通过超参数设置),那么节点就被认为是非必要的,它的子节点会被删除。

这种剪枝方式将会一直进行,直到所有的非必要节点都被删光。

代码示例:超参数:min_samples_leaf

from sklearn.datasets import make_moons
Xm, ym = make_moons(n_samples=100, noise=0.25, random_state=53)

deep_tree_clf1 = DecisionTreeClassifier(random_state=42)
deep_tree_clf2 = DecisionTreeClassifier(min_samples_leaf=4, random_state=42)
deep_tree_clf1.fit(Xm, ym)
deep_tree_clf2.fit(Xm, ym)

fig, axes = plt.subplots(ncols=2, figsize=(10, 4), sharey=True)
plt.sca(axes[0])
plot_decision_boundary(deep_tree_clf1, Xm, ym, axes=[-1.5, 2.4, -1, 1.5], iris=False)
plt.title("No restrictions", fontsize=16)
plt.sca(axes[1])
plot_decision_boundary(deep_tree_clf2, Xm, ym, axes=[-1.5, 2.4, -1, 1.5], iris=False)
plt.title("min_samples_leaf = {}".format(deep_tree_clf2.min_samples_leaf), fontsize=14)
plt.ylabel("")

save_fig("min_samples_leaf_plot")
plt.show()

龙蜥操作系统应该安装mysql的什么版本_信息增益_20

回归

决策树也能够执行回归任务,让我们使用Scikit-LearnDecisionTreeRegressor类构建一个回归树

代码示例:

np.random.seed(42)
m = 200
X = np.random.rand(m, 1)
y = 4 * (X - 0.5) ** 2
y = y + np.random.randn(m, 1) / 10

from sklearn.tree import DecisionTreeRegressor

tree_reg1 = DecisionTreeRegressor(random_state=42, max_depth=2)
tree_reg2 = DecisionTreeRegressor(random_state=42, max_depth=3)  
tree_reg1.fit(X, y)
tree_reg2.fit(X, y)

def plot_regression_predictions(tree_reg, X, y, axes=[0, 1, -0.2, 1], ylabel="$y$"):
    x1 = np.linspace(axes[0], axes[1], 500).reshape(-1, 1)
    y_pred = tree_reg.predict(x1)
    plt.axis(axes)
    plt.xlabel("$x_1$", fontsize=18)
    if ylabel:
        plt.ylabel(ylabel, fontsize=18, rotation=0)
    plt.plot(X, y, "b.")
    plt.plot(x1, y_pred, "r.-", linewidth=2, label=r"$\hat{y}$")

fig, axes = plt.subplots(ncols=2, figsize=(10, 4), sharey=True)
plt.sca(axes[0])
plot_regression_predictions(tree_reg1, X, y)
for split, style in ((0.1973, "k-"), (0.0917, "k--"), (0.7718, "k--")):
    plt.plot([split, split], [-0.2, 1], style, linewidth=2)
plt.text(0.21, 0.65, "Depth=0", fontsize=15)
plt.text(0.01, 0.2, "Depth=1", fontsize=13)
plt.text(0.65, 0.8, "Depth=1", fontsize=13)
plt.legend(loc="upper center", fontsize=18)
plt.title("max_depth=2", fontsize=14)

plt.sca(axes[1])
plot_regression_predictions(tree_reg2, X, y, ylabel=None)
for split, style in ((0.1973, "k-"), (0.0917, "k--"), (0.7718, "k--")):
    plt.plot([split, split], [-0.2, 1], style, linewidth=2)
for split in (0.0458, 0.1298, 0.2873, 0.9040):
    plt.plot([split, split], [-0.2, 1], "k:", linewidth=1)
plt.text(0.3, 0.5, "Depth=2", fontsize=13)
plt.title("max_depth=3", fontsize=14)

save_fig("tree_regression_plot")
plt.show()

export_graphviz(
        tree_reg1,
        out_file=os.path.join(IMAGES_PATH, "regression_tree.dot"),
        feature_names=["x1"],
        rounded=True,
        filled=True
    )
Source.from_file(os.path.join(IMAGES_PATH, "regression_tree.dot"))

龙蜥操作系统应该安装mysql的什么版本_信息增益_21

得到的回归决策树

龙蜥操作系统应该安装mysql的什么版本_子节点_22

不稳定性

对训练集旋转敏感

对训练集细节敏感

可以用随机森林,对许多树进行平均预测来限制这种不稳定性。