第83步 时间序列建模实战:Catboost回归建模

基于WIN10的64位系统演示

一、写在前面

这一期,我们介绍Catboost回归。

同样,这里使用这个数据:

《PLoS One》2015年一篇题目为《Comparison of Two Hybrid Models for Forecasting the Incidence of Hemorrhagic Fever with Renal Syndrome in Jiangsu Province, China》文章的公开数据做演示。数据为江苏省2004年1月至2012年12月肾综合症出血热月发病率。运用2004年1月至2011年12月的数据预测2012年12个月的发病率数据。

二、Catboost回归

(1)参数解读

无论是回归还是分类,CatBoost的大部分参数都是通用的,但任务的不同性质意味着一些参数可能只在一个任务中有意义。

以下是一些关键参数的简要概述:

(a)通用参数:

learning_rate: 学习率,决定了模型每一步的步长。常用的值为0.01, 0.03, 0.1等。

iterations: 树的数量。

depth: 树的深度。

l2_leaf_reg: L2正则化项的系数。

cat_features: 分类特征的列索引列表。

loss_function: 损失函数。对于分类,常见的是Logloss(二分类)或MultiClass(多分类)。对于回归,常见的是RMSE。

border_count: 用于数值特征的分箱数量。较高的值可能会导致过拟合,较低的值可能会导致欠拟合。

verbose: 显示的训练日志的详细程度。

(b)专用于分类的参数:

classes_count: 在多分类任务中,类别的数量。

class_weights: 各类的权重,用于不平衡分类任务。

auto_class_weights: 用于处理类不平衡的自动权重计算方法。

(c)专用于回归的参数:

scale_pos_weight: 用于不平衡的回归任务。

(d)异同点:

相同点: 大部分参数(如learning_rate, depth, l2_leaf_reg等)在回归和分类任务中都是相同的,并且它们的含义和效果也是一致的。

不同点: 损失函数loss_function是根据任务(回归或分类)来确定的。此外,某些参数(如classes_count和class_weights)仅在分类任务中有意义,而scale_pos_weight更倾向于回归任务。

此外,在使用CatBoost时,建议始终查阅其官方文档,因为该库可能会经常更新,新的参数或功能可能会被添加进来。网址如下:

https://catboost.ai/docs/

(2)单步滚动预测

import pandas as pd
import numpy as np
from sklearn.metrics import mean_absolute_error, mean_squared_error
from catboost import CatBoostRegressor
from sklearn.model_selection import GridSearchCV# 读取数据
data = pd.read_csv('data.csv')# 将时间列转换为日期格式
data['time'] = pd.to_datetime(data['time'], format='%b-%y')# 创建滞后期特征
lag_period = 6
for i in range(lag_period, 0, -1):data[f'lag_{i}'] = data['incidence'].shift(lag_period - i + 1)# 删除包含 NaN 的行
data = data.dropna().reset_index(drop=True)# 划分训练集和验证集
train_data = data[(data['time'] >= '2004-01-01') & (data['time'] <= '2011-12-31')]
validation_data = data[(data['time'] >= '2012-01-01') & (data['time'] <= '2012-12-31')]# 定义特征和目标变量
X_train = train_data[['lag_1', 'lag_2', 'lag_3', 'lag_4', 'lag_5', 'lag_6']]
y_train = train_data['incidence']
X_validation = validation_data[['lag_1', 'lag_2', 'lag_3', 'lag_4', 'lag_5', 'lag_6']]
y_validation = validation_data['incidence']# 初始化 CatBoostRegressor 模型
catboost_model = CatBoostRegressor(verbose=0)# 定义参数网格
param_grid = {'iterations': [50, 100, 150],'learning_rate': [0.01, 0.05, 0.1, 0.5, 1],'depth': [4, 6, 8],'loss_function': ['RMSE']
}# 初始化网格搜索
grid_search = GridSearchCV(catboost_model, param_grid, cv=5, scoring='neg_mean_squared_error')# 进行网格搜索
grid_search.fit(X_train, y_train)# 获取最佳参数
best_params = grid_search.best_params_# 使用最佳参数初始化 CatBoostRegressor 模型
best_catboost_model = CatBoostRegressor(**best_params, verbose=0)# 在训练集上训练模型
best_catboost_model.fit(X_train, y_train)# 对于验证集,我们需要迭代地预测每一个数据点
y_validation_pred = []for i in range(len(X_validation)):if i == 0:pred = best_catboost_model.predict([X_validation.iloc[0]])else:new_features = list(X_validation.iloc[i, 1:]) + [pred[0]]pred = best_catboost_model.predict([new_features])y_validation_pred.append(pred[0])y_validation_pred = np.array(y_validation_pred)# 计算验证集上的MAE, MAPE, MSE 和 RMSE
mae_validation = mean_absolute_error(y_validation, y_validation_pred)
mape_validation = np.mean(np.abs((y_validation - y_validation_pred) / y_validation))
mse_validation = mean_squared_error(y_validation, y_validation_pred)
rmse_validation = np.sqrt(mse_validation)# 计算训练集上的MAE, MAPE, MSE 和 RMSE
y_train_pred = best_catboost_model.predict(X_train)
mae_train = mean_absolute_error(y_train, y_train_pred)
mape_train = np.mean(np.abs((y_train - y_train_pred) / y_train))
mse_train = mean_squared_error(y_train, y_train_pred)
rmse_train = np.sqrt(mse_train)print("Train Metrics:", mae_train, mape_train, mse_train, rmse_train)
print("Validation Metrics:", mae_validation, mape_validation, mse_validation, rmse_validation)

