决策树 算法原理及代码

   决策树可以使用不熟悉的数据集合,并从中提取出一系列的规则,这是机器根据数据集创建规则的过程,就是机器学习的过程。用一个小案例分析:

 

通过No surfacing  和 flippers判断该生物是否是鱼,No surfacing 是离开水面是否可以生存,flippers判断是否有脚蹼

引入信息增益和信息熵的概念:

信息熵:计算熵,我们需要计算所有类别所有可能值包含的信息期望值。

                                        p(x)是类别出现的概率

条件熵(表示在已知随机变量X的条件下随机变量Y的不确定性。):

                                       

信息增益(划分数据集前后的信息发生的变化,通俗的说,就是信息熵减去条件熵):

                                         

代码实现:

      加载数据:

def createDataSet():dataSet = [[1,1,'yes'],[1,1,'yes'],[1,0,'no'],[0,1,'no'],[0,1,'no']]labels = ['no surfacing','flippers']return dataSet,labels

计算原始熵:

def calcShannonEnt(dataSet):numEntries = len( dataSet)labelCounts = { }for featVec in dataSet:currentLabel = featVec[-1]if  currentLabel not in labelCounts.keys():labelCounts[currentLabel] = 0labelCounts [currentLabel]+=1shannonEnt = 0.0 for key in labelCounts :prob = float (labelCounts[key])/numEntriesshannonEnt -=prob * log(prob,2)return shannonEnt 
划分数据集
def splitDataSet(dataSet,axis,value):  # 待划分的数据集  ,划分数据集的特征,需要返回的特征的值retDataSet=[]for featVec in dataSet:if featVec[axis] == value :reduceFeatVec=featVec[:axis]  #取不到axis这一行reduceFeatVec.extend(featVec[axis+1:])retDataSet.append(reduceFeatVec)return retDataSet

测试数据及结果:


(myDat,0,1)  myDat是数据集,0是第一次划分数据集,1是第一列为1的数据

计算出条件熵,然后求出信息增益,并找到最大的信息增益,最大的信息增益就是找到最好的划分数据集的特征

def chooseBestFeatureToSplit(dataSet):numFeatures=len(dataSet[0])-1#计算出原始的香农熵baseEntropy = calcShannonEnt(dataSet)bestInfoGain = 0.0; bestFeature =-1for i in range (numFeatures):#创建唯一的分类标签列表featList = [example[i] for example in dataSet]uniqueVals = set (featList)  #去重复#条件熵的初始化newEntropy = 0.0for value in uniqueVals :#划分   获得数据集subDataSet = splitDataSet(dataSet,i ,value)prob=len(subDataSet)/float(len(dataSet)) # 概率#条件熵的计算newEntropy += prob * calcShannonEnt (subDataSet)# 信息增益infoGain = baseEntropy -newEntropyif (infoGain >bestInfoGain):bestInfoGain = infoGain #找到最大的信息增益bestFeature =i  #找出最好的划分数据集的特征return bestFeature

测试数据:

dataSet,labels = createDataSet()
print(dataSet)
print(chooseBestFeatureToSplit(dataSet))

输入结果:

投票机制:


def majorityCnt(classList):classCount={}for vote in classList:if vote not in classCount.keys() :classCount[vote]=0sortedClassCount = sorted (classCount.iteritems(),key=operator.itemgetter(1),reverse=True)return sortedClassCount[0][0] 

创建树:

def createTree(dataSet,labels):classList = [example[-1] for example in dataSet]if classList.count(classList[0] )== len (classList) :return classList[0]if len(dataSet[0]) == 1:return majorityCnt(classList)bestFeat = chooseBestFeatureToSplit(dataSet)bestFeatLabel = labels[bestFeat]myTree={bestFeatLabel:{}}del(labels[bestFeat])featValues = [example[bestFeat] for example in dataSet]uniqueVals = set( featValues)for value in uniqueVals:subLabels =labels[:]myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet,bestFeat,value),subLabels)return myTree

