MachineLearning(7)-决策树基础+sklearn.DecisionTreeClassifier简单实践

sklearn.DecisionTreeClassifier决策树简单使用

  • 1.决策树算法基础
  • 2.sklearn.DecisionTreeClassifier简单实践
    • 2.1 决策树类
    • 2.3 决策树构建
      • 2.3.1全数据集拟合,决策树可视化
      • 2.3.2交叉验证实验
      • 2.3.3超参数搜索
      • 2.3.4模型保存与导入
      • 2.3.5固定随机数种子
  • 参考资料

1.决策树算法基础

决策树模型可以用来做 回归/分类 任务。

每次选择一个属性/特征,依据特征的阈值,将特征空间划分为 与 坐标轴平行的一些决策区域。如果是分类问题,每个决策区域的类别为该该区域中多数样本的类别;如果为回归问题,每个决策区域的回归值为该区域中所有样本值的均值。

决策树复杂程度 依赖于 特征空间的几何形状。根节点->叶子节点的一条路径产生一条决策规则。

决策树最大优点:可解释性强
决策树最大缺点:不是分类正确率最高的模型

决策树的学习是一个NP-Complete问题,所以实际中使用启发性的规则来构建决策树。
step1:选最好的特征来划分数据集
step2:对上一步划分的子集重复步骤1,直至停止条件(节点纯度/分裂增益/树深度)

不同的特征衡量标准,产生了不同的决策树生成算法:

算法最优特征选择标准
ID3信息增益:Gain(A)=H(D)−H(D∥A)Gain(A)=H(D)-H(D\|A)Gain(A)=H(D)H(DA)
C4.5信息增益率:GainRatio(A)=Gain(A)/Split(A)GainRatio(A)=Gain(A)/Split(A)GainRatio(A)=Gain(A)/Split(A)
CARTgini指数增益:Gini(D)−Gini(D∥A)Gini(D)-Gini(D\|A)Gini(D)Gini(DA)

k个类别,类别分布的gini 指数如下,gini指数越大,样本的不确定性越大:
Gini(D)=∑k=1Kpk(1−pk)=1−∑k=1Kpk2Gini(D) =\sum_{k=1}^Kp_k(1-p_k)=1-\sum_{k=1}^Kp_k^2Gini(D)=k=1Kpk(1pk)=1k=1Kpk2

CART – Classification and Regression Trees 的缩写1984年提出的一个特征选择算法,对特征进行是/否判断,生成一棵二叉树。且每次选择完特征后不对特征进行剔除操作,所有同一条决策规则上可能出现重复特征的情况。

2.sklearn.DecisionTreeClassifier简单实践

Scikit-learn(sklearn)是机器学习中常用的第三方模块,其建立在NumPy、Scipy、MatPlotLib之上,包括了回归,降维,分类,聚类方法。

sklearn 通过以下两个类实现了 决策分类树决策回归树

sklearn 实现了ID3和Cart 算法,criterion默认为"gini"系数,对应为CART算法。还可设置为"entropy",对应为ID3。(计算机最擅长做的事:规则重复计算,sklearn通过对每个特征的每个切分点计算信息增益/gini增益,得到当前数据集合最优的特征及最优划分点)

2.1 决策树类

sklearn.tree.DecisionTreeClassifier(criterion=’gini’*,splitter=’best’, max_depth=None, 
min_samples_split=2, min_samples_leaf=1, min_weight_fraction_leaf=0.0,
max_features=None, random_state=None, max_leaf_nodes=None, 
min_impurity_decrease=0.0, min_impurity_split=None, class_weight=None, presort=False)
DecisionTreeRegressor(criterion=’mse’, splitter=’best’, 
max_depth=None, min_samples_split=2, min_samples_leaf=1, 
min_weight_fraction_leaf=0.0, max_features=None, random_state=None, 
max_leaf_nodes=None, min_impurity_decrease=0.0, 
min_impurity_split=None, presort=False)
Criterion选择属性的准则–gini–cart算法
splitter特征划分点的选择策略:best 特征的所有划分点中找最优
random 部分划分点中找最优
max_depth决策树的最大深度,none/int 限制/不限制决策树的深度
min_samples_split节点 继续划分需要的最小样本数,如果少于这个数,节点将不再划分
min_samples_leaf限制叶子节点的最少样本数量,如果叶子节点的样本数量过少会被剪枝
min_weight_fraction_leaf叶子节点的剪枝规则
max_features选取用于分类的特征的数量
random_state随机数生成的一些规则、
max_leaf_nodes限制叶子节点的数量,防止过拟合
min_impurity_decrease表示结点减少的最小不纯度,控制节点的继续分割规律
min_impurity_split表示结点划分的最小不纯度,控制节点的继续分割规律
class_weight设置各个类别的权重,针对类别不均衡的数据集使用
不适用于决策树回归
presort控制决策树划分的速度

