随机森林超参数的网格优化(机器学习的精华--调参)

随机森林超参数的网格优化(机器学习的精华–调参)

随机森林各个参数对算法的影响

影响力参数
⭐⭐⭐⭐⭐
几乎总是具有巨大影响力
n_estimators(整体学习能力)
max_depth(粗剪枝)
max_features(随机性)
⭐⭐⭐⭐
大部分时候具有影响力
max_samples(随机性)
class_weight(样本均衡)
⭐⭐
可能有大影响力
大部分时候影响力不明显
min_samples_split(精剪枝)
min_impurity_decrease(精剪枝)
max_leaf_nodes(精剪枝)
criterion(分枝敏感度)

当数据量足够大时,几乎无影响
random_state
ccp_alpha(结构风险)

探索的第一步(学习曲线)

先看看n_estimators的学习曲线(当然我们也可以看一下其他参数的学习曲线)

#参数潜在取值,由于现在我们只调整一个参数,因此参数的范围可以取大一些、取值也可以更密集
Option = [1,*range(5,101,5)]#生成保存模型结果的arrays
trainRMSE = np.array([])
testRMSE = np.array([])
trainSTD = np.array([])
testSTD = np.array([])#在参数取值中进行循环
for n_estimators in Option:#按照当下的参数,实例化模型reg_f = RFR(n_estimators=n_estimators,random_state=83)#实例化交叉验证方式,输出交叉验证结果cv = KFold(n_splits=5,shuffle=True,random_state=83)result_f = cross_validate(reg_f,X,y,cv=cv,scoring="neg_mean_squared_error",return_train_score=True,n_jobs=-1)#根据输出的MSE进行RMSE计算train = abs(result_f["train_score"])**0.5test = abs(result_f["test_score"])**0.5#将本次交叉验证中RMSE的均值、标准差添加到arrays中进行保存trainRMSE = np.append(trainRMSE,train.mean()) #效果越好testRMSE = np.append(testRMSE,test.mean())trainSTD = np.append(trainSTD,train.std()) #模型越稳定testSTD = np.append(testSTD,test.std())

定义画图函数

def plotCVresult(Option,trainRMSE,testRMSE,trainSTD,testSTD):#一次交叉验证下,RMSE的均值与std的绘图xaxis = Optionplt.figure(figsize=(8,6),dpi=80)#RMSEplt.plot(xaxis,trainRMSE,color="k",label = "RandomForestTrain")plt.plot(xaxis,testRMSE,color="red",label = "RandomForestTest")#标准差 - 围绕在RMSE旁形成一个区间plt.plot(xaxis,trainRMSE+trainSTD,color="k",linestyle="dotted")plt.plot(xaxis,trainRMSE-trainSTD,color="k",linestyle="dotted")plt.plot(xaxis,testRMSE+testSTD,color="red",linestyle="dotted")plt.plot(xaxis,testRMSE-testSTD,color="red",linestyle="dotted")plt.xticks([*xaxis])plt.legend(loc=1)plt.show()plotCVresult(Option,trainRMSE,testRMSE,trainSTD,testSTD)

在这里插入图片描述

当绘制学习曲线时,我们可以很容易找到泛化误差开始上升、或转变为平稳趋势的转折点。因此我们可以选择转折点或转折点附近的n_estimators取值,例如20。然而,n_estimators会受到其他参数的影响,例如:

  • 单棵决策树的结构更简单时(依赖剪枝时),可能需要更多的树
  • 单棵决策树训练的数据更简单时(依赖随机性时),可能需要更多的树

因此n_estimators的参数空间可以被确定为range(20,100,5),如果你比较保守,甚至可以确认为是range(15,25,5)。

探索第二步(随机森林中的每一棵决策树)

属性.estimators_,查看森林中所有的树

