机器学习实战2决策树算法

文章目录

    • 决策树算法核心是要解决两个的关键问题
    • sklearn中的决策树模型
    • sklearn建模步骤
    • 分类树
      • Criterion
      • random_state && splitter
      • 剪枝参数
      • max_depth
      • min_samples_leaf&&min_samples_split
      • max_features&&min_impurity_decrease
      • 确认最优剪枝参数
      • 目标权重参数
      • 重要属性和接口
    • 回归树
      • 参数、属性、接口
        • Criterion
      • 交叉验证

决策树算法核心是要解决两个的关键问题

1、如何从数据表中照出最佳节点和最佳分支
2、如何让决策树停止生长防止过拟合

就是说假如我有一张数据表,数据表中有成千上万个特征,我要把他们都提问完吗?

sklearn中的决策树模型

6.png
本文主要是学习分类树和回归树

sklearn建模步骤

7.png

from sklearn import tree
clf = tree.DecisionTreeClassifier()
clf = clf.fit(x_train, y_train)
result = clf.score(x_test, y_test)

分类树

Criterion

为了将表格转化为一棵树,决策树需要找到最佳节点和最佳分支方法,对于分类数来说衡量最佳的方法是叫做不纯度,通常来说不纯度越低,决策树对训练集的拟合效果越好,所有的决策树算法都是将和不纯度相关的某个属性最优化,不管我们用那个算法,都是追求的与不纯度相关的指标最优化
不纯度基于节点来计算,树中的每个节点都会有一个不纯度,并且子节点的不纯度一定低于父节点的不纯度,在一棵决策树上叶子节点的不纯度一定是最低的

Criterion这个参数就是用来决定不纯度计算方法的,sklearn中提供了两种方法
一种是输入“entropy”,使用信息熵
一种是输入"gini",使用基尼系数

8.png
我们无法干扰信息熵和基尼系数的计算,所以这里我们知道怎么算的即可,sklearn中的方法我们是无法干扰的
相比于基尼系数来说,信息熵对于不纯度更加敏感,对不纯度的更强,但在实际使用中两者的效果差不多,信息熵的计算相比于基尼系数会慢一点,因为基尼系数没有对数运算,因为信息熵对于不纯度更加敏感,所以信息熵在计算决策树时候会更加仔细,所以对于高维数据或者噪音很多的数据来说很容易过拟合
关于参数如何选择
9.png
10.png
我们可以看到我们的决策树中并没有用到我们所给的所有属性
11.png

clf = tree.DecisionTreeClassifier(criterion = "entropy", random_state=30)
clf = clf.fit(Xtrain, Ytrain)
score = clf.score(Xtest, Ytest) # 返回预测的准确度accuracy
score

random_state && splitter

用来设置分支中的随机模式的参数,默认为None,在高维度时随机性会表现更明显,低维度数据几乎不会显现,我们任意给random_state一个数值可以让模型稳定下来
决策树是随机的
splitter也是用来控制决策树中随机选项的可以输入best,决策树虽然分支时会随机但会有限选择更重要的特征进行分支,输入random分支时会更加随机,树会更深,拟合将会降低,这也是防止过拟合的一种方法

clf = tree.DecisionTreeClassifier(criterion = "entropy", splitter="best")
clf = clf.fit(Xtrain, Ytrain)
score = clf.score(Xtest, Ytest) # 返回预测的准确度accurac
score

剪枝参数

在不加限制的情况下,一颗决策树会生长到衡量不纯度的指标最优,或没有更多的特征可用停止,这样的决策树往往会过拟合,也就是说他会在训练集上表现很好,在测试集上表现却很糟糕
12.png

score = clf.score(Xtrain, Ytrain)
score

13.png

我们要分清楚过拟合的概念,过拟合,也就是说他会在训练集上表现很好,在测试集上表现却很糟糕,但如果我们在训练集和测试集上的表现效果都很好的话不能称为过拟合

剪枝策略对于决策树的影响巨大,正确的剪枝策略是决策树优化的核心

max_depth

限制树的最大深度,超过设定深度的树枝全部剪掉
这是用的最广泛的剪枝参数,在高维度低样本量时非常有效,决策树多生长一层对样本的需求量就会增加一倍,所以限制决策树的深度能够特别有效的限制过拟合,在集成算法中也非常常用,在实际使用过程中,建议我们从3开始尝试,看看拟合的效果再决定是否增加深度

min_samples_leaf&&min_samples_split

这两个是用来限制叶子节点的参数,
min_samples_leaf建议从5开始使用
min_samples_split:一个节点至少包含min_samples_split个样本才被允许进行分支

max_features&&min_impurity_decrease