结果:


该方法是用信息增益的方法来构建树,在查阅其他的博客得知:

    ID3算法主要是通过信息增益的大小来判定,最大信息增益的特征就是当前节点,这个算法存在许多的不足,第一,它解决不了过拟合问题,和缺失值的处理,第二,信息增益偏向取值较多的特征,第三,不能处理连续特征问题。

因此,引入C4.5算法,是利用信息增益率来代替信息增益。为了减少过度匹配问题,我们通过剪枝来处理冗余的数据,生成决策树时决定是否要剪枝叫预剪枝,生成树之后进行交叉验证的叫后剪枝。

还有一个是引入基尼指数来进行计算叫CART树,以后再做介绍。

绘制树形图:

decisionNode = dict(boxstyle = "sawtooth", fc="0.8")
leafNode = dict(boxstyle = "round4" ,fc="0.8")
arrow_args = dict(arrowstyle="<-")
def plotNode(nodeTxt,centerPt,parentPt,nodeType):createPlot.ax1.annotate(nodeTxt,xy=parentPt,xycoords='axes fraction' ,\xytext=centerPt ,textcoords='axes fraction',va="center" ,\ha ="center" ,bbox=nodeType,arrowprops = arrow_args)

def getNumLeafs(myTree):numLeafs= 0firstStr =list(myTree.keys())[0]secondDict = myTree[firstStr]for key in secondDict.keys():if type(secondDict[key]).__name__=='dict':numLeafs +=getNumLeafs(secondDict[key])else :  numLeafs+=1return numLeafs
def getTreeDepth(myTree) :maxDepth=0firstStr =list(myTree.keys())[0]secondDict = myTree[firstStr]for key in secondDict.keys():if type(secondDict[key]).__name__=='dict':thisDepth = 1+ getTreeDepth(secondDict[key])else : thisDepth = 1if thisDepth > maxDepth : maxDepth=thisDepthreturn maxDepth
def plotMidText(cntrPt , parentPt ,txtString) :xMid = (parentPt[0]-cntrPt[0])/2.0 +cntrPt[0]yMid = (parentPt[1]-cntrPt[1])/2.0 +cntrPt[1]createPlot.ax1.text(xMid,yMid,txtString)
def plotTree(myTree,parentPt,nodeTxt):numLeafs = getNumLeafs(myTree)depth = getTreeDepth(myTree)firstStr = list(myTree.keys())[0]cntrPt = (plotTree.xOff+(1.0+float(numLeafs))/2.0/plotTree.totalW,plotTree.yOff)plotMidText(cntrPt,parentPt,nodeTxt)plotNode(firstStr,cntrPt,parentPt,decisionNode)secondDict = myTree[firstStr]plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalDfor key in secondDict.keys():if type(secondDict[key]).__name__=='dict':plotTree(secondDict[key],cntrPt,str(key))else:plotTree.xOff = plotTree.xOff +1.0 /plotTree.totalWplotNode(secondDict[key],(plotTree.xOff,plotTree.yOff),cntrPt,leafNode)plotMidText((plotTree.xOff,plotTree.yOff) ,cntrPt,str(key))plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD

def createPlot(inTree) :fig = plt.figure(1,facecolor = 'white')fig.clf()axprops =dict(xticks=[],yticks=[])createPlot.ax1 = plt.subplot(111,frameon = False ,**axprops)plotTree.totalW =float(getNumLeafs(inTree))plotTree.totalD=float(getTreeDepth(inTree))plotTree.xOff = -0.5/plotTree.totalW;plotTree.yOff = 1.0plotTree(inTree,(0.5,1.0),'')plt.show()
createPlot(myTree)





本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/466893.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

深度好文|面试官:进程和线程,我只问这19个问题

