机器学习之决策树(ID3)算法与Python实现
机器学习中,决策树是一个预测模型;他代表的是对象属性与对象值之间的一种映射关系。树中每个节点表示某个对象,而每个分叉路径则代表的某个可能的属性值,而每个叶结点则对应从根节点到该叶节点所经历的路径所表示的对象的值。决策树仅有单一输出,若欲有复数输出,可以建立独立的决策树以处理不同输出。 数据挖掘中决策树是一种经常要用到的技术,可以用于分析数据,同样也可以用来作预测。
一、决策树与ID3概述1.决策树
决策树,其结构和树非常相似,因此得其名决策树。决策树具有树形的结构,其中每个内部节点表示一个属性上的测试,每个分支代表一个测试输出,每个叶节点代表一种类别。
例如:
按照豆腐脑的冷热、甜咸和是否含有大蒜构建决策树,对其属性的测试,在最终的叶节点决定该豆腐脑吃还是不吃。
分类树(决策树)是一种十分常用的将决策树应用于分类的机器学习方法。他是一种监管学习,所谓监管学习就是给定一堆样本,每个样本都有一组属性(特征)和一个类别(分类信息/目标),这些类别是事先确定的,那么通过学习得到一个分类器,这个分类器能够对新出现的对象给出正确的分类。
其原理在于,每个决策树都表述了一种树型结构,它由它的分支来对该类型的对象依靠属性进行分类。每个决策树可以依靠对源数据库的分割进行数据测试。这个过程可以递归式的对树进行修剪。 当不能再进行分割或一个单独的类可以被应用于某一分支时,递归过程就完成了。
机器学习中,决策树是一个预测模型;他代表的是对象属性与对象值之间的一种映射关系。树中每个节点表示某个对象,而每个分叉路径则代表的某个可能的属性值,而每个叶结点则对应从根节点到该叶节点所经历的路径所表示的对象的值。决策树仅有单一输出,若欲有复数输出,可以建立独立的决策树以处理不同输出。数据挖掘中决策树是一种经常要用到的技术,可以用于分析数据,同样也可以用来作预测。从数据产生决策树的机器学习技术叫做决策树学习, 通俗说就是决策树。
目前常用的决策树算法有ID3算法、改进的C4.5算法和CART算法。
决策树的特点
1.多层次的决策树形式易于理解;
2.只适用于标称型数据,对连续性数据处理得不好;
2、ID3算法
ID3算法最早是由罗斯昆(J. Ross Quinlan)于1975年在悉尼大学提出的一种分类预测算法,算法以信息论为基础,其核心是“信息熵”。ID3算法通过计算每个属性的信息增益,认为信息增益高的是好属性,每次划分选取信息增益最高的属性为划分标准,重复这个过程,直至生成一个能完美分类训练样例的决策树。
信息熵(Entropy):
,其中p(xi)是选择i的概率。
熵越高,表示混合的数据越多。信息增益(Information Gain):
T是划分之后的分支集合,p(t)是该分支集合在原本的父集合中出现的概率,H(t)是该子集合的信息熵。
3.ID3算法与决策树的流程
(1)数据准备:需要对数值型数据进行离散化
(2)ID3算法构建决策树:
如果数据集类别完全相同,则停止划分
否则,继续划分决策树:
计算信息熵和信息增益来选择最好的数据集划分方法;
划分数据集
创建分支节点:
对每个分支进行判定是否类别相同,如果相同停止划分,不同按照上述方法进行划分。
二、Python算法实现
创建 trees.py文件,在其中创建构建决策树的函数。
首先构建一组测试数据:
0. 构造函数createDataSet:
def createDataSet():
dataSet=[[1,1,'yes'],[1,1,'yes'],[1,0,'no'],[0,1,'no'],[0,1,'no']]
labels=['no surfacing','flippers']
return dataSet,labels在Python控制台测试构造函数
#测试下构造的数据
import trees
myDat,labels = trees.createDataSet()
myDat
Out[4]: [[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]
labels
Out[5]: ['no surfacing', 'flippers']
2.1 计算信息熵
from math import log
def calcShannonEnt(dataSet):
numEntries = len(dataSet) #nrows
#为所有的分类类目创建字典
labelCounts ={}
for featVec in dataSet:
currentLable=featVec[-1] #取得最后一列数据
if currentLable not in labelCounts.keys():
labelCounts[currentLable]=0
labelCounts[currentLable] =1
#计算香农熵
shannonEnt=0.0
for key in labelCounts:
prob = float(labelCounts[key]) / numEntries
shannonEnt -= prob * log(prob, 2)
return shannonEnt利用构造的数据测试calcShannonEnt:
#Python console
In [6]: trees.calcShannonEnt(myDat)
...:
Out[6]: 0.97095059445466862.2 按照最大信息增益划分数据集
#定义按照某个特征进行划分的函数splitDataSet
#输入三个变量(待划分的数据集,特征,分类值)
def splitDataSet(dataSet,axis,value):
retDataSet=[]
for featVec in dataSet:
if featVec[axis]==value :
reduceFeatVec=featVec[:axis]
reduceFeatVec.extend(featVec[axis 1:])
retDataSet.append(reduceFeatVec)
return retDataSet #返回不含划分特征的子集
#定义按照最大信息增益划分数据的函数
def chooseBestFeatureToSplit(dataSet):
numFeature=len(dataSet[0])-1
baseEntropy=calcShannonEnt(dataSet)#香农熵
bestInforGain=0
bestFeature=-1
for i in range(numFeature):
featList=[number[i] for number in dataSet] #得到某个特征下所有值(某列)
uniqualVals=set(featList) #set无重复的属性特征值
newEntropy=0
for value in uniqualVals:
subDataSet=splitDataSet(dataSet,i,value)
prob=len(subDataSet)/float(len(dataSet)) #即p(t)
newEntropy =prob*calcShannonEnt(subDataSet)#对各子集香农熵求和
infoGain=baseEntropy-newEntropy #计算信息增益
#最大信息增益
if (infoGain>bestInforGain):
bestInforGain=infoGain
bestFeature=i
return bestFeature #返回特征值
在控制台中测试这两个函数:
#测试按照特征划分数据集的函数
In [8]: from imp import reload
In [9]: reload(trees)
Out[9]:
In [10]: myDat,labels=trees.createDataSet()
...:
In [11]: trees.splitDataSet(myDat,0,0)
...:
Out[11]: [[1, 'no'], [1, 'no']]
In [12]: trees.splitDataSet(myDat,0,1)
...:
Out[12]: [[1, 'yes'], [1, 'yes'], [0, 'no']]
#测试chooseBestFeatureToSplit函数
In [13]: reload(trees)
...:
Out[13]:
In [14]: trees.chooseBestFeatureToSplit(myDat)
...:
Out[14]: 0
2.3 创建决策树构造函数createTree
import operater
#投票表决代码
def majorityCnt(classList):
classCount={}
for vote in classList:
if vote not in classCount.keys():classCount[vote]=0
classCount[vote] =1
sortedClassCount=sorted(classCount.items,key=operator.itemgetter(1),reversed=True)
return sortedClassCount[0][0]
def createTree(dataSet,labels):
classList=[example[-1] for example in dataSet]
#类别相同,停止划分
if classList.count(classList[-1])==len(classList):
return classList[-1]
#长度为1,返回出现次数最多的类别
if len(classList[0])==1:
return majorityCnt(classList)