看结果:

(3)多步滚动预测-vol. 1

对于Catboost回归,目标变量y_train不能是多列的DataFrame,所以你们懂的。

(4)多步滚动预测-vol. 2

同上。

(5)多步滚动预测-vol. 3

import pandas as pd
import numpy as np
from catboost import CatBoostRegressor  # 导入CatBoostRegressor
from sklearn.model_selection import GridSearchCV
from sklearn.metrics import mean_absolute_error, mean_squared_error# 数据读取和预处理
data = pd.read_csv('data.csv')
data_y = pd.read_csv('data.csv')
data['time'] = pd.to_datetime(data['time'], format='%b-%y')
data_y['time'] = pd.to_datetime(data_y['time'], format='%b-%y')n = 6for i in range(n, 0, -1):data[f'lag_{i}'] = data['incidence'].shift(n - i + 1)data = data.dropna().reset_index(drop=True)
train_data = data[(data['time'] >= '2004-01-01') & (data['time'] <= '2011-12-31')]
X_train = train_data[[f'lag_{i}' for i in range(1, n+1)]]
m = 3X_train_list = []
y_train_list = []for i in range(m):X_temp = X_trainy_temp = data_y['incidence'].iloc[n + i:len(data_y) - m + 1 + i]X_train_list.append(X_temp)y_train_list.append(y_temp)for i in range(m):X_train_list[i] = X_train_list[i].iloc[:-(m-1)]y_train_list[i] = y_train_list[i].iloc[:len(X_train_list[i])]# 模型训练
param_grid = {'iterations': [50, 100, 150],'learning_rate': [0.01, 0.05, 0.1, 0.5, 1],'depth': [4, 6, 8]
}best_catboost_models = []for i in range(m):grid_search = GridSearchCV(CatBoostRegressor(verbose=0), param_grid, cv=5, scoring='neg_mean_squared_error')  # 使用CatBoostRegressorgrid_search.fit(X_train_list[i], y_train_list[i])best_catboost_model = CatBoostRegressor(**grid_search.best_params_, verbose=0)best_catboost_model.fit(X_train_list[i], y_train_list[i])best_catboost_models.append(best_catboost_model)validation_start_time = train_data['time'].iloc[-1] + pd.DateOffset(months=1)
validation_data = data[data['time'] >= validation_start_time]X_validation = validation_data[[f'lag_{i}' for i in range(1, n+1)]]
y_validation_pred_list = [model.predict(X_validation) for model in best_catboost_models]
y_train_pred_list = [model.predict(X_train_list[i]) for i, model in enumerate(best_catboost_models)]def concatenate_predictions(pred_list):concatenated = []for j in range(len(pred_list[0])):for i in range(m):concatenated.append(pred_list[i][j])return concatenatedy_validation_pred = np.array(concatenate_predictions(y_validation_pred_list))[:len(validation_data['incidence'])]
y_train_pred = np.array(concatenate_predictions(y_train_pred_list))[:len(train_data['incidence']) - m + 1]mae_validation = mean_absolute_error(validation_data['incidence'], y_validation_pred)
mape_validation = np.mean(np.abs((validation_data['incidence'] - y_validation_pred) / validation_data['incidence']))
mse_validation = mean_squared_error(validation_data['incidence'], y_validation_pred)
rmse_validation = np.sqrt(mse_validation)
print("验证集:", mae_validation, mape_validation, mse_validation, rmse_validation)mae_train = mean_absolute_error(train_data['incidence'][:-(m-1)], y_train_pred)
mape_train = np.mean(np.abs((train_data['incidence'][:-(m-1)] - y_train_pred) / train_data['incidence'][:-(m-1)]))
mse_train = mean_squared_error(train_data['incidence'][:-(m-1)], y_train_pred)
rmse_train = np.sqrt(mse_train)
print("训练集:", mae_train, mape_train, mse_train, rmse_train)