# 干了这碗鸡汤&#xff01;我急切地盼望着可以经历一场放纵的快乐&#xff0c;纵使巨大的悲哀将接踵而至&#xff0c;我也在所不惜。-- 太宰治 《人间失格》大家好&#xff0c;这里是周日凌晨4点&#xff0c;仍在笔耕不辍的程序喵大人。下面隆重推出我呕心沥血&#xff0c;耗时…

终于有人将进程间通信讲明白了

使用多进程协作来实现应用和系统是一种被广泛使用的开发方法。多进程协作主要有以下三点优势。将功能模块化&#xff0c;避免重复造轮子。增强模块间的隔离&#xff0c;提供更强的安全保障。提高应用的容错能力。进程间通信&#xff08;Inter-Process Communication&#xff0c…

神舟本本放心率

总得票8520 可以放心购买 22.0% 1942票 不太放心 64.0% 5510票 看情况 12.0% 1068票投票起止时间&#xff1a;2007-11-15 至2008-11-22转载于:https://www.cnblogs.com/badapple126/archive/2007/11/16/962020.html

梯度下降算法

在学习逻辑回归时&#xff0c;对梯度上升算法进行了应用&#xff0c;看到其他的博客讲解&#xff0c;梯度上升算法适合求最大值&#xff0c;梯度下降算法适合求最小值&#xff0c;这里有一个分析&#xff1a;梯度上升算法公式是学习率&#xff0c;是一个常数。这个是根据逻辑回…

花了一个深夜,才用C语言写了一个2048游戏雏形

12年我毕业的第二个月工资&#xff0c;我就买了一个IPAD&#xff0c;然后在IPAD上下了一个2048游戏&#xff0c;玩起来非常爽。然后这几天看到好几个公众号都发了自己写这个游戏的代码&#xff0c;然后我自己也想试试&#xff0c;所以就有了这篇文章&#xff0c;写代码还是很有…

向银行贷款20万, 分期三年买50万的车,个人借款40万, 贷款10年买200万的房子,再贷款120万分创业...

向银行贷款20万按1年期贷款利率为&#xff1a;6%&#xff0c;若按年还贷款&#xff0c;银行贷款利息为&#xff1a;200&#xff0c;000*6%12&#xff0c;000。连本带息&#xff1a;20*106%21.2万分期三年买50万的车 贷款总额30万 年利率按10%算&#xff0c;分三年还清&#xff…

集成算法——Adaboost代码

集成算法是我们将不同的分类器组合起来&#xff0c;而这种组合结果就被称为集成方法或者是元算法。使用集成方法时会有多种形式&#xff1a;可以是不同算法的集成&#xff0c;也可以是同意算法在不同设置下的集成&#xff0c;还可以是数据集不同部分分配给不同分类器之后的集成…

年终抽奖来了

时间很快&#xff0c;2020年已经到了12月份&#xff0c;我从2018年开始写公众号&#xff0c;经过了快两年是时间&#xff0c;我收获了4万的读者&#xff0c;非常开心。我自己是一个挺逗逼的人&#xff0c;而且我写公众号并不觉得我比别人厉害&#xff0c;技术上我真的就是一个很…

嵌入式 Linux下永久生效环境变量bashrc

作者&#xff1a;skdkjxy原文&#xff1a;http://blog.sina.com.cn/s/blog_8795b0970101f1f9.html.bashrc文件 在linux系统普通用户目录&#xff08;cd /home/xxx&#xff09;或root用户目录&#xff08;cd /root&#xff09;下&#xff0c;用指令ls -al可以看到4个隐藏文件&am…

回归分析——线性回归

机器学习中&#xff0c;对于离散的数据可以做分类问题&#xff0c;那对于连续的数据就是做回归问题&#xff0c;这里对一元线性回归和多元线性回归做一个简介&#xff0c;帮组理解。回归分析&#xff1a;从一组样本数据出发&#xff0c;确定变量之间的数学关系式&#xff0c;对…

编译原理(五)自底向上分析之算符优先分析法

