博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
ML in Action 决策树
阅读量:5068 次
发布时间:2019-06-12

本文共 7463 字,大约阅读时间需要 24 分钟。

Project Address:

dataset in ML/ML_ation/tree

 决策树

  • 计算复杂度低,中间值缺失不敏感,可理解不相关数据
  • 可能过度匹配(过度分类)
  • 适用:数值型和标称型

决策树伪代码createbranch

检测数据集中子项是否全部属于一类    if so return class_tag    else 寻找数据集最佳划分特征            划分数据集            创建分支节点            对每一个子集,递归调用createbranch        返回分支节点

递归结束条件:所有属性遍历完,或者数据集属于同一分类

香农熵

def calcShannonEnt(dataSet):    numEntries = len(dataSet)    labelCounts = {}    for featVec in dataSet:        currentLabel = featVec[-1]        if currentLabel not in labelCounts.keys():            labelCounts[currentLabel] = 0        labelCounts[currentLabel] += 1    shannonEnt = 0.0    for key in labelCounts:        prob = float(labelCounts[key])/numEntries        shannonEnt -= prob * log(prob,2)    return shannonEnt

数据及划分与最优选择(熵最小)

def splitDataSet(dataSet, axis, value):    retDataSet = []    for featVec in dataSet:        if featVec[axis] == value:            reduceFeatVec = featVec[:axis]            reduceFeatVec.extend(featVec[axis + 1:])            retDataSet.append(reduceFeatVec)    return retDataSetdef chooseBestFeatureToSplit(dataSet):    numFeatures = len(dataSet[0])- 1    baseEntropy = calcShannonEnt(dataSet)    bestInfoGain = 0.0    bestFeature = -1    for i in range(numFeatures):        featList = [example[i] for example in dataSet]        uniqueVals = set(featList)        newEntropy = 0.0        for value in uniqueVals:            subDataSet = splitDataSet(dataSet, i, value)            prob = len(subDataSet)/float(len(dataSet))            newEntropy += prob * calcShannonEnt(subDataSet)        infoGain = baseEntropy - newEntropy        if infoGain > bestInfoGain:            baseInfoGain = infoGain            bestFeature = i    return bestFeature

所有标签用尽无法确定类标签时: 多数表决决定子叶分类

def majorityCnt(classList):    classCount = {}    for vote in classList:        if vote not in classCount.keys(): classCount[vote] = 0        classCount[vote] += 1    sortedClassCount = sorted(classCount.iteritems(), key = operator.itemgetter(1), reverse = True)    return sortedClassCount[0][0]

创建树

def createTree(dataSet, labels):    classList = [example[-1] for example in dataSet]    if classList.count(classList[0]) == len(classList):        return classList[0]    if len(dataSet[0]) == 1:        return majorityCnt(classList)        bestFeat = chooseBestFeatureToSplit(dataSet)    bestFeatureLabel = labels[bestFeat]    myTree = {bestFeatureLabel:{}}    del(labels[bestFeat])    featValues = [example[bestFeat] for example in dataSet]    uniqueVals = set(featValues)    for value in uniqueVals:        subLabels = labels[:]        myTree[bestFeatureLabel][value] = createTree(splitDataSet(dataSet, bestFeat,value), subLabels)    return myTree

测试

def classify(inputTree,featLabels,testVec):    firstStr = inputTree.keys()[0]    secondDict = inputTree[firstStr]    featIndex = featLabels.index(firstStr)    for key in secondDict.keys():        if testVec[featIndex] == key:            if type(secondDict[key]).__name__=='dict':                classLabel = classify(secondDict[key],featLabels,testVec)            else:                classLabel = secondDict[key]    return classLabel
>>> import trees>>> myDat,labels=trees.createDataSet()>>> labels['no surfacing', 'flippers']>>> myTree=treePlotter.retrieveTree (0)>>> myTree{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}>>> trees.classify(myTree,labels,[1,0])'no'>>> trees.classify(myTree,labels,[1,1])'yes'

 存储与重载

def storeTree(inputTree, filename):    import pickle    fw = open(filename, 'w')    pickle.dump(inputTree,fw)    fw.close()def grabTree(filename):    import pickle    fr = open(filename)    return pickle.load(fr)

 test

#!/usr/bin/pythonimport treesmyDat,labels = trees.createDataSet()myTree = trees.createTree(myDat, labels)trees.storeTree(myTree,'classifierStorage.txt')print(trees.grabTree('classifierStorage.txt'))

图形化显示树结构

