如何用sklearn对随机森林调参

文章目录

  • 一、概述
  • 二、实操
    • 1、导入相关包
    • 2、导入乳腺癌数据集,建立模型
    • 3、调参
  • 三、总结

Link:https://zhuanlan.zhihu.com/p/126288078
Author:陈罐头

一、概述

sklearn是目前python中十分流行的用来实现机器学习的第三方包,其中包含了多种常见算法如:决策树,逻辑回归、集成算法(如随机森林)等等。

本文将使用sklearn自带的乳腺癌数据集,建立随机森林,并基于 泛化误差(Genelization Error)模型复杂度的关系来对模型进行调参,从而使模型获得更高的得分。

泛化误差是机器学习中,用来衡量模型在未知数据上的准确率的指标,其与模型复杂度的关系如下图所示:
在这里插入图片描述
当模型复杂度不足时,机器学习不足,会出现欠拟合现象,泛化误差变大;当复杂度逐渐提高到最佳模型复杂度时,泛化误差会达到最低点(即最高准确度);若复杂度仍在提高,泛化误差从最小值开始逐渐增大,出现过拟合现象。

因此,我们的目的,是通过不断调参来不断调整模型复杂度,尽可能地接近泛化误差最低点

二、实操

1、导入相关包

from sklearn.datasets import load_breast_cancer
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import cross_val_score
from sklearn.model_selection import GridSearchCV
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

2、导入乳腺癌数据集,建立模型

由于sklearn自带的数据集已经很工整了,所以无需做预处理,直接使用。

# 导入乳腺癌数据集
data = load_breast_cancer()# 建立随机森林
rfc = RandomForestClassifier(n_estimators=100, random_state=90)用交叉验证计算得分
score_pre = cross_val_score(rfc, data.data, data.target, cv=10).mean()
score_pre

在这里插入图片描述

初始得分

3、调参

随机森林主要的参数有n_estimators(子树的数量)、max_depth(树的最大生长深度)、min_samples_leaf(叶子的最小样本数量)、min_samples_split(分支节点的最小样本数量)、max_features(最大选择特征数)。

它们对随机森林模型复杂度的影响如下图所示:
在这里插入图片描述

可以看到,n_estimators是影响程度最大的参数,我们先对其进行调整:

# 调参,绘制学习曲线来调参n_estimators(对随机森林影响最大)
score_lt = []
# 每隔10步建立一个随机森林,获得不同n_estimators的得分
for i in range(0,200,10):rfc = RandomForestClassifier(n_estimators=i+1, random_state=90)score = cross_val_score(rfc, data.data, data.target, cv=10).mean()score_lt.append(score)
score_max = max(score_lt)
print('最大得分:{}'.format(score_max),'子树数量为:{}'.format(score_lt.index(score_max)*10+1))
# 绘制学习曲线
x = np.arange(1,201,10)
plt.subplot(111)
plt.plot(x, score_lt, 'r-')
plt.show()

在这里插入图片描述
如图所示,当n_estimators从0开始增大至21时,模型准确度有肉眼可见的提升。这也符合随机森林的特点:在一定范围内,子树数量越多,模型效果越好。而当子树数量越来越大时,准确率会发生波动,当取值为41时,获得最大得分。

接下来,我们在将取值范围缩小至41左右,以获得更好的取值。

# 在41附近缩小n_estimators的范围为30-49
score_lt = []
for i in range(30,50):rfc = RandomForestClassifier(n_estimators=i,random_state=90)score = cross_val_score(rfc, data.data, data.target, cv=10).mean()score_lt.append(score)
score_max = max(score_lt)
print('最大得分:{}'.format(score_max),'子树数量为:{}'.format(score_lt.index(score_max)+30))# 绘制学习曲线
x = np.arange(30,50)
plt.subplot(111)
plt.plot(x, score_lt,'o-')
plt.show()

在这里插入图片描述

如图所示,当n_estimators=45时,获得最大得分score_max=0.9719,相较于score_pre提升0.005
在这里插入图片描述

由此我们发现:当n_estimators100减小至45时(模型复杂度由大到小),模型准确度提升了(泛化误差减小),说明在泛化误差图中,模型往左移动了!

因此,接下来的调参方向是使模型复杂度减小的方向,从而接近泛化误差最低点。我们使用能使模型复杂度减小,并且影响程度排第二的max_depth

# 建立n_estimators为45的随机森林
rfc = RandomForestClassifier(n_estimators=45, random_state=90)# 用网格搜索调整max_depth
param_grid = {'max_depth':np.arange(1,20)}
GS = GridSearchCV(rfc, param_grid, cv=10)
GS.fit(data.data, data.target)best_param = GS.best_params_
best_score = GS.best_score_
print(best_param, best_score)

在这里插入图片描述

