机器学习——决策树与随机森林

机器学习——决策树与随机森林

文章目录

  • 前言
  • 一、决策树
    • 1.1. 原理
    • 1.2. 代码实现
    • 1.3. 网格搜索
    • 1.4. 可视化决策树
  • 二、随机森林算法
    • 2.1. 原理
    • 2.2. 代码实现
  • 三、补充(过拟合与欠拟合)
  • 总结


前言

决策树和随机森林都是常见的机器学习算法,用于分类和回归任务,本文将对这两种算法进行介绍。

在这里插入图片描述


一、决策树

1.1. 原理

决策树算法是一种基于树结构的分类和回归算法。它通过对数据集进行递归地二分,选择最佳的特征进行划分,直到达到终止条件。
决策树的每个内部节点表示一个特征,根据测试结果进行分类,每个叶子节点表示一个类别或一个回归值。
决策树的构建可以通过以下几个步骤来实现:

  1. 特征选择:根据某个评价指标(如信息增益、基尼不纯度等),选择最佳的特征作为当前节点的划分特征。(即哪个特征带来最多的信息变化幅度,就选择哪一个特征来分类)

  2. 划分数据集:根据选择的特征,将数据集划分成多个子集,每个子集对应一个分支。对于离散特征,可以根据特征值的不同进行划分;对于连续特征,可以选择一个阈值进行划分。

  3. 递归构建子树:对每个子集递归地构建子树,直到所有子集被正确分类或满足终止条件。常见的终止条件有:达到最大深度、样本数量小于阈值、节点中的样本属于同一类别等。

  4. 避免过拟合:对决策树进行剪枝处理。剪枝可以分为前剪枝和后剪枝: 前剪枝是在构建树的过程中进行剪枝,通过设定一个阈值,信息熵减小的数量小于这个值则停止创建分支;后剪枝则是在决策树构建完成后,对节点检查其信息熵的增益来判断是否进行剪枝。
    还可以通过控制决策树的最大深度(max_depth)

1.2. 代码实现

import numpy as np
from sklearn.datasets import make_moons
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier#生成数据集
np.random.seed(41)
raw_data = make_moons(n_samples=2000, noise=0.25, random_state=41)
data = raw_data[0]
target = raw_data[1]# 训练决策树分类模型
x_train, x_test, y_train, y_test = train_test_split(data, target)
classifer = DecisionTreeClassifier()
classifer.fit(x_train, y_train)
#计算测试数据集在决策树模型上的准确率得分
print(classifer.score(x_test, y_test))
0.916# max_depth 树的最大深度,默认为None
classifer = DecisionTreeClassifier(max_depth=6)
classifer.fit(x_train, y_train)
print(classifer.score(x_test, y_test))
0.934# min_samples_leaf 叶节点所需的最小样本数,默认为1
classifer = DecisionTreeClassifier(max_depth=6, min_samples_leaf=6)
classifer.fit(x_train, y_train)
print(classifer.score(x_test, y_test))
0.938# min_impurity_decrease 划分节点时的最小信息增益
def m_score(value):model = DecisionTreeClassifier(min_impurity_decrease=value)model.fit(x_train, y_train)train_score = model.score(x_train, y_train)test_score = model.score(x_test, y_test)return train_score, test_score
values = np.linspace(0,0.01,50)
score = [m_score(value) for value in values ]
train_s = [s[0] for s in score]
test_s = [s[1] for s in score]
best_index = np.argmax(test_s)
print(test_s[best_index])
print(values[best_index])
plt.plot(train_s,label = "train_s")
plt.plot(test_s,label = "test_s")
plt.legend()
plt.show()

在这里插入图片描述

从以上代码中可以看出在不同参数的选择情况下,准确率(分类器预测正确的样本数量与总样本数量的比例)得分是不同的,越接近1表示模型的预测性能越好

1.3. 网格搜索

可以使用网格搜索获得最优的模型参数:


# 使用网格搜索获得最优的模型参数
from sklearn.model_selection import GridSearchCV
classifer = DecisionTreeClassifier()
params = {"max_depth": np.arange(1, 10),"min_samples_leaf": np.arange(1, 20),"min_impurity_decrease": np.linspace(0,0.4,50),"criterion" : ("gini","entropy")
}
grid_searchcv = GridSearchCV(classifer, param_grid=params, scoring="accuracy",cv=5)  # scoring指定模型评估指标,例如:'accuracy'表示使用准确率作为评估指标。
grid_searchcv.fit(x_train, y_train)
print(grid_searchcv.best_params_)
print(grid_searchcv.best_score_)
#print(grid_searchcv.cv_results_)
print(grid_searchcv.best_index_)
print(grid_searchcv.best_estimator_)
best_clf = grid_searchcv.best_estimator_
best_clf.fit(x_train,y_train)
print(best_clf.score(x_test,y_test))
#结果:
{'criterion': 'entropy', 'max_depth': 8, 'min_impurity_decrease': 0.0, 'min_samples_leaf': 16}
0.9286666666666668
15215
DecisionTreeClassifier(criterion='entropy', max_depth=8, min_samples_leaf=16)
0.94

1.4. 可视化决策树

得到可视化决策树文件:

df = pd.DataFrame(data = data,columns=["x1","x2"])
from sklearn.tree import export_graphviz
from graphviz import Source
dot_data = export_graphviz(best_clf, out_file=None, feature_names=df.columns)
graph = Source(dot_data)
graph.format = 'png'file
graph.render(filename='file_image', view=True)

在这里插入图片描述

二、随机森林算法

2.1. 原理

随机森林是一种集成学习方法,它通过构建多个决策树来进行分类或回归,
随机森林的基本原理:

  1. 随机采样:从原始训练集中随机选择一定数量的样本,作为每个决策树的训练集。

  2. 随机特征选择:对于每个决策树的每个节点,从所有特征中随机选择一部分特征进行评估,选择最佳的特征进行划分。

  3. 构建决策树:根据随机采样和随机特征选择的方式,构建多个决策树。

  4. 预测:对于分类问题,通过投票或取平均值的方式,将每个决策树的预测结果进行集成;对于回归问题,将每个决策树的预测结果取平均值。

随机森林函数中的超参数:

  1. n_estimators:它表示随机森林中决策树的个数。

  2. min_samples_split:内部节点分裂所需的最小样本数

  3. min_samples_leaf:叶节点所需的最小样本数

  4. max_features:每个决策树考虑的最大特征数量

  5. n_jobs :表示允许使用处理器的数量

  6. criterion :gini 或者entropy (default = gini)

  7. random_state:随机种子

2.2. 代码实现

import numpy as np
import pandas as pd
from sklearn.datasets import make_moons
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
#训练随机森林分类模型
np.random.seed(42)
raw_data = make_moons(n_samples= 2000,noise= 0.25,random_state=42)
data,target = raw_data[0],raw_data[1]x_train, x_test, y_train, y_test = train_test_split(data, target,random_state=42)
classfier = RandomForestClassifier(random_state= 42)
classfier.fit(x_train,y_train)
score = classfier.score(x_test,y_test)
print(score)
0.93#网格搜索获取最优参数
from sklearn.model_selection import GridSearchCV
param_grids = {"criterion": ["gini","entropy"],"max_depth":np.arange(1,10),"min_samples_leaf":np.arange(1,10),"max_features": np.arange(1,3)
}
grid_search = GridSearchCV(RandomForestClassifier(),param_grid=param_grids,n_jobs= 1,scoring="accuracy",cv=5)
grid_search.fit(x_train,y_train)
print(grid_search.best_params_)  #最优的参数
print(grid_search.best_score_)	#最好的得分
best_clf = grid_search.best_estimator_   #最优的模型
print(best_clf)
best_clf.fit(x_train,y_train)
print(best_clf.score(x_test,y_test))  #查看测试集在最优模型上的得分
#结果:
{'criterion': 'entropy', 'max_depth': 9, 'max_features': 1, 'min_samples_leaf': 7}
0.9506666666666665
RandomForestClassifier(criterion='entropy', max_depth=9, max_features=1,min_samples_leaf=7)
0.932

