#-*- coding: utf-8 -*-
'''
Created on Oct 12, 2010
Decision Tree Source Code for Machine Learning in Action Ch. 3
@author: Peter Harrington
'''
from math import log
import operator
#训练数据集
def createDataSet():
#类别值有两个:yes,no
dataSet = [[1, 1, 'yes'],
[1, 1, 'yes'],
[1, 0, 'no'],
[0, 1, 'no'],
[0, 1, 'no']]
#两个特征名称
labels = ['no surfacing','flippers']
#change to discrete values
return dataSet, labels
#计算熵
def calcShannonEnt(dataSet):
#样本总数
numEntries = len(dataSet)
#dict字典数据类型,字典是由键对值组组成
labelCounts = {}
#featVec遍历dataSet的每一行
for featVec in dataSet:
#每行最后一列为类别标签
currentLabel = featVec[-1]
#***统计各类别中的样本数
#为所有可能取值建立字典<key,value>结构
#key表示类别,value表示出现次数
if currentLabel not in labelCounts.keys():
#当前key不在字典中,扩展字典
labelCounts[currentLabel] = 0
#当前类别出现一次,字典value值加1
labelCounts[currentLabel] += 1
#保存信息熵
shannonEnt = 0.0
#样本遍历完后,计算给类别占总样本数的比例
for key in labelCounts:
#各类别所占比例
prob = float(labelCounts[key])/numEntries
#熵计算,自然对数以2为底
shannonEnt -= prob * log(prob,2)
return shannonEnt
#*****按照给定特征值划分数据集
#dataSet:待划分的数据集
#axis:特征索引,按照该特征进行划分
#value:所选特征下的某个取值
def splitDataSet(dataSet, axis, value):
#定义新变量,保存划分后的数据集
retDataSet = []
#遍历数据集dataSet的每一行(条)数据
for featVec in dataSet:
#第axis所对应的特征值为要用于划分的特征值
if featVec[axis] == value:
#取特征列[0]~[axis-1]
reducedFeatVec = featVec[:axis]
#取特征列[axis+1]~[最后一列]
reducedFeatVec.extend(featVec[axis+1:])
#裁剪掉 [axis]列的特征值
retDataSet.append(reducedFeatVec)
return retDataSet
#*****选择最好的特征划分数据集,即选信息增益最大的特征划分
#dataSet:待划分的数据集
def chooseBestFeatureToSplit(dataSet):
#特征维数
numFeatures = len(dataSet[0]) - 1
#数据集的整体熵
baseEntropy = calcShannonEnt(dataSet)
#保存最大信息增益,初始化为0
bestInfoGain = 0.0;
#信息增益最大的特征,初始化为-1列特征
bestFeature = -1
#遍历所有特征
for i in range(numFeatures):
#取第i列特征
featList = [example[i] for example in dataSet]
#第i列特征下的所有不重复取值
uniqueVals = set(featList)
#条件熵,初始化为0
newEntropy = 0.0
#遍历某个特征下的所有值
for value in uniqueVals:
#将dataSet按照第i列特征==value划分成子数据集subDataSet
subDataSet = splitDataSet(dataSet, i, value)
#按第i列特征划分后的各子集在总数据集dataSet所占的比例
prob = len(subDataSet)/float(len(dataSet))
#按照第i列特征划分后的条件熵
newEntropy += prob * calcShannonEnt(subDataSet)
#信息增益=整体熵-条件熵
infoGain = baseEntropy - newEntropy
#找最大信息增益及所对应的特征
if (infoGain > bestInfoGain):
bestInfoGain = infoGain
bestFeature = i #最大信息增益对应的特征
return bestFeature
#所有特征用完后,以数据集中类别最多的类别作为最终类别
#classList:类别
def majorityCnt(classList):
#<key,value>字典结构,key存类别,value存类别对应的样本数
classCount={}
#遍历数据集中的类别
for vote in classList:
#若类别不在字典中,则添加到字典中去
if vote not in classCount.keys(): classCount[vote] = 0
#若某类别值出现一次,则累加1,即统计某类别下的样本数
classCount[vote] += 1
#按照classCount[vote]值从大到小排序(即各类别下样本数从大到小排序)
sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)
#返回含样本数最多的类别
return sortedClassCount[0][0]
#*****递归创建决策树
#dataSet:数据集
#labels:特征名称
def createTree(dataSet,labels):
#获取数据集最后一列,即类别标签
classList = [example[-1] for example in dataSet]
#*当划分的数据集属于同一类别则停止划分,返回该类别
if classList.count(classList[0]) == len(classList):
return classList[0]
#*划分的数据集已经没有特征值,返回出现次数多的类别
if len(dataSet[0]) == 1:
return majorityCnt(classList)
#*递归未终止
#选出最佳划分的特征所对应的索引,即信息增益最大的划分特征
bestFeat = chooseBestFeatureToSplit(dataSet)
#最佳划分的特征名称
bestFeatLabel = labels[bestFeat]
#将该特征名称作为根节点
myTree = {bestFeatLabel:{}}
#删除在原特征名称中的最佳划分特征名称
del(labels[bestFeat])
#取出最佳的划分特征列
featValues = [example[bestFeat] for example in dataSet]
#特征列下的不重复特征值(去重)
uniqueVals = set(featValues)
#遍历最佳划分特征下的特征值
for value in uniqueVals:
#获取删除最佳划分特征名称后的特征名称集合
subLabels = labels[:]
#对划分后的子集进行递归调用构建决策树,将递归调用的结果作为树的一个分支
myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value),subLabels)
return myTree
#*****使用决策树进行分类
#inputTree:训练好的决策树
#featLabels:特征的名称
#testVec:测试数据
def classify(inputTree,featLabels,testVec):
#存放决策树的根节点名称(比如no surfacing)
firstStr = inputTree.keys()[0]
print 'firstStr=',firstStr
print 'featLabels=',featLabels
#除根节点名称外的值
#即{0:'no',1:{'flippers':{0:'no',1:'yes'}}}
secondDict = inputTree[firstStr]
print 'secondDict=',secondDict
#index方法查找当前列表中第一个匹配firstStr变量的元素的索引
#即找到树根节点在所有特征列的第几列
featIndex = featLabels.index(firstStr)
print 'featIndex=',featIndex
#测试数据对应根节点所在特征下的特征值
key = testVec[featIndex]
#secondDict[0]='no'secondDict[1]={'flippers':{0:'no',1:'yes'}}
valueOfFeat = secondDict[key]
#判断valueOfFeat的类型
#valueOfFeat为dict字典类型,递归寻找
if isinstance(valueOfFeat, dict):
classLabel = classify(valueOfFeat, featLabels, testVec)
#valueOfFeat为数值,直接返回该值(最终类别)
#此处valueOfFeat=secondDict[key]='no'返回no
else: classLabel = valueOfFeat
#返回最终类别
return classLabel
def storeTree(inputTree,filename):
import pickle
fw = open(filename,'w')
pickle.dump(inputTree,fw)
fw.close()
def grabTree(filename):
import pickle
fr = open(filename)
return pickle.load(fr)
#==================================================================
####绘图:绘构建的树
# import matplotlib.pyplot as plt
#
# decisionNode = dict(boxstyle="sawtooth", fc="0.8")
# leafNode = dict(boxstyle="round4", fc="0.8")
# arrow_args = dict(arrowstyle="<-")
#
# def getNumLeafs(myTree):
# numLeafs = 0
# firstStr = myTree.keys()[0]
# secondDict = myTree[firstStr]
# for key in secondDict.keys():
# if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes
# numLeafs += getNumLeafs(secondDict[key])
# else: numLeafs +=1
# return numLeafs
#
# def getTreeDepth(myTree):
# maxDepth = 0
# firstStr = myTree.keys()[0]
# secondDict = myTree[firstStr]
# for key in secondDict.keys():
# if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes
# thisDepth = 1 + getTreeDepth(secondDict[key])
# else: thisDepth = 1
# if thisDepth > maxDepth: maxDepth = thisDepth
# return maxDepth
#
# def plotNode(nodeTxt, centerPt, parentPt, nodeType):
# createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',
# xytext=centerPt, textcoords='axes fraction',
# va="center", ha="center", bbox=nodeType, arrowprops=arrow_args )
#
# def plotMidText(cntrPt, parentPt, txtString):
# xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
# yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
# createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)
#
# def plotTree(myTree, parentPt, nodeTxt):#if the first key tells you what feat was split on
# numLeafs = getNumLeafs(myTree) #this determines the x width of this tree
# depth = getTreeDepth(myTree)
# firstStr = myTree.keys()[0] #the text label for this node should be this
# cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
# plotMidText(cntrPt, parentPt, nodeTxt)
# plotNode(firstStr, cntrPt, parentPt, decisionNode)
# secondDict = myTree[firstStr]
# plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
# for key in secondDict.keys():
# if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes
# plotTree(secondDict[key],cntrPt,str(key)) #recursion
# else: #it's a leaf node print the leaf node
# plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
# plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
# plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
# plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
# #if you do get a dictonary you know it's a tree, and the first element will be another dict
#
# def createPlot(inTree):
# fig = plt.figure(1, facecolor='white')
# fig.clf()
# axprops = dict(xticks=[], yticks=[])
# createPlot.ax1 = plt.subplot(111, frameon=False, **axprops) #no ticks
# #createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses
# plotTree.totalW = float(getNumLeafs(inTree))
# plotTree.totalD = float(getTreeDepth(inTree))
# plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0;
# plotTree(inTree, (0.5,1.0), '')
# plt.show()
##def createPlot():
## fig = plt.figure(1, facecolor='white')
## fig.clf()
## createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses
## plotNode('a decision node', (0.5, 0.1), (0.1, 0.5), decisionNode)
## plotNode('a leaf node', (0.8, 0.1), (0.3, 0.8), leafNode)
## plt.show()
##def retrieveTree(i):
## listOfTrees =[{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
## {'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}
## ]
## return listOfTrees[i]
##
###createPlot(thisTree)
####绘图:绘构建的树
#==================================================================
if __name__ == "__main__":
mydat,labels=createDataSet()
featlabels=labels
mytree=createTree(mydat,labels)
#需要将初始得到的labels保存备份。
print mytree
#****隐形眼镜数据集
##
## fr=open('lenses.txt','r')
## lenses=[line.strip().split('\t')for line in fr.readlines()]
## lensesLabels=['age','prescript','astigmatic','tearRate']
## lensesTree=createTree(lenses,lensesLabels)
## print lensesTree
## createPlot(lensesTree) #绘图:构建的树