继续跟着白皮书学习,对上面的代码做了不少改动,现在能正确绘制了。

先不谈决策树的算法,现在仅仅是依据字典表示树来绘制决策树的图形。

go.py

引导脚本。

#!/usr/local/bin/python3.5
import treePlot
myTree0=treePlot.getTstTree(0)
myTree1=treePlot.getTstTree(1)
myTree0['no surfacing'][1]['flippers'][0]=myTree1['no surfacing'][0]
treePlot.mainPlot(myTree0)

treePlot.py

#!/usr/local/bin/python3.5
#-*-coding:utf-8-*-
import matplotlib.pyplot as plt

#建立存决策结点格式的字典{'fc': '0.8', 'boxstyle': 'sawtooth'}
decisionNode=dict(boxstyle="sawtooth",fc="0.8")
#建立存叶结点格式的字典{'fc': '0.8', 'boxstyle': 'round4'}
leafNode=dict(boxstyle="round4",fc="0.8")
#建立存箭头格式的字典{'arrowstyle': '->'}
arrow_args=dict(arrowstyle='-')

#绘制结点(结点名称,结点位置,箭头起点,结点类型)
def plotNode(nodeTxt,centerPt,parentPt,nodeType):
    #下面的pyplot.annotate()用于做文本注释
    #参数s:传注释的文本字符串nodeTxt
    #参数xy:传被注释的坐标元组(x,y)
    #参数xytext:传插入文本的坐标元组(x,y),如果和xy不一样,会产生箭头
    #参数xycoords:指定传入的参数xy所依据的坐标系统
    #这个参数的值取'axes fraction'时表示从左下角的坐标轴
    #参数textcoords:指定传入的参数xytext所依据的坐标系统,规则同xycoords
    #参数arrowprops:传入一个字典,如果字典中有键为arrowstyle的键值对,那么其对应的值可以指定箭头的类型
    #这个参数中的键arrowstyle的值还可以取'-|>','-['等..用的时候查官方文档吧
    createPlot.ax1.annotate(\
        s=nodeTxt,\
        xy=parentPt,\
        xytext=centerPt,\
        xycoords='axes fraction',\
        textcoords='axes fraction',\
        va="center",\
        ha="center",\
        bbox=nodeType,\
        arrowprops=arrow_args)

#用来测试pyplot.annotate()绘制注释的函数
def createPlot():
    #下面的pyplot.annotate()用于创建一个新绘图对象
    #参数num:若不提供则创建一个新图形;若提供了存在的num值则返回其引用;否则创建它并在窗口标题上显示
    #提供的num为数字会显示'Figure 数字';字符串会直接显示这个字符串
    #参数facecolor:指定背景颜色,可以使用颜色名或16进制颜色
    fig=plt.figure(num='绘制注释',facecolor='#99CC66')
    fig.clf() #清除figure对象fig上的图形
    #这里createPlot.ax1是对createPlot这个函数定义了一个属性ax1
    #python可以用这种方式来实现全局变量
    #在这个属性中,用plt.subplot创建子图并获取了这个子图的引用
    #就可以在其它函数中通过访问该函数这个属性直接操作这个子图了
    #frameon指定子图是否独立出来,默认是True
    #子图不独立出来时,将继承figure对象的facecolor
    createPlot.ax1=plt.subplot(111,frameon=True)
    #调用绘制结点的函数,在函数体内用createPlot.ax1访问到此处建立的子图
    plotNode('Decision node',(0.5,0.1),(0.1,0.5),decisionNode)
    plotNode('Leaf nodes',(0.8,0.1),(0.3,0.8),leafNode)
    plt.show()

#用来测试的已经建立好的字典形式的树
def getTstTree(i):
    #只提供了两棵树
    treeList=[\
            {'no surfacing':\
                {0:'no',1:\
                    {'flippers':\
                        {0:'no',1:'yes'}\
                    }\
                }\
            },\
            {'no surfacing':\
                {0:\
                    {'head':\
                        {0:'no',1:'yes'}\
                    },\
                1:'no'\
                }\
            }\
            ]
    return treeList[i]

#获取树myTree的叶结点数目
def getNumLeafs(myTree):
    numLeafs=0 #初始化叶子结点数目为0
    firstStr=list(myTree.keys())[0] #获取当前子树的树根key
    secondDict=myTree[firstStr] #获取对应的所有可能划分的子树字典
    for key in secondDict.keys(): #对于其中的每一个划分出的子树
        #如果这棵子树下还有树,即其对应的value值还是一个字典对象
        if type(secondDict[key]).__name__=='dict':
            #将这个字典对象传入,递归调用求其叶结点数目加到总数中
            numLeafs+=getNumLeafs(secondDict[key])
        else: #如果这棵子树已经是叶结点了,即不再包含字典了
            numLeafs+=1 #递归出口,记录叶结点数增加了1
    return numLeafs #返回这个树下总的叶结点数目

#获取树myTree的高度
def getTreeDepth(myTree):
    maxDepth=0 #初始化树的高度为0
    firstStr=list(myTree.keys())[0] #获取当前子树的树根key
    secondDict=myTree[firstStr] #获取对应的所有可能划分的子树字典
    for key in secondDict.keys(): #对于其中的每一个划分出的子树
        #如果这棵子树下还有树,即其对应的value值还是一个字典对象
        if type(secondDict[key]).__name__=='dict':
            #将这个字典对象传入,递归调用求其树高,加上子树根的高度1
            thisDepth=getTreeDepth(secondDict[key])+1
        else: #如果这棵子树已经是叶结点了,即不再包含字典了
            thisDepth=1 #递归出口,记录单一结点的树高是1
        if thisDepth>maxDepth: #如果这次找出的树高更高了
            maxDepth=thisDepth #更新最高值
    return maxDepth #返回这个树的树高