自底向上分析之算符优先分析法 说明&#xff1a;以老师PPT为标准&#xff0c;借鉴部分教材内容&#xff0c;AlvinZH学习笔记。 基本过程 1. 一般方法&#xff1a;采用自左向右地扫描和分析输入串&#xff0c;从输入符号串开始&#xff0c;通过反复查找当前句型的句柄&#xff0…

做Android开发,要清楚init.rc里面的东西

init.rc 复习看这个之前&#xff0c;先看看大神总结的文章这篇文章总结的非常到位&#xff0c;但是因为代码不是最新的Android版本&#xff0c;对我们最新的Android版本不适用。http://gityuan.com/2016/02/05/android-init/#init rc文件拷贝拷贝其实也就是把文件放到机器的某个…

宏比较值,坑的一B

昨晚上&#xff0c;我准备睡觉&#xff0c;连总给我发了一段代码#include "stdio.h"#define MAX_MACRO(a, b) ((a) > (b) ? (a) : (b)) int MAX_FUNC(int a, int b) {return ((a) > (b) ? (a) : (b)); }int main() {unsigned int a 1;int b -1;printf(&quo…

Linux下Samba服务器搭建

linux文件共享之samba服务器 ——ubuntu 宗旨&#xff1a;技术的学习是有限的&#xff0c;分享的精神是无限的。 关闭LINUX防火墙命令&#xff1a; #ufwdisable 然后就在windows下ping一下linux的IP&#xff0c;如果能ping通&#xff0c;就可以继续下面的内容&#xff0c;如果p…

搞懂C++为什么难学,看这篇就够了!

学C能干什么&#xff1f; 往细了说&#xff0c;后端、客户端、游戏引擎开发以及人工智能领域都需要它。往大了说&#xff0c;构成一个工程师核心能力的东西&#xff0c;都在C里。跟面向对象型的语言相比&#xff0c;C是一门非常考验技术想象力的编程语言&#xff0c;因此学习起…

看图学源码之FutureTask

RunnableFuture 源码学习&#xff1a; 成员变量 任务的运行状态的转化 package java.util.concurrent; import java.util.concurrent.locks.LockSupport;/**可取消的异步计算。该类提供了Future的基本实现&#xff0c;包括启动和取消计算的方法&#xff0c;查询计算是否完成以…

单片机的引脚,你都清楚吗?

第1课&#xff1a;单片机简叙1.单片机可以做什么&#xff1f;目前单片机渗透到我们生活的各个领域&#xff0c;几乎很难找到哪个领域没有单片机的踪迹。小到电话&#xff0c;玩具&#xff0c;手机&#xff0c;各类刷卡机&#xff0c;电脑键盘&#xff0c;彩电&#xff0c;冰箱&…

Graphviz的安装及纠错

在Anaconda Prompt里边输入conda install graphviz 安装成功之后输入pip install graphviz 它会提示成功安装。 启动 Jupyter Notebook &#xff0c;在文件里边输入 import graphviz 测试&#xff0c;如果没有报错证明&#xff0c;模块安装成功&#xff0c;但是在运行程序…

sklearn——决策树

总结sklearn决策树的使用&#xff0c;方便以后查阅。1.分类决策树 &#xff08;基于CART树&#xff09; 原型&#xff1a;参数&#xff1a;2、回归分类树 原型&#xff1a;参数&#xff1a;3、export_graphviz 当训练完毕一颗决策树时&#xff0c;可以通过sklearn.tree.expor…

Linux下SVN服务器的搭建

Linux下SVN服务器的搭建 宗旨&#xff1a;技术的学习是有限的&#xff0c;分享的精神是无限的。 1、下载工具&#xff08;下载地址&#xff1a;&#xff09; subversion-1.6.1.tar.gz subversion-deps-1.6.1.tar.gz 2、解压两个包&#xff1a; a) tar -xzvf subvers…