参数参数含义对应属性属性含义
n_estimators树的数量reg.estimators_森林中所有树对象
max_depth允许的最大深度.tree_.max_depth0号树实际的深度
max_leaf_nodes允许的最大
叶子节点量
.tree_.node_count0号树实际的总节点量
min_sample_split分枝所需最小
样本量
.tree_.n_node_samples0号树每片叶子上实际的样本量
min_weight_fraction_leaf分枝所需最小
样本权重
tree_.weighted_n_node_samples0号树每片叶子上实际的样本权重
min_impurity_decrease分枝所需最小
不纯度下降量
.tree_.impurity
.tree_.threshold
0号树每片叶子上的实际不纯度
0号树每个节点分枝后不纯度下降量

可以通过对上述属性的调用查看当前模型每一棵树的各个属性,对我们对于参数范围的选择给予帮助。

正戏开始(网格搜索)

import numpy as np
import pandas as pd
import sklearn
import matplotlib as mlp
import matplotlib.pyplot as plt
import time #计时模块time
from sklearn.ensemble import RandomForestRegressor as RFR
from sklearn.model_selection import cross_validate, KFold, GridSearchCV
# 定义RMSE函数
def RMSE(cvresult,key):return (abs(cvresult[key])**0.5).mean()
#导入波士顿房价数据
data = pd.read_csv(r"D:\Pythonwork\datasets\House Price\train_encode.csv",index_col=0)
#查看数据
data.head()
X = data.iloc[:,:-1]
y = data.iloc[:,-1]

在这里插入图片描述

Step 1.建立benchmark

# 定义回归器
reg = RFR(random_state=83)
# 进行5折交叉验证
cv = KFold(n_splits=5,shuffle=True,random_state=83)result_pre_adjusted = cross_validate(reg,X,y,cv=cv,scoring="neg_mean_squared_error",return_train_score=True,verbose=True,n_jobs=-1)
#分别查看训练集和测试集在调参之前的RMSE
RMSE(result_pre_adjusted,"train_score")
RMSE(result_pre_adjusted,"test_score")

结果分别是
11177.272008319653
30571.26665524217

Step 2.创建参数空间

param_grid_simple = {"criterion": ["squared_error","poisson"], 'n_estimators': [*range(20,100,5)], 'max_depth': [*range(10,25,2)], "max_features": ["log2","sqrt",16,32,64,"auto"], "min_impurity_decrease": [*np.arange(0,5,10)]}

Step 3.实例化用于搜索的评估器、交叉验证评估器与网格搜索评估器

#n_jobs=4/8,verbose=True
reg = RFR(random_state=1412,verbose=True,n_jobs=-1)
cv = KFold(n_splits=5,shuffle=True,random_state=83)
search = GridSearchCV(estimator=reg,param_grid=param_grid_simple,scoring = "neg_mean_squared_error",verbose = True,cv = cv,n_jobs=-1)

Step 4.训练网格搜索评估器

#=====【TIME WARNING: 15mins】=====#   当然博主的电脑比较慢
start = time.time()
search.fit(X,y)
print(time.time() - start)

Step 5.查看结果

search.best_estimator_# 直接使用最优参数进行建模
ad_reg = RFR(n_estimators=85, max_depth=23, max_features=16, random_state=83)cv = KFold(n_splits=5,shuffle=True,random_state=83)
result_post_adjusted = cross_validate(ad_reg,X,y,cv=cv,scoring="neg_mean_squared_error",return_train_score=True,verbose=True,n_jobs=-1)
#查看调参后的结果
RMSE(result_post_adjusted,"train_score")
RMSE(result_post_adjusted,"test_score")

得出结果
11000.81099038192
28572.070208366855

调参前后对比