结果:

三、数据

链接:https://pan.baidu.com/s/1EFaWfHoG14h15KCEhn1STg?pwd=q41n

提取码:q41n

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/100866.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Android Fragment中使用Arouter跳转到Activity后返回Fragment不回调onActivityResult

Fragment中通过路由跳转到Activity 跳转传递参数 通过Arouter跳转 Postcard postcard ARouter.getInstance().build(RouterConstant.ACTION_TRANSMANAGERACTIVITY1);Bundle bundle new Bundle();bundle.putInt("code", 404);postcard.with(bundle); //设置bundlef…

如何查看端口占用(windows,linux,mac)

如何查看端口占用&#xff0c;各平台 一、背景 如何查看端口占用&#xff1f;网上很多&#xff0c;但大多直接丢出命令&#xff0c;没有任何解释关于如何查看命令的输出 所谓 “查端口占用”&#xff0c;即查看某个端口是否被某个程序占用&#xff0c;如果有&#xff0c;被哪…

Vuex的使用,详细易懂

目录 一.前言 二.Vuex的简介 三.vuex的使用 3.1 安装Vuex 3.2 使用Vuex的步骤&#xff1a; 四.vuex的存值取值&#xff08;改变值&#xff09; 五.vuex的异步请求 好啦&#xff0c;今天的分享就到这啦&#xff01;&#xff01;&#xff01; 一.前言 今天我们继续前面的E…

导引服务机器人 通用技术条件

声明 本文是学习GB-T 42831-2023 导引服务机器人 通用技术条件. 而整理的学习笔记,分享出来希望更多人受益,如果存在侵权请及时联系我们 6 检验规则 6.1 检验项目 检验分为型式检验和出厂检验。检验项目见表2。 表 2 检验项目 序号 检验项目 技术要求 检验方法 出厂检验 型…

天锐绿盾加密软件——企业数据防泄密-CAD图纸、文档、源代码加密管理系统@德人合科技

天锐绿盾是一款专门为企业提供数据防泄密和文档加密管理的软件。该软件通过加密技术保护企业的核心数据&#xff0c;防止数据泄露和侵权行为&#xff0c;同时提供了全方位的文档加密管理系统&#xff0c;实现了对企业数据的安全保障和有效管理。 PC访问地址&#xff1a; isite…

睿趣科技:抖音店铺怎么取名受欢迎

抖音作为国内最大的短视频平台&#xff0c;其商业价值不容忽视。许多商家和创作者都在抖音上开设了自己的店铺&#xff0c;而一个富有创意和吸引力的店铺名字&#xff0c;往往能带来更多的客流量。那么&#xff0c;如何为抖音店铺取个好名字呢?以下是一些有用的建议。 明确定位…

【MATLAB源码-第44期】基于matlab的2*2MIMO-LDPC系统的误码率仿真。

操作环境&#xff1a; MATLAB 2022a 1、算法描述 2x2 MIMO&#xff08;多输入多输出&#xff09;和LDPC&#xff08;低密度奇偶校验码&#xff09;编码是在通信系统中常用的技术&#xff0c;它们通常用于提高无线通信系统的性能和可靠性。 1. 2x2 MIMO&#xff1a; 2x2 MIMO…

【RabbitMQ 实战】09 客户端连接集群生产和消费消息

一、部署一个三节点集群 下面的链接是最快最简单的一种集群部署方法 3分钟部署一个RabbitMQ集群 上的的例子中&#xff0c;没有映射端口&#xff0c;所以没法从宿主机外部连接容器&#xff0c;下面的yml文件中&#xff0c;暴露了端口。 每个容器应用都映射了宿主机的端口&…

Vscode进行远程开发

之前用的是pycharm&#xff0c;但是同事说pycharm太重了&#xff0c;连接远程服务器的时候给远程服务器的压力比较大&#xff0c;有时候远程服务器可能都扛不住&#xff0c;所以换成了vscode。 参考博客 手把手教你配置VS Code远程开发工具&#xff0c;工作效率提升N倍 - 知…

词云图大揭秘:如何从文本中挖掘热点词汇?