2.3 决策树构建

采用sklearn内置数据集鸢尾花数据集做实验。

导入第三方库

from sklearn import tree
from sklearn.tree import DecisionTreeClassifier 
from sklearn.datasets import load_iris
import graphviz
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, recall_score
import joblib
plt.switch_backend('agg')

2.3.1全数据集拟合,决策树可视化

def demo1():# 全数据集拟合,决策树可视化iris = load_iris()x, y = load_iris(return_X_y = True)                     # x[list]-feature,y[]-label clf = tree.DecisionTreeClassifier()                     # 实例化了一个类,可以指定类参数,定制决策树模型clf = clf.fit(x,y)                                      # 训练模型print("feature name ", iris.feature_names)              # 特征列表, 自己的数据可视化时,构建一个特征列表即可print("label name ",iris.target_names)                  # 类别列表dot_data = tree.export_graphviz(clf, out_file = None, feature_names = iris.feature_names, class_names = iris.target_names )    graph = graphviz.Source(dot_data)                        # 能绘制树节点的一个接口graph.render("iris")                                     # 存成pdf图
tree.export_graphviz 参数
feature_names特征列表list,和训练时的特征列表排列顺序对其即可
class_names类别l列表ist,和训练时的label列表排列顺序对其即可
filledFalse/True,会依据criterion的纯度将节点显示成不同的颜色

value中的值显示的是各个类别样本的数量(二分类就是[负样本数,正样本数])

在这里插入图片描述

2.3.2交叉验证实验

def demo2():# n-折实验iris = load_iris()iris_feature = iris.data                                # 与demo1中的x,y是同样的数据iris_target = iris.target# 数据集合划分参数:train_x, test_x, train_y, test_y = train_test_split(iris_feature,iris_target,test_size = 0.2, random_state = 1)dt_model = DecisionTreeClassifier()dt_model.fit(train_x, train_y)                          # 模型训练predict_y = dt_model.predict(test_x)                    # 模型预测输出# score = dt_model.score(test_x,test_y)                 # 模型测试性能: 输入:feature_test,target_test , 输出acc# print(score)                                          # 性能指标print("label: \n{0}".format(test_y[:5]))                # 输出前5个labelprint("predict: \n{0}".format(predict_y[:5]))           # 输出前5个label# sklearn 内置acc, recall, precision统计接口print("test acc: %.3f"%(accuracy_score(test_y, predict_y)))# print("test recall: %.3f"%(recall_score(test_y, predict_y)))  # 多类别统计召回率需要指定平均方式# print("test precision: %.3f"%(precision_score(test_y, predict_y))) # 多类别统计准确率需要指定平均方式

2.3.3超参数搜索

def model_search(feas,labels):# 模型参数选择,全数据5折交叉验证,出结果min_impurity_de_entropy = np.linspace(0, 0.01, 10)      # 纯度增益下界,划分后降低量少于这个值,将不进行分裂min_impurity_split_entropy = np.linspace(0, 0.4, 10)    # 当前节点纯度小于这个值将不分裂,较高版本中已经取消这个参数max_depth_entropy = np.arange(1,11)                     # 决策树的深度# param_grid = {"criterion" : ["entropy"], "min_impurity_decrease" : min_impurity_de_entropy,"max_depth" : max_depth_entropy,"min_impurity_split" :  min_impurity_split_entropy }param_grid = {"criterion" : ["entropy"], "max_depth" : max_depth_entropy, "min_impurity_split" :  min_impurity_split_entropy }clf = GridSearchCV(DecisionTreeClassifier(), param_grid, cv = 5)  # 遍历以上超参, 通过多次五折交叉验证得出最优的参数选择clf.fit(feas, label)                                    print("best param:", clf.best_params_)                  # 输出最优参数选择print("best score:", clf.best_score_)

2.3.4模型保存与导入

模型保存

joblib.dump(clf,"./dtc_model.pkl")

模型导入

model_path = “./dtc_model.pkl”
clf = joblib.load(model_path)

