机器学习算法应用——CART决策树

CART决策树(4-2)

CART(Classification and Regression Trees)决策树是一种常用的机器学习算法,它既可以用于分类问题,也可以用于回归问题。CART决策树的主要原理是通过递归地将数据集划分为两个子集来构建决策树。在分类问题中,CART决策树通过选择一个能够最大化分裂后各个子集纯度提升的特征进行分裂,从而将数据划分为不同的类别。

CART决策树的构建过程包括以下几个步骤:

  1. 特征选择:从数据集中选择一个最优特征,用于划分数据集。最优特征的选择基于某种准则,如基尼指数(Gini Index)或信息增益(Information Gain)。
  2. 决策树生成:根据选定的最优特征,将数据集划分为两个子集,并递归地在每个子集上重复上述过程,直到满足停止条件(如子集大小小于某个阈值、所有样本属于同一类别等)。
  3. 剪枝:为了避免过拟合,可以对生成的决策树进行剪枝操作,即删除一些子树或叶子节点,以提高模型的泛化能力。

CART决策树的优点包括:

  1. 计算简单,易于理解,可解释性强。
  2. 不需要预处理,不需要提前归一化,可以处理缺失值和异常值。
  3. 既可以处理离散值也可以处理连续值。
  4. 既可以用于分类问题,也可以用于回归问题。

然而,CART决策树也存在一些缺点:

  1. 不支持在线学习,当有新样本产生时,需要重新构建决策树模型。
  2. 容易出现过拟合现象,生成的决策树可能对训练数据有很好的分类能力,但对未知的测试数据却未必有很好的分类能力。
  3. 对于一些复杂的关系,如异或关系,CART决策树可能难以学习。

CART决策树在许多领域都有广泛的应用,如推荐系统中的商品推荐模型、金融风控中的信用评分和欺诈检测、医疗诊断中的疾病预测等。此外,CART决策树还可以用于社交媒体情感分析等领域。

  1. 数据

使用Universal Bank数据集。

示例:

        

IDAgeExperienceIncomeZIP CodeFamilyCCAvgEducationMortgagePersonal LoanSecurities AccountCD AccountOnlineCreditCard
1251499110741.61001000
24519349008931.51001000
339151194720111000000
43591009411212.72000000
53584591330412000001
63713299212140.4215500010
75327729171121.52000010
85024229394310.33000001
93510819008930.6210400010
103491809302318.93010000
1165391059471042.43000000
12295459027730.12000010
1348231149310623.83001000
145932409492042.52000010
15674111291741121001000
166030229505411.53000011
1738141309501044.7313410000
184218819430542.41000000
1946211939160428.13010000
205528219472010.52001001
215631259401540.9211100010
2257276390095323000010
23295629027711.2126000010
244418439132020.7116301000
2536111529552123.9115900001
264319299430530.519700010
274016839506440.23000000
2846201589006412.41000011
295630489453912.23000011
3038131199410413.32010111
315935359310611.2312200010
3240162994117122000010
335328419480120.6319300000
34306189133030.93000000
35315509403541.83000010
364824819264730.71000000
3759351219472012.91000001
385125719581411.4319800000
39421814194114353011110
403813809411540.7328500010
415732849267231.63001000
42349609412232.31000000
433271329001941.1241210010
443915459561610.71000010
4546201049406515.71000011
465731529472042.51000001
473914439501430.7215300010
4837121949138040.2321111111
495626819574724.53000001
504016499237311.81000001
5132889209340.72001010
5261371319472012.91000010
53306729400510.1120700000
5450261909024532.1324010010
55295449581910.23000010
56411713994022281000010
575530299400530.12001110
5856311319561621.23010000
59282939406520.21000000
603151889132024.5145500000
614924399040431.72001010
6247211259340715.7111201000
6342182290089111000000
6442173294523402000010
6547231059002423.31000000
6659351319136013.81000011
6762361059567022.8133600000
685323459512342313201000
694721609340732.11000011
705329209004540.21000010
7142181159133513.51000001
7253296993907412000010
73442013092007151000001
7441168594606143000011
752831359461123.31000001
763171359490143.82010111

注意:数据集中的编号(ID)和邮政编码(ZIP CODE)特征因为在分类模型中无意义,所以在数据预处理阶段将它们删除。

  1. 使用CART决策树对数据进行分类
  1. 使用留出法划分数据集,训练集:测试集为7:3。
# 使用留出法划分数据集,训练集:测试集为7:3
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
  1. 使用CART决策树对训练集进行训练
# 使用CART决策树对训练集进行训练,深度限制为10层
model = DecisionTreeClassifier(max_depth=10)
model.fit(X_train, y_train)