三、补充(过拟合与欠拟合)

过拟合指的是模型在训练集上表现得很好,但在测试集或新数据上表现不佳的情况。
过拟合通常发生在模型过于复杂或训练数据过少的情况下

欠拟合指的是模型无法很好地拟合训练集,导致在训练集和测试集上的误差都很高。
欠拟合通常发生在模型过于简单或训练数据过于复杂的情况下


总结

总之,决策树和随机森林都是基于树结构的机器学习算法,具有可解释性和特征选择的能力。随机森林是多个决策树的集成模型,引入了随机性并通过投票或平均来得出最终预测结果,可以有效降低噪声干扰,提高模型的准确性与稳定性,但是增加了计算量。

锦帽貂裘,千骑卷平冈

–2023-9-1 筑基篇

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/66664.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

牛客网刷题

牛客网刷题-C&C 2023年9月3日15:58:392023年9月3日16:37:01 2023年9月3日15:58:39 2023年9月3日16:37:01 整型常量和实型常量的区别

华为静态路由配置实验(超详细讲解+详细命令行)

系列文章目录 华为数通学习(7) 前言 一,静态路由配置 二,网络地址配置 AR1的配置: AR2的配置: AR3的配置: 三,测试是否连通 AR1的配置: 讲解: AR2的配置&#…

CentOS 7.6源码安装gdb 12.1

参考文章:《GDB调试-从安装到使用》 gdb --version看一下当前gdb的版本,可以看到是7.6.1-120.el7。 https://www.sourceware.org/gdb/download/可以下载gdb源码。 sudo nohup wget https://sourceware.org/pub/gdb/releases/gdb-12.1.tar.xz &下…

跨站请求伪造(CSRF)攻击与防御原理

跨站请求伪造(CSRF) 1.1 CSRF原理 1.1.1 基本概念 跨站请求伪造(Cross Site Request Forgery,CSRF)是一种攻击,它强制浏览器客户端用户在当前对其进行身份验证后的Web 应用程序上执行非本意操作的攻击&a…

什么是malloxx勒索病毒,服务器中malloxx勒索病毒了怎么办?

Malloxx勒索病毒是一种新型的电脑病毒,它通过加密用户电脑中的重要文件数据来威胁用户,并以此勒索钱财。这种病毒并不是让用户的电脑瘫痪,而是以非常独特的方式进行攻击。在感染了Malloxx勒索病毒后,它会加密用户服务器中的数据&a…

深入探讨Java虚拟机(JVM):执行流程、内存管理和垃圾回收机制

目录 什么是JVM? JVM 执行流程 JVM 运行时数据区 堆(线程共享) Java虚拟机栈(线程私有) 什么是线程私有? 程序计数器(线程私有) 方法区(线程共享) JDK 1.8 元空…

【LeetCode-面试经典150题-day18】