14.png

确认最优剪枝参数

使用确认超参数的曲线

import matplotlib.pyplot as plt
test = []
for i in range(10):clf = tree.DecisionTreeClassifier(criterion = "entropy", random_state=30, splitter="random", max_depth = i + 1)clf = clf.fit(Xtrain, Ytrain)score = clf.score(Xtest, Ytest) # 返回预测的准确度accuracytest.append(score)
plt.plot(range(1,11), test,color = "red", label="max_depth")
plt.legend()
plt.show()

15.png
16.png

目标权重参数

17.png

重要属性和接口

fit
score
apply
predict
18.png
19.png

回归树

参数、属性、接口

Criterion

回归树衡量分枝质量的指标,支持的有三种
1、mse使用均方误差
2、friedman_mse误差费尔德曼均方误差
3、mae绝对均方误差
这里面也有许多数学原理,但是我们在使用sklearn时不用关心,因为这些因素我们并无法干预
20.png
属性依然是feature_importance_
接口中依然是
fit
score
apply
predict
是最核心
21.png

交叉验证

交叉验证是用来验证模型稳定性的一种方法,我们将数据划分为n份,依次使用其中一份作为测试集,其他n-1份作为训练集,多次计算模型的精确性来评估模型的平均准确程度
22.png

23.png
导入所需要的库

from sklearn.datasets import load_diabetes
from sklearn.model_selection import cross_val_score
from sklearn.tree import DecisionTreeRegressor

导入数据集

diabetes = load_diabetes()

实例化并交叉验证

regressor = DecisionTreeRegressor(random_state = 0) #实例化
cross_val_score(regressor, diabetes.data, diabetes.target, cv = 10, scoring = "neg_mean_squared_error") #交叉验证

参数解读
1、第一个参数可以是回归也可以是分类,这里的模型不止可以是决策树,可以是其他的支持向量机、随机森林等等模型,可以是任何我们实例化后的算法模型
2、第二个参数是完整的不需要分测试集和训练集的特征矩阵,交叉验证会自己帮我们划分测试集和数据集
3、第三个参数是数据的标签(完整的数据标签)
4、cv = 10是将数据集分成10分,每次用其中的一份作为测试集,通常我们将这个数设为5
5、scoring用后面的neg_mean_squared_error衡量我们交叉测试的结果,当默认时会返回R2可能为负数,但是我们在做回归时最常用的是均方误差,neg_mean_squared_error是负均方误差,R2越接近1越好

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/30313.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

VR全景智慧文旅,用科技助力旅游业振兴

引言: 近年来,科技的迅猛发展将我们带入一个全新的数字化时代,而虚拟现实(Virtual Reality,简称VR)技术则以其令人惊叹的全新方式,影响着各个领域。其中,旅游业作为人们探索世界、体…

Camunda 7.x 系列【12】创建流程引擎

有道无术,术尚可求,有术无道,止于术。 本系列Spring Boot 版本 2.7.9 本系列Camunda 版本 7.19.0 源码地址:https://gitee.com/pearl-organization/camunda-study-demo 文章目录 1. ProcessEngine2. 创建流程引擎2.1 Java API2.2 XML 配置2.3 Spring2.4 Spring Boot1. Pr…

系统架构设计师笔记第35期:表现层框架设计

表现层框架设计是指在软件系统中,将用户界面(UI)和用户交互逻辑与后端业务逻辑分离,使用特定的框架来组织和管理表现层的功能和结构。下面是表现层框架设计的一般步骤和常用技术: 确定需求和功能:首先&…

【2.1】Java微服务:详解Hystrix

✅作者简介:大家好,我是 Meteors., 向往着更加简洁高效的代码写法与编程方式,持续分享Java技术内容。 🍎个人主页:Meteors.的博客 💞当前专栏: Java微服务 ✨特色专栏: 知识分享 &am…

用C++实现的RTS游戏的路径查找算法(A*、JPS、Wall-tracing)

在实时策略(RTS)游戏中,路径查找是一个关键的问题。游戏中的单位需要能够找到从一个地方到另一个地方的最佳路径。这个问题在计算机科学中被广泛研究,有许多已经存在的算法可以解决这个问题。在本文中,我们将探讨三种在…

NeRF基础代码解析

embedders 对position和view direction做embedding。 class FreqEmbedder(nn.Module):def __init__(self, in_dim3, multi_res10, use_log_bandsTrue, include_inputTrue):super().__init__()self.in_dim in_dimself.num_freqs multi_resself.max_freq_log2 multi_resself…

php如何爬取天猫和淘宝商品数据