决策树的深度限制为10层,max_depth=10。

  1. 使用训练好的模型对测试集进行预测并输出预测结果模型准确度
# 使用训练好的模型对测试集进行预测
y_pred = model.predict(X_test)# 输出预测结果和模型准确度
accuracy = accuracy_score(y_test, y_pred)
print("模型准确度:", accuracy)
  1. 可视化训练好的CART决策树模型
# 可视化训练好的CART决策树模型
dot_data = export_graphviz(model, out_file=None,feature_names=X.columns,class_names=['0', '1'],filled=True, rounded=True,special_characters=True)
graph = graphviz.Source(dot_data)
graph.render("Universal_Bank_CART")  # 保存为PDF文件
  1. 安装graphviz模块

首先在windows系统中安装graphviz模块

32位系统使用windows_10_cmake_Release_graphviz-install-10.0.1-win32.exe

64位系统使用windows_10_cmake_Release_graphviz-install-10.0.1-win64.exe

注意:安装时使用下图中圈出的选项

安装完成后使用pip install graphviz指令在python环境中安装graphviz库。

  1. 使用graphviz模块可视化模型
# 可视化训练好的CART决策树模型
dot_data = export_graphviz(model, out_file=None,feature_names=X.columns,class_names=['0', '1'],filled=True, rounded=True,special_characters=True)
graph = graphviz.Source(dot_data)
graph.render("Universal_Bank_CART")  # 保存为PDF文件

完整代码:

# 导入所需的库
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import accuracy_score
from sklearn.tree import export_graphviz
import graphviz# 读取数据集
data = pd.read_csv("universalbank.csv")# 数据预处理:删除无意义特征
data = data.drop(columns=['ID', 'ZIP Code'])# 划分特征和标签
X = data.drop(columns=['Personal Loan'])
y = data['Personal Loan']# 使用留出法划分数据集,训练集:测试集为7:3
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)# 使用CART决策树对训练集进行训练,深度限制为10层
model = DecisionTreeClassifier(max_depth=10)
model.fit(X_train, y_train)# 使用训练好的模型对测试集进行预测
y_pred = model.predict(X_test)# 输出预测结果和模型准确度
accuracy = accuracy_score(y_test, y_pred)
print("模型准确度:", accuracy)# 可视化训练好的CART决策树模型
dot_data = export_graphviz(model, out_file=None,feature_names=X.columns,class_names=['0', '1'],filled=True, rounded=True,special_characters=True)
graph = graphviz.Source(dot_data)
graph.render("Universal_Bank_CART6")  # 保存为PDF文件

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/9862.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

力扣 256. 粉刷房子 LCR 091. 粉刷房子 python AC

动态规划 class Solution:def minCost(self, costs):row, col len(costs), 3dp [[0] * col for _ in range(row 1)]for i in range(1, row 1):for j in range(col):dp[i][j] costs[i - 1][j - 1]if j 0:dp[i][j] min(dp[i - 1][1], dp[i - 1][2])elif j 1:dp[i][j] m…

【QT教程】QT6硬件高级编程实战案例 QT硬件高级编程

QT6硬件高级编程实战案例 使用AI技术辅助生成 QT界面美化视频课程 QT性能优化视频课程 QT原理与源码分析视频课程 QT QML C扩展开发视频课程 免费QT视频课程 您可以看免费1000个QT技术视频 免费QT视频课程 QT统计图和QT数据可视化视频免费看 免费QT视频课程 QT性能优化视频免…

【GoLang基础】通道(channel)是什么?

问题引出: Go语言中的通道(channel)是什么? 解答: 通道(channel)是 Go 语言中用于协程(goroutine)之间通信和同步的机制。通道提供了一种安全、简单且高效的方式&#x…

idea运行SpringBoot项目爆红提示出现:Java HotSpot(TM) 64-Bit Server VM warning...让我来看看~

在运行SpringBoot项目的时候,发现总有这个警告提示出现,有点强迫症真的每次运行项目都很难受啊!那么今天便来解决这个问题! 先来看一下提示内容:Java HotSpot(TM) 64-Bit Server VM warning: Options -Xverify:none an…

FreeRTOS标准库例程代码

1.设备STM32F103C8T6 2.工程模板 单片机: 部分单片机的程序例程 - Gitee.comhttps://gitee.com/lovefoolnotme/singlechip/tree/master/STM32_FREERTOS/1.%E5%B7%A5%E7%A8%8B%E6%A8%A1%E6%9D%BF 3.代码 1-FreeRTOS移植模板 #include "system.h" #include "…

C语言编程中布尔设置位掩码示例