#默认值下随机森林的RMSE
xaxis = range(1,6)
plt.figure(figsize=(8,6),dpi=80)
#RMSE
plt.plot(xaxis,abs(result_pre_adjusted["train_score"])**0.5,color="green",label = "RF_pre_ad_Train")
plt.plot(xaxis,abs(result_pre_adjusted["test_score"])**0.5,color="green",linestyle="--",label = "RF_pre_ad_Test")
plt.plot(xaxis,abs(result_post_adjusted["train_score"])**0.5,color="orange",label = "RF_post_ad_Train")
plt.plot(xaxis,abs(result_post_adjusted["test_score"])**0.5,color="orange",linestyle="--",label = "RF_post_ad_Test")
plt.xticks([1,2,3,4,5])
plt.xlabel("CVcounts",fontsize=16)
plt.ylabel("RMSE",fontsize=16)
plt.legend()
plt.show()

在这里插入图片描述
不难发现,网格搜索之后的模型过拟合程度减轻,且在训练集与测试集上的结果都有提高,可以说从根本上提升了模型的基础能力。我们还可以根据网格的结果继续尝试进行其他调整,来进一步降低模型在测试集上的RMSE。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/668340.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

ASP.NET Core 预防开放式重定向攻击

写在前面 为预防钓鱼网站的常用套路,在进行 Web 应用程序的开发时,原则上应该将所有由用户提交的数据视为不可信。如果应用程序中包含了基于 URL 内容重定向的功能,需要确保这种类型的重定向操作只能在应用本地完成,或者明确判断…

[技术杂谈]如何下载vscode历史版本

网站模板: https://code.visualstudio.com/updates/v1_85 如果你想下载1.84系列可以访问https://code.visualstudio.com/updates/v1_84​​​​​​ 然后看到: 选择对应版本下载即可,我是windows x64系统选择x64即可开始下载

MQTT在linux下服务端和客户端的应用