这篇文章主要介绍了php如何爬取天猫和淘宝商品数据,具有一定借鉴价值,感兴趣的朋友可以参考下,希望大家阅读完这篇文章之后大有收获,下面让小编带着大家一起了解一下。 一、思路 最近做了一个网站用到了从网址爬取天猫和淘宝的商…

row_number()分页返回结果顺序不确定

之前通过row_number()实现分页查询时: select top [PageSize] * from (select row_number() over (order by id desc) as RowNum,*from table ) as A where RowNum > (PageIndex - 1) * PageSize发现查询出来的结果顺序是不确定的,查询官方文档&am…

基于遗传算法改进的支持向量机多分类仿真,基于GA-SVM的多分类预测,支持相机的详细原理

目录 背影 支持向量机SVM的详细原理 SVM的定义 SVM理论 遗传算法的原理及步骤 SVM应用实例,基于遗传算法优化SVM的多分类预测 完整代码包括SVM工具箱:https://download.csdn.net/download/abc991835105/88175549 代码 结果分析 展望 背影 多分类预测对现代智能化社会拥有重…

VGG16模型详解

VGG16模型详解 0、VGG16介绍 VGG16是一种深度卷积神经网络,由牛津大学的研究团队于2014年开发。 VGG16在2014年的ImageNet Large Scale Visual Recognition Challenge (ILSVRC) 竞赛中取得了显著的成绩。它在图像分类任务中获得了当年的第二名,其准确…

matplotlib 笔记 plt.grid

用于添加网格线 主要参数 visible 布尔值,True表示画网格 which表示要显示的刻度线类型,可以是 major(主刻度)或 minor(次刻度),或者同时显示(both)alpha 透明度 …

音视频--视频数据传输

参考文献 H264码流RTP封装方式详解:https://blog.csdn.net/water1209/article/details/126019272H264视频传输、编解码----RTP协议对H264数据帧拆包、打包、解包过程: https://blog.csdn.net/wujian946110509/article/details/79129338H264之NALU解析&a…

【Redis】初学Redis

目录 使用Redisyum安装redis启动redis操作redis设置远程连接 Redis路线Redis 使用Redis yum安装redis 使用命令,直接将Redis安装到linux服务器: yum -y install redis启动redis redis-server /etc/redis.conf &操作redis redis-cli设置远程连接…

Shopee虾皮买家号注册时需要注意什么问题

虾皮是一家在线购物平台,如果您打算在虾皮上注册一个买家账号,以下是一些需要注意的问题: 账号安全:确保您选择一个安全的密码,并定期更改密码,以保护您的账号免受未经授权的访问。 个人信息:…

网页版Java(Spring/Spring Boot/Spring MVC)五子棋项目(四)对战模块

网页版Java(Spring/Spring Boot/Spring MVC)五子棋项目(四)对战模块 一、约定前后端交互接口1. 建立连接接口2. 针对落子的请求和响应 二、实现前端页面三、实现后端1. 当用户进入房间,更新用户状态 OnlineUserManager…

Linux mysql5.7开启 binlog

查看 mysql是否开启 binlog。 查看命令: show variables like %log_bin%; log_bin OFF 是关闭的状态。 编辑my.cnf配置文件 vim /etc/my.cnf 默认的配置文件内容: 增加下面内容 server_id 1 binlog_format ROW log-bin mysql_log_bin 重启mysq…

Chromium内核浏览器编译记(三)116版本内核UI定制

转载请注明出处:https://blog.csdn.net/kong_gu_you_lan/article/details/132180843?spm1001.2014.3001.5501 本文出自 容华谢后的博客 往期回顾: Chromium内核浏览器编译记(一)踩坑实录 Chromium内核浏览器编译记(…

木马免杀(篇一)基础知识学习

木马免杀(篇一)基础知识学习 ———— 简单的木马就是一个 exe 文件,比如今年hw流传的一张图:某可疑 exe 文件正在加载。当然木马还可能伪造成各式各样的文件,dll动态链接库文件、lnk快捷方式文件等,也可能…

MySQL单表查询

单表查询 素材: 表名:worker-- 表中字段均为中文,比如 部门号 工资 职工号 参加工作 等 CREATE TABLE worker ( 部门号 int(11) NOT NULL, 职工号 int(11) NOT NULL, 工作时间 date NOT NULL, 工资 float(8,2) NOT NULL, 政治面貌 varch…

Spring MVC项目概述及创建

Spring MVC项目概述及创建 1.什么是Spring MVC Spring MVC是基于SevletAPI的原始Web框架。Spring MVC项目也叫做SpringWeb项目。 它是在springboot项目中引入了web框架,原本的spring项目不具备网络通信能力,而spring mvc允许http响应,当用…