目录 17.电话号码的字母组合 77.组合 46.全排列 52.N皇后Ⅱ 17.电话号码的字母组合 题意: 给定一个仅包含数字 2-9 的字符串,返回所有它能表示的字母组合。答案可以按 任意顺序 返回。 给出数字到字母的映射如下(与电话按键相同&#xf…

mysql:[Some non-transactional changed tables couldn‘t be rolled back]不支持事务

1. mysql创建表时默认引擎MyIsam,因此不支持事务的操作; 2. 修改mysql的默认引擎,可以使用show engine命令查看支持的引擎: 【my.conf详情说明】my.cnf配置文件注释详解_xiaolin01999的博客-CSDN博客 3. 原来使用MyIsam创建的表…

Linux系统中驱动面试分享

​ 1、驱动程序分为几类? 字符设备驱动 块设备驱动 网络设备驱动 2、字符设备驱动需要实现的接口通常有哪些 open、close、read、write、ioctl等接口。 3、主设备号与次设备号的作用 主设备号和次设备号是用来标识系统中的设备的,主设备号用来标识…

postgresql并行查询(高级特性)

######################## 并行查询 postgresql和Oracle一样支持并行查询的,比如select、update、delete大事无开启并行功能后,能够利用多核cpu,从而充分发挥硬件性能,提升大事物的处理效率。 pg在9.6的版本之前是不支持的并行查询的,从9.6开始支持并行查询,但是功能非常…

OpenCV(十六):高斯图像金字塔

目录 1.高斯图像金字塔原理 2.高斯图像金字塔实现 1.高斯图像金字塔原理 高斯图像金字塔是一种用于多尺度图像表示和处理的重要技术。它通过对图像进行多次高斯模糊和下采样操作来生成不同分辨率的图像层级,每个层级都是原始图像的模糊和降采样版本。 以下是高斯…

count(1)与count(*)的区别、ROUND函数

部分问题 1. count(1)与count(*)的区别2. ROUND函数3. SQL19 分组过滤练习题4. Mysql bigdecimal 与 float的区别5. 隐式内连接与显示内连接 (INNER可省略) 1. count(1)与count(*)的区别 COUNT(*)和COUNT(1)有什么区别? count(*)包括了所有…

图表背后的故事:数据可视化的威力与影响

数据可视化现在在市场上重不重要?这已经不再是一个简单的问题,而是一个不可忽视的现实。随着信息时代的来临,数据已经成为企业和组织的核心资产,而数据可视化则成为释放数据价值的重要工具。 在当今竞争激烈的商业环境中&#xf…

小赢科技,寻找金融科技核心价

如果说金融是经济的晴雨表,是通过改善供给质量以提高经济质量的切入口,那么金融科技公司,就是这一切行动的推手。上半年,社会经济活跃程度提高背后,金融科技公司既是奉献者,也是受益者。 8月29日&#xff0…

数据艺术:精通数据可视化的关键步骤

数据可视化是将复杂数据转化为易于理解的图表和图形的过程,帮助我们发现趋势、关联和模式。同时数据可视化也是数字孪生的基础,本文小编带大家用最简单的话语为大家讲解怎么制作一个数据可视化大屏,接下来跟随小编的思路走起来~ 1.数据收集和…

华为云Stack的学习(四)

五、Service OM资源管理 1.Service OM简介 1.1 Service OM介绍 在华为云Stack解决方案中,Service OM是FusionSphere OpenStack的操作管理界面,是资源池(计算、存储、网络)以及基础云服务的管理工具。 1.2 Service OM定位 Serv…

Apifox(1)比postman更优秀的接口自动化测试平台

Apifox介绍 Apifox 是 API 文档、API 调试、API Mock、API 自动化测试一体化协作平台,定位 Postman Swagger Mock JMeter。通过一套系统、一份数据,解决多个系统之间的数据同步问题。只要定义好 API 文档,API 调试、API 数据 Mock、API 自…

Mysql数据库(3)—架构和日志

Mysql的架构设计 Mysql分为Server层和存储引擎层: Server层 主要包括连接器、查询缓存、分析器、优化器、执行器等,涵盖 MySQL 的大多数核心服务功能,以及所有的内置函数(如日期、时间、数学和加密函数等)&#xff…

【GitHub 个人主页】适应于初学者的自定义个人主页设置

▚ 00 自定义GitHub主页的教程 🍁 【保姆级教程】手把手教你用github制作学术个人主页(学者必备) ▚ 01 优秀案例 1.1 添加Stats 🎃 网址为:Stats & Most Used Langs

【一等奖方案】大规模金融图数据中异常风险行为模式挖掘赛题「NUFE」解题思路

第十届CCF大数据与计算智能大赛(2022 CCF BDCI)已圆满结束,大赛官方竞赛平台DataFountain(简称DF平台)正在陆续释出各赛题获奖队伍的方案思路,欢迎广大数据科学家交流讨论。 本方案为【大规模金融图数据中…