MQTT(Message Queuing Telemetry Transport)是一种轻量级、开放标准的消息传输协议,设计用于受限设备和低带宽、不稳定网络的通信。 MQTT的一些关键特点和概念: 发布/订阅模型: MQTT采用发布/订阅(Publ…

QCustomplot实现灰度曲线图

从 QCustomplot官网 https://www.qcustomplot.com/index.php/download 下载支持文件。首页有些demo可以进行参考学习。 新建一个Qt工程,将下载得到的qcustomplot.h和qcustomplot.cpp文件加入到当前工程。pro文件中加上 printsupport 在ui界面中,添加一…

云服务器总结

1.服务器重装系后远程连接报错 Host key for xxx.xxx.xxx.xxx has changed and you have requested strict checking. 问题原因 ssh会把你每个你访问过计算机的公钥(public key)都记录在~/.ssh/known_hosts。当下次访问相同计算机时,OpenSSH会核对公钥。如果公钥不…

【Linux】信号-上

欢迎来到Cefler的博客😁 🕌博客主页:折纸花满衣 🏠个人专栏:题目解析 🌎推荐文章:【LeetCode】winter vacation training 目录 👉🏻信号的概念与产生jobs命令普通信号和实…

BridgeTower:融合视觉和文本信息的多层语义信息,主打复杂视觉-语言任务

BridgeTower 核心思想子问题1:双塔架构的局限性子问题2:不同层次的语义信息未被充分利用子问题3:模型扩展性和泛化能力 核心思想 论文:https://arxiv.org/pdf/2206.08657.pdf 代码:https://github.com/microsoft/Bri…

基于Java SSM框架实现网上租车系统项目【项目源码+论文说明】计算机毕业设计

基于java的SSM框架实现网上租车系统演示 摘要 随着现在网络的快速发展,网上管理系统也逐渐快速发展起来,网上管理模式很快融入到了许多商家的之中,随之就产生了“网上租车系统”,这样就让网上租车系统更加方便简单。 对于本网上…

JAVA Web 学习(五)Nginx、RPC、JWT

十二、反向代理服务器——Nginx 支持热部署,几乎可以做到 7 * 24 小时不间断运行,即使运行几个月也不需要重新启动,还能在不间断服务的情况下对软件版本进行热更新。性能是 Nginx 最重要的考量,其占用内存少、并发能力强、能支持…

arcpy高德爬取路况信息数据json转shp

最近工作上遇到爬取的高德路况信息数据需要在地图上展示出来,由于json数据不具备直接可视化的能力,又联想到前两个月学习了一点点arcpy的知识,就花了一些时间去写了个代码,毕竟手动处理要了老命了。 1、json文件解读 json文件显…

Log360,引入全新安全与风险管理功能,助力企业积极抵御网络威胁

ManageEngine在其SIEM解决方案中推出了安全与风险管理新功能,企业现在能够更主动地减轻内部攻击和防范入侵。 SIEM 这项新功能为Log360引入了安全与风险管理仪表板,Log360是ManageEngine的统一安全信息与事件管理(SIEM)解决方案…

【新书推荐】6.1 if语句

第六章 分支结构 计算机语言和人类语言类似,人类语言是为了解决人与人之间交流的问题,而计算机语言是为了解决程序员与计算机之间交流的问题。程序员编写的程序就是计算机的控制指令,控制计算机的运行。借助于编译工具,可以将各种…

Java代码实现基数排序算法(附带源码)

基数排序是一种非比较型整数排序算法,其原理是将整数按位数切割成不同的数字,然后按每个位数分别比较。由于整数也可以表达字符串(比如名字或日期)和特定格式的浮点数,所以基数排序也不是只能使用于整数。 1. 基数排序…

大白话介绍循环神经网络

循环神经网络实质为递归式的网络,它在处理时序任务表现出优良的效果,毕竟递归本来就是一步套一步的向下进行,而自然语言处理任务中涉及的文本天然满足这种时序性,比如我们写字就是从左到右一步步来的鸭,刚接触深度学习…

基于hadoop+spark的大规模日志的一种处理方案

概述: CDN服务平台上有为客户提供访问日志下载的功能,主要是为了满足在给CDN客户提供服务的过程中,要对所有的记录访问日志,按照客户定制的格式化需求以小时为粒度(或者其他任意时间粒度)进行排序、压缩、打包,供客户进行下载,以便进行后续的核对和分析的诉求。而且CDN…

xlsx xlsx-style 使用和坑记录

1 安装之后报错 npm install xlsx --savenpm install xlsx-style --save Umi运行会报错 自己代码 import XLSX from "xlsx"; import XLSXStyle from "xlsx-style";const data [["demo1","demo2","demo3","demo4&quo…

好烦,怎么输入拼音的过程也会触发input事件!!!

说在前面 🎈input输入框大家应该都很熟悉了吧,不知道大家有没有遇到过这样的一种情况:如上图,在中文输入过程中,输入的拼音也会触发input框的input事件,有些时候我们并不希望在中文输入的过程中拼音触发inp…

基于OpenCV灰度图像转GCode的双向扫描实现

基于OpenCV灰度图像转GCode的双向扫描实现 引言激光雕刻简介OpenCV简介实现步骤 1.导入必要的库2. 读取灰度图像3. 图像预处理4. 生成GCode 1. 简化版的双向扫描2. 优化版的双向扫描 5. 保存生成的GCode6. 灰度图像双向扫描代码示例 总结 系列文章 ⭐深入理解G0和G1指令&…

Matomo 访问图形显示异常

近期我们的把 PHP 系统完全升级后,访问 Matomo 的站点有关访问的曲线无法显示。 出现的情况如下图: 我们可以看到图片中有关的访问曲线无法显示。 如果具体直接访问链接的话,会有下面的错误信息。 问题和解决 出现上面问题的原因是缺少 ph…

我的创作128纪念日

机缘 起初我写博客是为了记录自己的学习过程,现在也是如此 实战项目中的经验分享日常学习过程中的记录通过文章进行技术交流通过文章加深学习和复习 收获 在创作过程中 获得了400多位粉丝的关注感谢大家的支持阅读数量也达到了3w在博客上认识仲秋大佬,感谢大佬对我的指导,我…