随着互联网的普及&#xff0c;大量的文本信息在网络上被产生和传播。如何从这些海量的文本中提取出有价值的信息&#xff0c;成为了人们关注的焦点。在这个信息爆炸的时代&#xff0c;词云图作为一种直观、形象的数据可视化手段&#xff0c;越来越受到人们的喜爱。本文手把手教…

设计模式 - 七大软件设计原则

目录 一、设计模式 1.1、软件设计原则 1.1.1、开闭原则 1.2.2、单一职责原则 1.2.3、里氏替换原则 1.2.4、迪米特原则 1.2.5、接口隔离原则 1.2.6、依赖倒转原则 1.2.7、合成/聚合复用原则 一、设计模式 1.1、软件设计原则 1.1.1、开闭原则 开闭原则&#xff1a;对扩…

双周赛114(模拟、枚举 + 哈希、DFS)

文章目录 双周赛114[2869. 收集元素的最少操作次数](https://leetcode.cn/problems/minimum-operations-to-collect-elements/)模拟 [2870. 使数组为空的最少操作次数](https://leetcode.cn/problems/minimum-number-of-operations-to-make-array-empty/)哈希 枚举 [2871. 将数…

Docker--harbor私有仓库部署与管理

目录 一、Harbor简介 1、什么是Harbor 2、Harbor的特性 3、Haebor的构成 二、搭建本地私有仓库 1、本地私有仓库创建 2、将镜像上传至本地私有仓库 三、搭建Harbor仓库 1. 部署 Docker-Compose 服务 2、部署 Harbor 服务 3、启动Harbor 4、创建一个新项目 5、在其他…

并发、并行、同步、异步、阻塞、非阻塞

一、多核、多cpu &#xff08;一&#xff09;多核 Multicore 核是CPU最重要的部分。负责运算。核包括控制单元、运算单元、寄存器等单元。 多核就是指单个CPU中有多个核。 &#xff08;二&#xff09;多cpu Multiprocessor 多cpu就是一个系统拥有多个CPU。每个CPU可能有单个核…

北京股票开户的佣金手续费是多少?北京股票开户选择哪家券商?

北京股票开户的佣金手续费是多少?北京股票开户选择哪家券商? 股票注册开户是非常简单的&#xff0c;在2015年前也就是互联网还不发达的时候&#xff0c;投资者只能去券商的营业部柜台办理&#xff0c;而自从各大券商都可以网上开户后&#xff0c;更多的投资者会选择网上开户…

【运维】一些团队开发相关的软件安装。

gitlab 安装步骤 (1) 下载镜像&#xff0c;并且上传到服务器 https://mirrors.tuna.tsinghua.edu.cn/gitlab-ce/yum/el7/gitlab-ce-16.2.8-ce.0.el7.x86_64.rpm &#xff08;2&#xff09;rpm -i gitlab-ce-16.2.8-ce.0.el7.x86_64.rpm &#xff08;3&#xff09;安装成功后…

vue elementui <el-date-picker>日期选择框限制只能选择90天内的日期(包括今天)

之前也写过其他限制日期的语句&#xff0c;感觉用dayjs()的subtract()和add()也挺方便易懂的&#xff0c;以此记录 安装dayjs npm install dayjs --save dayjs().add(value : Number, unit : String); dayjs().add(7, day); //在当前的基础上加7天dayjs().subtract(value : N…

在硅云上主机搭建wordpress并使用Astra主题和avada主题

目录 前言 准备 操作 DNS解析域名 云主机绑定域名 安装wordpress网站程序 网站内Astra主题设计操作 安装主题 网站内avada主题安装 上传插件 上传主题 选择网站主题 前言 一开始以为云虚拟主机和云服务器是一个东西&#xff0c;只不过前者是虚拟的后者是不是虚拟的…

练[GYCTF2020]EasyThinking

[GYCTF2020]EasyThinking 文章目录 [GYCTF2020]EasyThinking掌握知识解题思路还得靠大佬正式开始 关键paylaod 掌握知识 ​ thinkphpV6任意文件操作漏洞&#xff0c;代码分析写入session文件的参数&#xff0c;源码泄露&#xff0c;使用蚁剑插件disable_functions绕过终端无回…

尚硅谷CSS学习笔记

什么是css css&#xff08;层叠样式表&#xff09; 它是一种标记语言&#xff0c;用于给HTML结构设置样式。简单理解css可以美化html&#xff0c;实现结构与样式的分离。 <link rel"shortcut icon" href"favicon.ico" type"image/x-icon"&g…