2.3.5固定随机数种子

1.五折交叉验证,数据集划分随机数设置 random_state

train_test_split(feas, labels, test_size = 0.2, random_state = 1 )

2.模型随机数设置 andom_state

DecisionTreeClassifier(random_state = 1)

参考资料

1.官网类接口说明:
https://scikit-learn.org/dev/modules/generated/sklearn.tree.DecisionTreeClassifier.html#sklearn.tree.DecisionTreeClassifier

可视化接口说明https://scikit-learn.org/stable/modules/generated/sklearn.tree.export_graphviz.html

2.决策树超参数调参技巧:https://www.jianshu.com/p/230be18b08c2

3.Sklearn.metrics 简介及应用示例:https://blog.csdn.net/Yqq19950707/article/details/90169913

4.sklearn的train_test_split()各函数参数含义解释(非常全):https://www.cnblogs.com/Yanjy-OnlyOne/p/11288098.html

5.sklearn.tree.DecisionTreeClassifier 详细说明:https://www.jianshu.com/p/8f3f1e706f11

6.使用scikit-learn中的metrics以及DecisionTreeClassifier重做《机器学习实战》中的隐形眼镜分类问题:http://keyblog.cn/article-235.html

7.决策树算法:https://www.cnblogs.com/yanqiang/p/11600569.html

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/444969.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

leetcode72 编辑距离

给定两个单词 word1 和 word2,计算出将 word1 转换成 word2 所使用的最少操作数 。 你可以对一个单词进行如下三种操作: 插入一个字符 删除一个字符 替换一个字符 示例 1: 输入: word1 "horse", word2 "ros" 输出: 3 解释: ho…

即时通讯系统架构

有过几款IM系统开发经历,目前有一款还在线上跑着。准备简单地介绍一下大型商业应用的IM系统的架构。设计这种架构比较重要的一点是低耦合,把整个系统设计成多个相互分离的子系统。我把整个系统分成下面几个部分:(1)状态…

leetcode303 区域和检索