在C语言编程中,当你想使用整数(通常是unsigned int或uint8_t, uint16_t, uint32_t等)的位来存储多个布尔设置时,你会使用位掩码。每个设置对应于整数中的一个位,你可以通过位操作(如按位与&、按位或|、…

Rust:用 Warp 库实现 Restful API 的简单示例

直接上代码: 1、源文件 Cargo.toml [package] name "xcalc" version "0.1.0" edition "2021"# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html[dependencies] warp "…

uniap之微信公众号支付

近来用uniapp开发H5的时候,需要接入支付,原来都是基于后端框架来做的,所以可谓是一路坑中过,今天整理下大致流程分享给大家。 先封装util.js,便于后面调用 const isWechat function(){return String(navigator.userA…

队列的实现(使用C语言)

完整代码链接:DataStructure: 基本数据结构的实现。 (gitee.com) 目录 一、队列的概念: 二、队列的实现: 使用链表实现队列: 1.结构体设计: 2.初始化: 3.销毁: 4.入队: 5.…

OC foudation框架(下)的学习

OCfoudation框架(下) 前面学习了有关OCfoudation框架的部分内容,我们现在对于后面的内容继续学习。 文章目录 OCfoudation框架(下)数组(NSArray和NSMutableArray)对集合元素整体调用方法排序使用…

会赚钱的人都在做这件事:你了解吗?

在我们日常生活的点滴中,以及在各种场合的交互中,利他思维始终扮演着不可或缺的角色。当我们追求合作与共赢时,单方面的自我立场显然是不够的,真正的关键在于换位思考,寻找并满足对方的需求。 互利互赢的核心理念正是利…

设置docker容器时区

设置docker容器时区 查看当前系统时间 1.1 查看当前系统版本 cat /etc/issue1.2 查看当前系统时间 date查看镜像默认时间 2.1 alpine镜像 sudo docker run -it --rm alpine date2.2 ubuntu镜像 sudo docker run -it --rm ubuntu date2.3 centos镜像 sudo docker run -it --rm …

虚拟知识付费系统源码推荐,在线教育双十一怎么做活动?

又是一年光棍节,啊不是,剁手节。小伙伴们早就摩拳擦掌准备剁手了,这个时候,几乎所有线上平台都行动起来了,而在线教育行业也没有闲着。如今,双十一已经成为了各大在线教育公司用来变现的一个大杀器&#xf…

ruoyi-vue-pro 使用记录(4)

ruoyi-vue-pro 使用记录(4) CRM数据库线索客户商机合同回款产品其他 CRM 文档 主要分为 6 个核心模块:线索、客户、商机、合同、回款、产品。 线索管理以 crm_clue 作为核心表客户管理以 crm_customer 作为核心表商机管理以 crm_business 作…

JavaScript数组(Array)方法 - toReversed、toSorted、toSpliced

最近发现几个数组方法,是一些常规方法的升级版,比较有意思,分享给大家 文章目录 一、温故二、知新toReversedtoSortedtoSpliced 一、温故 我们先来回顾几个比较常用的方法:reverse,sort,splice众所周知&a…

luceda ipkiss教程 69:导出器件或者线路的三维模型

ipkiss 3.12版加入write_obj函数,可以直接输出器件的三维模型。 如,输出自定义的mmi的三维模型: 代码如下: from si_fab import all as pdk from ipkiss3 import all as i3class MMI1x2(i3.PCell):"""MMI with …

kaldi学习参考

HMM模型 https://www.cnblogs.com/baixf-xyz/p/16777438.htmlhttps://www.cnblogs.com/baixf-xyz/p/16777438.htmlGMM-HMM 基于GMM-HMM的语音识别系统https://www.cnblogs.com/baixf-xyz/p/16777439.html https://www.cnblogs.com/baixf-xyz/p/16777426.htmlhttps://www.cnbl…

全栈开发之路——前端篇(6)生命周期和自定义hooks

全栈开发一条龙——前端篇 第一篇:框架确定、ide设置与项目创建 第二篇:介绍项目文件意义、组件结构与导入以及setup的引入。 第三篇:setup语法,设置响应式数据。 第四篇:数据绑定、计算属性和watch监视 第五篇 : 组件…

码一点网站

Linux命令查询网站 https://www.lzltool.com/LinuxCommand/Index 小林 x 图解计算机基础 https://xiaolincoding.com/ 代码随想录 https://programmercarl.com/ 可用于爬虫 https://books.toscrape.com/ 数据结构可视化 https://www.cs.usfca.edu/~galles/visualization/ …

fastText-文本分类

fastText介绍 fastText是一个快速文本分类算法,与基于神经网络的分类算法相比有两大优点: 1、fastText在保持高精度的情况下加快了训练速度和测试速度 2、fastText不需要预训练好的词向量,fastText会自己训练词向量 3、fastText两个重要的优化:Hierarchical Softmax、N-gr…