模型优化_XGBOOST学习曲线及改进,泛化误差

代码

from xgboost import XGBRegressor as XGBR
from sklearn.ensemble import RandomForestRegressor as RFR
from sklearn.linear_model import LinearRegression as LR
from sklearn.datasets import load_boston
from sklearn.model_selection import train_test_split,cross_val_score as CV,KFold
from sklearn.metrics import mean_squared_error as MSE
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from time import time
import datetime#加载数据
data=load_boston()
X=data.data
y=data.target#划分数据集
Xtrain,Xtest,ytrain,ytest=train_test_split(X,y,test_size=0.3,random_state=420)#定位模型,进行fit
reg=XGBR(n_estimators=100).fit(Xtrain,ytrain)#进行预测
reg.predict(Xtest)
reg.score(Xtest,ytest)#返回的是R平方
MSE(ytest,reg.predict(Xtest))
reg.feature_importances_
#查看SKLEARN中所有的模型评估指标
import sklearn
sorted(sklearn.metrics.SCORERS.keys())# ======================================
#交叉验证,与线性回归随机森林进行结果比对
reg=XGBR(n_estimators=100)from sklearn.model_selection import train_test_split,cross_val_score
cross_val_score(reg,Xtrain,ytrain,cv=5).mean()##交叉验证既可以解决数据集的数据量不够大问题,也可以解决参数调优的问题。
#这块主要有三种方式:简单交叉验证(HoldOut检验)、k折交叉验证(k-fold交叉验证)
cross_val_score(reg,Xtrain,ytrain,cv=5,scoring="neg_mean_squared_error").mean()#绘制学习曲线
def plot_learning_curve(estimator,title,X,y,ax=None,#选择子图ylim=None,#设置纵坐标的取值范围cv=None,#交叉验证n_jobs=None):from sklearn.model_selection import learning_curvetrain_sizes,train_scores,test_scores=learning_curve(estimator,X,y,shuffle=True,cv=cv,random_state=420,n_jobs=n_jobs)if ax==None:ax=plt.gca()else:ax=plt.figure()ax.set_title(title)if ylim is not None:ax.set_ylim(*ylim)ax.set_xlabel("Traing example")ax.set_ylabel("Score")ax.grid()#绘制网格ax.plot(train_sizes,np.mean(train_scores,axis=1),"o-",color="r",label="traing score")ax.plot(train_sizes,np.mean(test_scores,axis=1),"o-",color="g",label="test.py score")ax.legend(loc="best")return ax#学习曲线的绘制
cv=KFold(n_splits=5,shuffle=True,random_state=42)
plot_learning_curve(XGBR(n_estimators=100,random_state=420),"XGB",Xtrain,ytrain,ax=None,cv=cv)

#绘制学习曲线,查看n_estimators对模型的影响 

#绘制学习曲线,查看n_estimators对模型的影响
axis=range(10,50,1)
rs=[]
for i in axis:reg=XGBR(n_estimators=i)cv1=cross_val_score(reg,Xtrain,ytrain,cv=5).mean()rs.append(cv1)
print(axis[rs.index(max(rs))],max(rs))
plt.figure(figsize=(20,5))
plt.plot(axis,rs,c='red',label="XGB")
plt.legend()
plt.show()

泛化误差:用来衡量模型在未知数据集上的准确率

#绘制学习曲线,查看n_estimators对模型的影响
axis=range(10,50,1)
rs=[]#偏差,衡量的是准确率
var=[]#方差,衡量的是稳定性
ge=[]#泛化误差的可控部门
for i in axis:reg=XGBR(n_estimators=i)cv1=cross_val_score(reg,Xtrain,ytrain,cv=5)rs.append(cv1.mean())#记录偏差,返回的R平方就是偏差部门,衡量的是准确率var.append(cv1.var())ge.append((1-cv1.mean())**2+cv1.var())
print(axis[rs.index(max(rs))],max(rs),var[rs.index(max(rs))])
print(axis[var.index(min(var))],rs[var.index(min(var))],min(var))
print(axis[ge.index(min(ge))],rs[ge.index(min(ge))],var[ge.index(min(ge))],min(ge))
plt.figure(figsize=(20,5))
plt.plot(axis,rs,c='red',label="XGB")
plt.legend()
plt.show()

 

#绘制学习曲线,查看n_estimators对模型的影响

#绘制学习曲线,查看n_estimators对模型的影响
axis=range(10,30,1)
rs=[]#偏差,衡量的是准确率
var=[]#方差,衡量的是稳定性
ge=[]#泛化误差的可控部门
for i in axis:reg=XGBR(n_estimators=i)cv1=cross_val_score(reg,Xtrain,ytrain,cv=5)rs.append(cv1.mean())#记录偏差,返回的R平方就是偏差部门,衡量的是准确率var.append(cv1.var())ge.append((1-cv1.mean())**2+cv1.var())
print(axis[rs.index(max(rs))],max(rs),var[rs.index(max(rs))])
print(axis[var.index(min(var))],rs[var.index(min(var))],min(var))
print(axis[ge.index(min(ge))],rs[ge.index(min(ge))],var[ge.index(min(ge))],min(ge))
#添加方差线条
rs=np.array(rs)
var=np.array(var)#源代码这里*0.01
plt.figure(figsize=(20,5))
plt.plot(axis,rs,c='red',label="XGB")
plt.plot(axis,rs+var,c="black",linestyle="-.")
plt.plot(axis,rs-var,c="black",linestyle="-.")
plt.legend()
plt.show()

#看看泛化误差的可控部分如何

plt.figure(figsize=(20,5))
plt.plot(axis,ge,c='red',label="XGB")
plt.legend()
plt.show()

从这个过程中观察n_estimators参数对模型的影响,我们可以得出以下结论:
首先,XGB中的树的数量决定了模型的学习能力,树的数量越多,模型的学习能力越强。只要XGB中树的数量足够
了,即便只有很少的数据, 模型也能够学到训练数据100%的信息,所以XGB也是天生过拟合的模型。但在这种情况
下,模型会变得非常不稳定。
第二,XGB中树的数量很少的时候,对模型的影响较大,当树的数量已经很多的时候,对模型的影响比较小,只能有
微弱的变化。当数据本身就处于过拟合的时候,再使用过多的树能达到的效果甚微,反而浪费计算资源。当唯一指标
或者准确率给出的n_estimators看起来不太可靠的时候,我们可以改造学习曲线来帮助我们。
第三,树的数量提升对模型的影响有极限,最开始,模型的表现会随着XGB的树的数量一起提升,但到达某个点之
后,树的数量越多,模型的效果会逐步下降,这也说明了暴力增加n_estimators不一定有效果。
这些都和随机森林中的参数n_estimators表现出一致的状态。在随机森林中我们总是先调整n_estimators,当
n_estimators的极限已达到,我们才考虑其他参数,但XGB中的状况明显更加复杂,当数据集不太寻常的时候会更加
复杂。这是我们要给出的第一个超参数,因此还是建议优先调整n_estimators,一般都不会建议一个太大的数目,
300以下为佳。

参考:

XGBOOST学习曲线及改进,泛化误差-CSDN博客

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/709423.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Linux: GDB 调试工具

概念: Linux GDB(GNU Debugger)是一个功能强大的调试工具,用于调试C、C等编程语言的程序。它可以帮助开发人员定位和修复程序中的错误。 GDB 的使用 : 激活和进入工作模式: gdb 需要调试的文件 进入 …

【Java设计模式】三、

文章目录 0、案例:咖啡屋1、简单工厂模式 静态工厂(不属于23种之列)2、工厂方法模式3、抽象工厂模式4、建造者模式5、原型设计模式 0、案例:咖啡屋 模拟咖啡店点餐。咖啡有多种,抽象类,子类为各种咖啡。咖…

2.29作业

T课上实现通信代码总结&#xff1a; 程序代码&#xff1a; TCPSER.c #include<myhead.h> #define SER_IP "192.168.244.140" //服务器IP #define SER_PORT 9999 //服务器端口号 int main(int argc, const char *argv[]) {//1.创建用于监…

为什么猫咪挑食不吃猫粮?适口性好、普口性价的主食冻干推荐

现代养猫人士往往把自家的小猫看作是生活中的小宝贝&#xff0c;十分宠爱。最令人头疼的就是猫咪挑食不吃猫粮&#xff0c;为什么猫咪挑食不吃猫粮&#xff1f;猫咪挑食应该怎么办&#xff1f;今天为大家分享一个既不让咱宝贝猫咪受罪又可以改善猫咪挑食的方法。 一、为什么猫咪…

深入理解nginx的https sni机制

目录 1. 概述2. 初识sni3. nginx的ssl证书配置指令3.1 ssl_certificate3.2 ssl_certificate_key3.3 ssl_password_file4. nginx源码分析4.1 给ssl上下文的初始化4.2 连接初始化4.3 处理sni回调4.2 动态证书的加载5. 总结阅读姊妹篇: 深入理解nginx的https alpn机制 1. 概述 SN…

Vue+SpringBoot打造音乐偏好度推荐系统

目录 一、摘要1.1 项目介绍1.2 项目录屏 二、系统设计2.1 功能模块设计2.1.1 音乐档案模块2.1.2 我的喜好模块2.1.3 每日推荐模块2.1.4 通知公告模块 2.2 用例图设计2.3 实体类设计2.4 数据库设计 三、系统展示3.1 登录注册3.2 音乐档案模块3.3 音乐每日推荐模块3.4 通知公告模…

外包干了6个月,技术退步明显。。。。。

先说一下自己的情况&#xff0c;本科生&#xff0c;2019年我通过校招踏入了重庆一家软件公司&#xff0c;开始了我的职业生涯。那时的我&#xff0c;满怀热血和憧憬&#xff0c;期待着在这个行业中闯出一片天地。然而&#xff0c;随着时间的推移&#xff0c;我发现自己逐渐陷入…

Dockerfile(1) - FROM 指令详解

FROM 指明当前的镜像基于哪个镜像构建dockerfile 必须以 FROM 开头&#xff0c;除了 ARG 命令可以在 FROM 前面 FROM [--platform<platform>] <image> [AS <name>]FROM [--platform<platform>] <image>[:<tag>] [AS <name>]FROM […

Java+SpringBoot+Vue+MySQL:员工健康管理技术新组合

✍✍计算机毕业编程指导师 ⭐⭐个人介绍&#xff1a;自己非常喜欢研究技术问题&#xff01;专业做Java、Python、微信小程序、安卓、大数据、爬虫、Golang、大屏等实战项目。 ⛽⛽实战项目&#xff1a;有源码或者技术上的问题欢迎在评论区一起讨论交流&#xff01; ⚡⚡ Java、…

TCP的三次握手和四次挥手 | 查看网络状态

三次握手和四次挥手是在计算机网络中用于建立和终止TCP连接的协议。这两个过程是TCP协议的重要组成部分&#xff0c;确保数据的可靠传输。 三次握手指的是在客户端和服务器之间建立连接时的步骤。具体流程如下&#xff1a; 客户端向服务器发送一个连接请求报文段&#xff08;…

Git教程-Git的基本使用

Git是一个强大的分布式版本控制系统&#xff0c;它不仅用于跟踪代码的变化&#xff0c;还能够协调多个开发者之间的工作。在软件开发过程中&#xff0c;Git被广泛应用于协作开发、版本管理和代码追踪等方面。以下是一个详细的Git教程&#xff0c;我们将深入探讨Git的基本概念和…

数据结构——lesson4带头双向循环链表实现

前言✨✨ &#x1f4a5;个人主页&#xff1a;大耳朵土土垚-CSDN博客 &#x1f4a5; 所属专栏&#xff1a;数据结构学习笔记​​​​​​ &#x1f4a5;双链表与单链表的区分&#xff1a;单链表介绍与实现 &#x1f4a5;对于malloc函数有疑问的:动态内存函数介绍 感谢大家的观看…

tomcat安装步骤流程

安装tomcat是基于安装java的基础上的 JAVA 举例说明&#xff1a; 关闭防火墙 下载java [rootlocalhost ~]#yum install java -y rootlocalhost ~]#yum install epel-release.noarch -y [rootlocalhost ~]#yum provides */javac [rootlocalhost data]#yum install java-1.8.0-o…

6.1 deeplabv3+的pth模型转换为rknn模型

和yolov5的pth模型转换为rknn模型类似&#xff0c;deeplabv3的pth模型转为rknn模型的步骤是&#xff1a; pth------>onnx-------->rknn 1.pth转为onnx 代码如下&#xff1a; #!/usr/bin/env python3 # -*- coding: utf-8 -*- # by [jackhanyuan](https://github.com/…

实现一个线程安全的单例模式

单例模式 单例模式能保证某个类在程序中只存在唯⼀⼀份实例,⽽不会创建出多个实例 某个类,在一个类,只应该创建出一个实例,使用单例模式,就可以对咱们的代码进行一个更严格的校验和检查 单例模式具体的实现⽅式有很多.最常⻅的是"饿汉"和"懒汉"两种单例…

DevEco Studio下载与安装(Windows)

下载地址&#xff1a; HUAWEI DevEco Studio和SDK下载和升级 | HarmonyOS开发者 安装时直接点击 next 即可。 运⾏已安装的DevEco Studio&#xff0c;⾸次使⽤&#xff0c;请选择Do not import settings&#xff0c;单击OK。 1.安装Node.js 如果本地有下载&#xff0c;可以…

前端JS 时间复杂度和空间复杂度

时间复杂度 BigO 算法的时间复杂度通常用大 O 符号表述&#xff0c;定义为 T(n) O(f(n)) 实际就是计算当一个一个问题量级&#xff08;n&#xff09;增加的时候&#xff0c;时间T增加的一个趋势 T(n)&#xff1a;时间的复杂度&#xff0c;也就相当于所消耗的时长 O&#xff1…

乐吾乐Web可视化RTSP播放

背景 乐吾乐致力于物联网和智能制造等场景的Web可视化平台和解决方案&#xff0c;其中摄像头播放必不可少。 当前国内摄像头都以RTSP协议为主&#xff0c;而HTML不能直接读取RTSP协议&#xff0c;因此需要一个转流服务。乐吾乐Web可视化播放RTSP也是如此&#xff1a; RTSP协…

理解计算着色器中glsl语言的内置变量

概要 本文通过示例的方式&#xff0c;着重解释以下几个内置变量&#xff1a; gl_WorkGroupSizegl_NumWorkGroupsgl_LocalInvocationIDgl_WorkGroupIDgl_GlobalInvocationID 基本概念 局部工作组与工作项 一个3x2x1的局部工作组示例如下&#xff0c;每个小篮格子表示一个工作项…

Vulnhub靶机:basic_pentesting_1

一、介绍 运行环境&#xff1a;Virtualbox 攻击机&#xff1a;kali&#xff08;10.0.2.4&#xff09; 靶机&#xff1a;basic_pentesting_1&#xff08;10.0.2.6&#xff09; 目标&#xff1a;获取靶机root权限和flag 靶机下载地址&#xff1a;https://www.vulnhub.com/en…