如图所示,最佳深度为11,最大得分为0.9718,竟然比不调整深度的得分0.9719还低,难道我们刚才就已经十分接近最低泛化误差了吗?

本着严谨的态度,我们再进行调整。调整max_depth使模型复杂度减小,却获得了更低的得分,因此接下来我们需要朝着复杂度增大的方向调整。我们在n_estimators=45max_depth=11的情况下,对唯一能够增加模型复杂度的参数max_features进行调整:
在这里插入图片描述

查看数据集大小,发现一共有30列特征,由于max_features默认取值特征数量的开平方值,因此我们从5开始调整:

# 用网格搜索调整max_features
param_grid = {'max_features':np.arange(5,31)}rfc = RandomForestClassifier(n_estimators=45,random_state=90,max_depth=11)
GS = GridSearchCV(rfc, param_grid, cv=10)
GS.fit(data.data, data.target)
best_param = GS.best_params_
best_score = GS.best_score_
print(best_param, best_score)     

在这里插入图片描述

输出结果为5,和默认值一样。得分为0.9718,仍然小于0.9719。因此,仅需n_estimators=45就能使模型的准确率达到最高0.9719,相较于初始得分0.9667,提升0.005,最接近最小泛化误差,调参工作到此结束。

三、总结

总结一下在sklearn中调参的思路:

① 基于泛化误差模型复杂度的关系来进行调参;

② 根据对模型的影响程度,由大到小对参数排序,并确定哪些参数会使模型复杂度减小,哪些会增大;

③ 依次选择合适的参数,通过绘制学习曲线或网格搜索的方法调参,直到找到最大准确得分。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/139215.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

凯美瑞 vs 太空船:Web3 游戏生长的两条路径

撰文:Teng Yan(0xPrismatic),Delphi Digital 研究员 编译:TinTinLand 来源:https://0xprismatic.substack.com/p/my-short-web3-gaming-thesis 经常有人问我关于 Web3 游戏的看法,所以我想以这…

什么是数据库事务、事务的ACID、怎么设置/禁止自动提交?

数据库事务及ACID 数据库事务是指作为单个逻辑工作单元执行的一组操作。这组操作要么全部成功地执行,要么全部不执行,不允许出现部分执行的情况。数据库事务通常需要满足ACID属性,即原子性(Atomicity)、一致性&#x…

某城高速综合管控大数据大屏可视化【可视化项目案例-04】

🎉🎊🎉 你的技术旅程将在这里启航! 🚀🚀 本文选自专栏:可视化技术专栏100例 可视化技术专栏100例,包括但不限于大屏可视化、图表可视化等等。订阅专栏用户在文章底部可下载对应案例源码以供大家深入的学习研究。 🎓 每一个案例都会提供完整代码和详细的讲解,不…

大厂真题:【哈希表】美团2023秋招-小美的排列询问

题目描述与示例 题目描述 小美拿到了一个排列。她想知道在这个排列中,x和y是否是相邻的。你能帮帮她吗? 排列是指一个长度为n的数组,其中 1 到 n 每个元素恰好出现一次。 输入描述 第一行输入一个正整数n,代表排列的长度。 …

【复杂网络建模】——基于关联矩阵构建超图网络

目录 一、复杂网络介绍 二、常规的构建方法 三、基于关联矩阵构建超图 一、复杂网络介绍 复杂网络是指由大量相互连接的元素或节点构成的网络,这些节点之间的连接关系通常是非常复杂和多样化的。这种网络结构通常用图论来表示,其中节点表示网络中的个体或元素,边表示它们…

Python开源项目PGDiff——人脸重建(Face Restoration),模糊清晰、划痕修复及黑白上色的实践

python ansconda 等的下载、安装等请参阅: Python开源项目CodeFormer——人脸重建(Face Restoration),模糊清晰、划痕修复及黑白上色的实践https://blog.csdn.net/beijinghorn/article/details/134334021 友情提示: …

CSS3 过度效果、动画、多列

一、CSS3过度&#xff1a; CSS3过渡是元素从一种样式逐渐改变为另一种的效果。要实现这一点&#xff0c;必须规定两相内容&#xff1a;指定要添加效果的CSS属性&#xff1b;指定效果的持续时间。如果为指定持续时间&#xff0c;transition将没有任何效果。 <style> div…

Ubuntu(WSL2) 安装最新版的 cmake

Ubuntu(WSL) 安装最新版的 cmake 具体流程如下&#xff1a; 步骤一&#xff1a;卸载原本的 cmake sudo apt-get remove cmake 步骤二&#xff1a; sudo apt-get update sudo apt-get install apt-transport-https ca-certificates gnupg software-properties-common wget 步…

nsd的资料

nsd是一款开源的DNS服务器应用。 近期参与项目过程中&#xff0c;涉及到DNS业务&#xff0c;结果被打的满头包。 虽然在校学习时就知道DNS协议&#xff0c;但从业这么多年&#xff0c;对于DNS协议的理解其实一直处于一知半解的状态。 当前处理问题时&#xff0c;接触到了nsd&am…