#!/usr/bin/pythonimport matplotlib.pyplot as plt decisionNode = dict(boxstyle = "sawtooth", fc = "0.8")leafNode = dict(boxstyle = "round4", fc = "0.8")arrow_args = dict(arrowstyle = "<-")def plotNode(nodeTxt, centerPt, parentPt, nodeType):    createPlot.ax1.annotate(nodeTxt, xy = parentPt, xycoords = "axes fraction",    xytext = centerPt, textcoords = "axes fraction",    va = "center", ha = "center", bbox = nodeType, arrowprops = arrow_args)

创建节点

def createPlot():    fig = plt.figure(1, facecolor = "white")    fig.clf()    createPlot.ax1 = plt.subplot(111, frameon = False)    plotNode("a decision node",(0.5, 0.1), (0.1, 0.5), decisionNode)    plotNode("a leaf node",(0.8, 0.1), (0.3, 0.8), leafNode)    plt.show()

python command line run command as this

import treeplottertreePlotter.createPlot()
  • result like this
    图片标题
def getNumLeafs(myTree):    numLeafs = 0    firstStr = myTree.keys()[0]    secondDict = myTree[firstStr]    for key in secondDict.keys():        if type(secondDict[key]).__name__ == 'dict':            numLeafs += getNumleafs(secondDict[key])        else: numLeafs +=1    return numLeafsdef getTreeDepth(myTree):    maxDepth = 0    firstStr = myTree.keys()[0]    secondDict = myTree[firstStr]    for key in secondDict.keys():        if type(secondDict[key]).__name__ == 'dict':            thisDepth = 1+ getTreeDepth(secondDict[key])        else:            thisDepth = 1        if thisDepth > maxDepth: maxDepth = thisDepth    return maxDepthdef retrieveTree(i):    listOfTrees =[{'no surfacing': {0: 'no', 1: {'flippers': \        {0: 'no', 1: 'yes'}}}},        {'no surfacing': {0: 'no', 1: {'flippers': \        {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}        ]    return listOfTrees[i]def plotMidText(cntrPt, parentPt, txtString):    xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]    yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]    createPlot.ax1.text(xMid, yMid, txtString)    def plotTree(myTree, parentPt, nodeTxt):    numLeafs = getNumLeafs(myTree)        depth = getTreeDepth(myTree)        firstStr = myTree.keys()[0]    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW,\            plotTree.yOff)    plotMidText(cntrPt, parentPt, nodeTxt)        plotNode(firstStr, cntrPt, parentPt, decisionNode)        secondDict = myTree[firstStr]    plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD        for key in secondDict.keys():        if type(secondDict[key]).__name__=='dict':            plotTree(secondDict[key],cntrPt,str(key))        else:            plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff),            cntrPt, leafNode)            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))    plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD    def createPlot(inTree):    fig = plt.figure(1, facecolor='white')    fig.clf()    axprops = dict(xticks=[], yticks=[])    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)    plotTree.totalW = float(getNumLeafs(inTree))    plotTree.totalD = float(getTreeDepth(inTree))    plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0;    plotTree(inTree, (0.5,1.0), '')    plt.show()

图片标题

扩展测试 lens.py

Project Address: ` https://github.com/TheOneAC/ML.git`    dataset:  `lens.txt in ML/ML_ation/tree`
#!/usr/bin/pythonimport treesimport treePlotterfr = open("lenses.txt")lenses = [inst.strip().split('\t') for inst in fr.readlines()]lensesLabels=['age', 'prescript', 'astigmatic', 'tearRate']lensesTree = trees.createTree(lenses,lensesLabels)print(lensesTree)treePlotter.createPlot(lensesTree)

图片标题

转载于:https://www.cnblogs.com/zeroArn/p/6691287.html

你可能感兴趣的文章
ASP.NET/C#获取文章中图片的地址
查看>>
Spring MVC 入门(二)
查看>>
Java处理多人同时读写文件的文件锁处理
查看>>
设计模式IOS篇-第二章:委托模式
查看>>
beego——日志处理
查看>>
【连载】 FPGA Verilog HDL 系列实例--------十进制加减法计数器
查看>>
MySQL中MyISAM与InnoDB区别及选择
查看>>
DataGrid 上修改數據
查看>>
nginx php-fpm安装配置(转)
查看>>
重读The C programming Lanuage 笔记一:类型转换
查看>>
复杂类型的属性注入
查看>>
回家最好最快路线
查看>>
mysql面试题
查看>>
格式化输出数字和时间
查看>>
页面中公用的全选按钮,单选按钮组件的编写
查看>>
判断文本框输入的文字长度
查看>>
java笔记--用ThreadLocal管理线程,Callable<V>接口实现有返回值的线程
查看>>
Scaling Pinterest - From 0 To 10s Of Billions Of Page Views A Month In Two Years
查看>>
SelectSort 选择排序
查看>>
关于android 加载https网页的问题
查看>>