#在树的父子结点之间填充文本信息
#(子结点坐标[x,y],父结点坐标[x,y],文本信息字符串)
def plotMidText(cntrPt,parentPt,txtString):
    xMid=(parentPt[0]+cntrPt[0])/2.0 #横坐标中心
    yMid=(parentPt[1]+cntrPt[1])/2.0 #纵坐标中心
    #利用mainPlot函数的属性ax1在subplot对象上添加文本
    mainPlot.ax1.text(xMid,yMid,txtString)

#在mainPlot函数的属性ax1对应的subplot对象上绘制结点
#(结点名称,结点位置,箭头起点,结点类型)
def mainPlotNode(nodeTxt,centerPt,parentPt,nodeType):
    #下面的pyplot.annotate()用于做文本注释
    #参数s:传注释的文本字符串nodeTxt
    #参数xy:传被注释的坐标元组(x,y)
    #参数xytext:传插入文本的坐标元组(x,y),如果和xy不一样,会产生箭头
    #参数xycoords:指定传入的参数xy所依据的坐标系统
    #这个参数的值取'axes fraction'时表示从左下角的坐标轴
    #参数textcoords:指定传入的参数xytext所依据的坐标系统,规则同xycoords
    #参数arrowprops:传入一个字典,如果字典中有键为arrowstyle的键值对,那么其对应的值可以指定箭头的类型
    #这个参数中的键arrowstyle的值还可以取'-|>','-['等..用的时候查官方文档吧
    mainPlot.ax1.annotate(\
        s=nodeTxt,\
        xy=parentPt,\
        xytext=centerPt,\
        xycoords='axes fraction',\
        textcoords='axes fraction',\
        va="center",\
        ha="center",\
        bbox=nodeType,\
        arrowprops=arrow_args)

#绘制决策(子)树,也是一个递归的函数
#(字典表示树,父结点坐标[x,y],填充的文本信息)
def plotTree(myTree,parentPt,txtString):
    numLeafs=getNumLeafs(myTree) #计算叶结点数(表征子树宽)
    depth=getTreeDepth(myTree) #计算子树的高度
    firstStr=list(myTree.keys())[0] #获取当前子树的树根key
    #按比例计算树当前子树根结点的摆放位置
    cntrPt=(plotTree.xOff+\
        (1.0+float(numLeafs))/2.0/plotTree.totalW,\
        plotTree.yOff)
    #在树的父子结点之间填充文本信息
    #(子结点坐标=当前根结点坐标,父结点坐标,填充的文本信息)
    plotMidText(cntrPt,parentPt,txtString)
    #绘制决策结点(结点名称,结点位置,箭头起点=父结点,结点类型=决策结点)
    mainPlotNode(firstStr,cntrPt,parentPt,decisionNode)
    secondDict=myTree[firstStr] #获取对应的所有可能划分的子树字典
    plotTree.yOff=plotTree.yOff-1.0/plotTree.totalD #为了绘制子树,y轴偏移量按比例减少
    for key in secondDict.keys(): #对于其中的每一个划分出的子树
        #如果这棵子树下还有树,即其对应的value值还是一个字典对象
        if type(secondDict[key]).__name__=='dict':
            #将这个字典对象传入,递归调用绘制其子树
            #(子树下的字典对象,子树父结点坐标=当前根结点坐标,子树文本)
            plotTree(secondDict[key],cntrPt,str(key))
            #为了不影响兄弟结点的高度(yOff)
            #在一个结点上的递归绘制子树完成返回之前
            #需要把本层按比例减掉的yOff(全局纵坐标值)按比例加回来
            #plotTree.yOff=plotTree.yOff+1.0/plotTree.totalD
            plotTree.yOff=plotTree.yOff+1.0/plotTree.totalD
        else: #如果这棵子树已经是叶结点了,即不再包含字典了
            plotTree.xOff=plotTree.xOff+1.0/plotTree.totalW #对子树,x轴偏移量按比例增加
            #绘制新的结点(子树文本,结点位置=当前xy平移后的位置,箭头起点=当前树根结点位置,结点类型=叶子结点)
            mainPlotNode(secondDict[key],\
                        (plotTree.xOff,plotTree.yOff),\
                        cntrPt,\
                        leafNode)
            #在父子结点间填充文本信息.在绘制树时,这一步在plotTree里
            #但在这里是绘制叶结点,没有封装进函数里而是单独拿出来做
            #(子结点坐标=当前偏移后位置,父结点坐标=当前根结点位置,子树文本)
            plotMidText((plotTree.xOff,plotTree.yOff),\
                        cntrPt,
                        str(key))

#绘制字典表示树inTree
def mainPlot(inTree):
    fig=plt.figure('绘制决策树',facecolor='#CCCCFF')
    fig.clf()
    axprops=dict(xticks=[],yticks=[])
    mainPlot.ax1=plt.subplot(111,frameon=False,**axprops)
    #plotTree函数的totalW属性存储树的总宽度
    plotTree.totalW=float(getNumLeafs(inTree))
    #plotTree函数的totalD属性存储树的总高度
    plotTree.totalD=float(getTreeDepth(inTree))
    #初始化其xOff和yOff位置
    #以使树inTree的根节点位置尽量合适
    plotTree.xOff=-0.5/plotTree.totalW
    plotTree.yOff=1.0
    #在合适的位置选取虚拟的父节点
    #以使树inTree的根节点位置尽量合适
    plotTree(inTree,(0.5,1.0),'')
    plt.show()

运行结果

python 决策树绘图 怎么用python画决策树_matplotlib