Clickhouse 学习笔记(6)—— ClickHouse 分片集群

前置知识&#xff1a; Clickhouse学习笔记&#xff08;5&#xff09;—— ClickHouse 副本-CSDN博客 与副本对比&#xff1a; 副本虽然能够提高数据的可用性&#xff0c;降低丢失风险&#xff0c;但是每台服务器实际上必须容纳全量数据&#xff0c;对数据的横向扩容没有解决 …

Redis学习笔记8:基于springboot的Lettuce redis客户端connectTimeout、timeout、shutdownTimeout

一个对springboot redis框架进行重写&#xff0c;支持lettuce、jedis、连接池、同时连接多个集群、多个redis数据库、开发自定义属性配置的开源SDK <dependency><groupId>io.github.mingyang66</groupId><artifactId>emily-spring-boot-redis</art…

基于工业智能网关的汽车充电桩安全监测方案

近年来&#xff0c;我国新能源汽车产业得到快速发展&#xff0c;电动车产量和销量都在持续增长&#xff0c;不仅国内市场竞争激烈&#xff0c;而且也远销海外&#xff0c;成为新的经济增长点。但与此同时&#xff0c;充电设施的运营却面临着安全和效率的双重挑战。 当前的充电桩…

Linux开发工具之编辑器vim

文章目录 1.vim是啥?1.1问问度娘1.2自己总结 2.vim的初步了解2.1进入和退出2.2vim的模式1.介绍2.使用 3.vim的配置3.1自己配置3.2下载插件3.3安装大佬配置好的文件 4.程序的翻译 1.vim是啥? 1.1问问度娘 1.2自己总结 vi/vim都是多模式编辑器&#xff0c;vim是vi的升级版本&a…

【Excel】补全单元格值变成固定长度

我们知道股票代码都为6位数字&#xff0c;但深圳中小板代码前面以0开头&#xff0c;数字格式时前面的0会自动省略&#xff0c;现在需要在Excel表格补全它。如下图&#xff1a; 这时我们需要用到特殊的函数&#xff1a;TEXT或者RIGHT TEXT函数是Excel中一个非常有用的函数。TEX…

UnRaid安装安装仓库管理系统GreaterWMS

文章目录 0、前言1、安装流程1.1、克隆GreaterWMS项目到UnRaid本地目录1.2、修改项目前后端端口1.3、修改baseurl1.4、修改Nginx.conf配置文件1.5、安装依赖插件1.5.1、Docker Compose Manager插件1.5.2、Python3环境 1.6、创建GreaterWMS容器1.6.1、为前后端启动脚本赋执行权限…

【数据结构】归并排序

#include<iostream>using namespace std;void Merge(int* arr,int left,int right,int mid, int*& tmparr) {int begin1 left, end1 mid;int begin2 mid 1, end2 right;int tmpi left;//下面合并两个数组为一个有序数组&#xff08;升序&#xff09;&#xff1…

Protobuf简介

Protobuf 定义&#xff1a;可序列化的数据交换格式 用途&#xff1a;用于通信协议&#xff08;数据&#xff09;&#xff0c;数据存储等 特点&#xff1a;语言无关&#xff0c;平台无关&#xff0c;高效&#xff0c;扩展性好。 相近产品&#xff1a;XML/JSON 优点&#xff1a;…

Amazon EC2 Serial Console 现已在其他亚马逊云科技区域推出

即日起&#xff0c;交互式 EC2 Serial Console 现也在以下区域推出&#xff1a;中东&#xff08;巴林&#xff09;、亚太地区&#xff08;雅加达&#xff09;、非洲&#xff08;开普敦&#xff09;、中东&#xff08;阿联酋&#xff09;、亚太地区&#xff08;香港&#xff09;…

AI系统ChatGPT源码+详细搭建部署教程+AI绘画系统+支持GPT4.0+Midjourney绘画+已支持OpenAI GPT全模型+国内AI全模型

一、AI创作系统 SparkAi创作系统是基于OpenAI很火的ChatGPT进行开发的Ai智能问答系统和Midjourney绘画系统&#xff0c;支持OpenAI-GPT全模型国内AI全模型。本期针对源码系统整体测试下来非常完美&#xff0c;可以说SparkAi是目前国内一款的ChatGPT对接OpenAI软件系统。那么如…

JDBC(二)

第4章 操作BLOB类型字段 4.1 MySQL BLOB类型 MySQL中&#xff0c;BLOB是一个二进制大型对象&#xff0c;是一个可以存储大量数据的容器&#xff0c;它能容纳不同大小的数据。 插入BLOB类型的数据必须使用PreparedStatement&#xff0c;因为BLOB类型的数据无法使用字符串拼接写…