给定一个整数数组 nums,求出数组从索引 i 到 j (i ≤ j) 范围内元素的总和,包含 i, j 两点。 示例: 给定 nums [-2, 0, 3, -5, 2, -1],求和函数为 sumRange() sumRange(0, 2) -> 1 sumRange(2, 5) -> -1 sumRange(0,…

网络游戏的客户端同步问题 .

有关位置同步的方案实际上已经比较成熟,网上也有比较多的资料可供参考。在《带宽限制下的视觉实体属性传播》一文中,作者也简单提到了位置同步方案的构造过程,但涉及到细节的地方没有深入,这里专门针对这一主题做些回顾。 最直接的…

leetcode319 灯泡的开关

初始时有 n 个灯泡关闭。 第 1 轮,你打开所有的灯泡。 第 2 轮,每两个灯泡你关闭一次。 第 3 轮,每三个灯泡切换一次开关(如果关闭则开启,如果开启则关闭)。第 i 轮,每 i 个灯泡切换一次开关。 …

网游服务器端设计思考:心跳设计

网络游戏服务器的主要作用是模拟整个游戏世界,客户端用过网络连接把一些信息数据发给服务器,在操作合法的情况下,更新服务器上该客户端对应的player实体、所在场景等,并把这些操作及其影响广播出去。让别的客户端能显示这些操作。…

算法(25)-括号

各种括号1.LeetCode-22 括号生成--各种括号排列组合2.LeetCode-20 有效括号(是否)--堆栈3.LeetCode-32 最长有效括号(长度)--dp4.LeetCode-301删除无效括号 --多种删除方式1.LeetCode-22 括号生成–各种括号排列组合 数字 n 代表生成括号的对数,请你设计一个函数&a…

leetcode542 01矩阵

给定一个由 0 和 1 组成的矩阵,找出每个元素到最近的 0 的距离。 两个相邻元素间的距离为 1 。 示例 1: 输入: 0 0 0 0 1 0 0 0 0 输出: 0 0 0 0 1 0 0 0 0 示例 2: 输入: 0 0 0 0 1 0 1 1 1 输出: 0 0 0 0 1 0 1 2 1 注意: 给定矩阵的元素个数不超过 10000。…

RPC、RMI与MOM与组播 通信原理 .

远程过程调用(RPC): 即对远程站点机上的过程进行调用。当站点机A上的一个进程调用另一个站点机上的过程时,A上的调用进程挂起,B上的被调用过程执行,并将结果返回给调用进程,使调用进程继续执行【…

一个简单的游戏服务器框架 .

最近一段时间不是很忙,就写了一个自己的游戏服务器框架雏形,很多地方还不够完善,但是基本上也算是能够跑起来了。我先从上层结构说起,一直到实现细节吧,想起什么就写什么。 第一部分 服务器逻辑 服务器这边简单的分为三…

leetcode97 交错字符串

给定三个字符串 s1, s2, s3, 验证 s3 是否是由 s1 和 s2 交错组成的。 示例 1: 输入: s1 "aabcc", s2 "dbbca", s3 "aadbbcbcac" 输出: true 示例 2: 输入: s1 "aabcc", s2 "dbbca", s3 "aadbbbaccc" 输…

leetcode 33 搜索旋转排序数组 到处是细节的好题

这个题想了想就会做,只是细节真的能卡死人,找了好久的bug。甚至我怀疑我现在的代码可能还有错,只是没例子测出来。 假设按照升序排序的数组在预先未知的某个点上进行了旋转。 ( 例如,数组 [0,1,2,4,5,6,7] 可能变为 [4,5,6,7,0,1…

MachineLearning(8)-PCA,LDA基础+sklearn 简单实践

PCA,LDA基础sklearn 简单实践1.PCAsklearn.decomposition.PCA1.PCA理论基础2.sklearn.decomposition.PCA简单实践2.LDAsklearn.discriminant_analysis.LinearDiscriminantAnalysis2.1 LDA理论基础2.2 sklearn LDA简单实践1.PCAsklearn.decomposition.PCA 1.PCA理论基础 PCA:&…

leetcode198 打家劫舍

你是一个专业的小偷,计划偷窃沿街的房屋。每间房内都藏有一定的现金,影响你偷窃的唯一制约因素就是相邻的房屋装有相互连通的防盗系统,如果两间相邻的房屋在同一晚上被小偷闯入,系统会自动报警。 给定一个代表每个房屋存放金额的…

linux下的RPC

一、概述 在传统的编程概念中,过程是由程序员在本地编译完成,并只能局限在本地运行的一段代码,也即其主程序和过程之间的运行关系是本地调用关系。因此这种结构在网络日益发展的今天已无法适应实际需求。总而言之,传统过程调用模式…

算法(28)--矩阵搜索系列

矩阵搜索1.leetcode-200. 岛屿数量2.leetcode-695. 岛屿的最大面积3.leetcode-463. 岛屿的周长4.剑指 Offer 12. 矩阵中的路径5.leetcode-329. 矩阵中的最长递增路径6.leetcode-1091. 二进制矩阵中的最短路径1.leetcode-200. 岛屿数量 给你一个由 ‘1’(陆地&#…

leetcode213 打家劫舍II

你是一个专业的小偷,计划偷窃沿街的房屋,每间房内都藏有一定的现金。这个地方所有的房屋都围成一圈,这意味着第一个房屋和最后一个房屋是紧挨着的。同时,相邻的房屋装有相互连通的防盗系统,如果两间相邻的房屋在同一晚…

PaperNotes(4)-高质量图像生成-CGAN-StackGAN-Lapgan-Cyclegan-Pix2pixgan

cgan,stackgan,lapgan,cyclegan,pix2pixgan1.Conditional GAN1.1简介1.2网络结构与训练1.3特点与用途2.Stack GAN2.1简介2.2网络结构与训练2.3特点与用途3.Lap GAN3.1简介3.2网络结构与训练3.3特点与用途4.Pix2pix GAN4.1 简介4.2 网络结构和训练4.3 特点和用途5.Patch GAN6.Cy…

leetcode206 反转链表

反转一个单链表。 示例: 输入: 1->2->3->4->5->NULL 输出: 5->4->3->2->1->NULL 进阶: 你可以迭代或递归地反转链表。你能否用两种方法解决这道题? 经典题不解释 /*** Definition for singly-linked list.* public class ListNode…

leetcode 152 乘积最大子序列

给定一个整数数组 nums ,找出一个序列中乘积最大的连续子序列(该序列至少包含一个数)。 示例 1: 输入: [2,3,-2,4] 输出: 6 解释: 子数组 [2,3] 有最大乘积 6。 示例 2: 输入: [-2,0,-1] 输出: 0 解释: 结果不能为 2, 因为 [-2,-1] 不是子…