【PyCaret】使用PyCaret创建机器学习Pipeline进行多分类任务

  发现一个好东西,PyCaret机器学习Pipeline,记录一下用其进行多分类任务的使用方法。


1、简介

  PyCaret是一个开源的、不用写很多代码的Python机器学习库,可以自动化机器学习工作流程,是一个端到端的机器学习和模型管理工具,可以成倍地加快实验周期,提高工作效率。
  PyCaret本质上是几个机器学习库和框架的封装,比如scikit-learn、XGBoost、LightGBM、CatBoost、spaCy、Optuna、Hyperopt、Ray等等。
  一字诗:棒~


2、安装PyCaret

安装命令:

pip install pycaret

安装后测试:

import pycaret
pycaret.__version__
'3.3.0'

3、PyCaret建模

  PyCaret中一个典型的工作流程由以下5个步骤组成:
  Setup ➡️ Compare Models ➡️ Analyze Model ➡️ Prediction ➡️ Save Model

首先,从pycaret数据集模块加载样本数据集(鸢尾花)

from pycaret.datasets import get_data
data = get_data('iris')

非常不幸,在第一步就夭折了…

报错: requests.exceptions.ConnectionError: HTTPSConnectionPool(host=‘raw.githubusercontent.com’, port=443): Max retries exceeded with url: /pycaret/datasets/main/data/common/iris.csv (Caused by NewConnectionError(‘<urllib3.connection.HTTPSConnection object at 0x00000224EF2D0C40>: Failed to establish a new connection: [Errno 11004] getaddrinfo failed’))

原因: https://raw.githubusercontent.com/pycaret/datasets/main/ 这个网址打不开,咋办呢,没有条件创造条件也要上…

解决: 发现 get_data(‘iris’) 加载的数据集应该是如下的 dataframe 形式,一般情况下自己的数据集应该也是这样子的,因此我们把 sklearn.datasets 的鸢尾花数据集重建为 dataframe 形式就可以啦~

在这里插入图片描述

数据集构建代码:

from sklearn.datasets import load_iris
import pandas as pd
target = load_iris().target
target_names = load_iris().target_names
mapping = {'0': target_names[0], '1': target_names[1], '2': target_names[2]}
df_data = pd.DataFrame(load_iris().data, columns=['sepal_length', 'sepal_width', 'petal_length', 'petal_width'])
target_str = pd.DataFrame([mapping[str(num)] for num in target], columns=['species'])
data = pd.concat([df_data, target_str], axis=1)

检查一下数据格式:

在这里插入图片描述

完美,可以继续啦~

3.1 Setup

  Setup函数初始化训练环境并创建transformation pipeline。Setup函数必须在执行PyCaret中的任何其他函数之前调用,只有两个必需的参数,data和target,其他参数均为可选参数。

from pycaret.classification import *
s = setup(data, target = 'species', session_id = 123)

Setup成功执行后,会显示以下实验信息:

在这里插入图片描述

信息说明:
  (1)Session id:随机数种子;
  (2)Target type:自动检测目标类型,二分类、多分类还是回归;
  (3)Target mapping:标签编码,字符串映射为0、1;
  (4)Original data shape:原始数据大小;
  (5)Transformed train set shape:训练集大小;
  (6)Transformed test set shape:测试集大小;
  (7)Numeric features:数字特征的数量;

3.2 Compare Models

  compare_models函数使用交叉验证训练和评估模型库中可用模型的性能,其输出是平均交叉验证分数。

比较基线模型:

best = compare_models()

输出默认按ACC排序:

在这里插入图片描述
打印最优模型:

print(best)

输出为最优模型的参数:

LogisticRegression(C=1.0, class_weight=None, dual=False, fit_intercept=True,intercept_scaling=1, l1_ratio=None, max_iter=1000,multi_class='auto', n_jobs=None, penalty='l2',random_state=123, solver='lbfgs', tol=0.0001, verbose=0,warm_start=False)

我靠,发没发现,又出问题了,AUC怎么不显示啊啊啊啊啊…学习的路上总是充满坎坷…

调查了一下这个问题,发现这是一个很新的问题,但似乎并没有被解决,大家可以去看看,似乎是我的 PyCaret == 3.3.0 和 scikit-learn==1.4.1.post1 不太匹配的问题:
https://github.com/pycaret/pycaret/pull/3935
https://github.com/pycaret/pycaret/issues/3932

倔强的我,在linux环境中重新配了PyCaret == 3.2.0, scikit-learn==1.0.2,这下可以显示AUC了,舒服了~

在这里插入图片描述

后面哪位朋友解决了3.3.0的AUC不显示问题,记得踢我一下喔~

3.3 Analyze Model

(1)画混淆矩阵

plot_model(best, plot = 'confusion_matrix')

在这里插入图片描述
(2)画AUC曲线

plot_model(best, plot = 'auc')

在这里插入图片描述

这时候AUC又行了…显着你了…估计前面是哪传参数有问题…

(3)画特征重要性

plot_model(best, plot = 'feature')

在这里插入图片描述

3.4 Prediction

  predict_model函数返回 prediction_label 和 prediction_score(预测类的概率)作为数据表中新的列。当data为None(默认)时,它使用测试集(在setup函数期间创建)进行评分。

holdout_pred = predict_model(best)

指标结果:

在这里插入图片描述
返回的dataframe:

在这里插入图片描述

3.5 Save Model

  使用pycaret的save_model函数将整个Pipeline进行保存

save_model(best, 'iris_pipeline')

保存后是一个pkl文件:

在这里插入图片描述

保存后的模型再加载:

loaded_best_pipeline = load_model('iris_pipeline')

4、代码整合

from sklearn.datasets import load_iris
import pandas as pd
from pycaret.classification import *# 数据集加载
target = load_iris().target
target_names = load_iris().target_names
mapping = {'0': target_names[0], '1': target_names[1], '2': target_names[2]}
df_data = pd.DataFrame(load_iris().data, columns=['sepal_length', 'sepal_width', 'petal_length', 'petal_width'])
target_str = pd.DataFrame([mapping[str(num)] for num in target], columns=['species'])
data = pd.concat([df_data, target_str], axis=1)# Setup
s = setup(data, target='species', session_id=123)# Compare Models
best = compare_models()
print(best)# Analyze Model
plot_model(best, plot = 'confusion_matrix')
plot_model(best, plot = 'auc')
plot_model(best, plot = 'feature')# Prediction
holdout_pred = predict_model(best)# Save Model
save_model(best, 'iris_pipeline')# Load Model
loaded_best_pipeline = load_model('iris_pipeline')

代码是非常简洁明了的,但封装的太好了,有些想改的也不好改了~


参考资料:PyCaret Multiclass Classification Tutorial
更多学习:用PyCaret创建整个机器学习管道
PyCaret的github仓库:https://github.com/pycaret/pycaret/tree/master


最后说一句,PyCaret的Pipeline还是用JupyterLab运行最舒服,Spyter运行不显示,Pycharm运行不好看…

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/761658.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

登录与注册功能(简单版)(3)登录时使用Cookie增加记住我功能

目录 1、实现分析 2、步骤 1&#xff09;新建login.jsp 2&#xff09;修改LoginServlet&#xff1a; 3&#xff09;启动访问&#xff1a; 3、安全性考虑 4、最佳实践思路 1&#xff09;选择安全的认证机制 2&#xff09;强化会话管理 3&#xff09;安全地存储用户凭证…

Unity 粒子在UI中使用时需要注意的地方

最近项目中要在UI中挂载粒子特效,美术给过来的粒子直接放到UI中会有一些问题,查询一些资料后,总结了一下 一: 粒子的大小发生变化,与在预制件编辑中设计的大小不同 在预制件编辑模式下,大小正常 实际使用的时候特别大或者特别小 经过检查,发现预制件编辑模式下,默认画布的Rend…

上位机图像处理和嵌入式模块部署(qmacvisual点线测量)

【 声明&#xff1a;版权所有&#xff0c;欢迎转载&#xff0c;请勿用于商业用途。 联系信箱&#xff1a;feixiaoxing 163.com】 上面一篇文章&#xff0c;我们完成了直线的拟合操作。在实际场景中&#xff0c;拟合之后更多地是需要进行长度的测量。既然是测量&#xff0c;那么…

【JavaScript 漫游】【040】Blob 对象

文章简介 本篇文章为【JavaScript 漫游】专栏的第 040 篇文章&#xff0c;对浏览器模型中的 Blob 对象进行了总结。 概述 Blob 对象表示一个二进制文件的数据内容&#xff0c;比如一个图片文件的内容就可以通过 Blob 对象读写。它通常用来读写文件&#xff0c;它的名字是 Bi…

哪些企业适合构建企业新媒体矩阵?

⭐关注矩阵通服务号&#xff0c;探索企业新媒体矩阵搭建与营销策略 新媒体矩阵就是在某个平台或多个平台开设、联动多个账号&#xff0c;组建有关系的不同账号集群。 在数字化转型的浪潮下&#xff0c;矩阵已然成为企业实现品牌塑造、市场开拓与用户互动的重要阵地。 然而&…

2024Android-目前最稳定和高效的UI适配方案!你头秃都没想到还能这样吧!

但是这个方案有一个致命的缺陷&#xff0c;那就是需要精准命中才能适配&#xff0c;比如1920x1080的手机就一定要找到1920x1080的限定符&#xff0c;否则就只能用统一的默认的dimens文件了。而使用默认的尺寸的话&#xff0c;UI就很可能变形&#xff0c;简单说&#xff0c;就是…

记一次由于buff/cache导致服务器内存爆满的问题

目录 前言 复现 登录服务器查看占用内存进程排行 先了解一下什么是buff/cache&#xff1f; 尝试释放buffer/cache /proc/sys/vm/drop_caches dirty_ratio dirty_background_ratio dirty_writeback_centisecs dirty_expire_centisecs drop_caches page-cluster swap…

ideaSSM 人才引进管理系统bootstrap开发mysql数据库web结构java编程计算机网页源码maven项目

一、源码特点 idea 开发 SSM 人才引进管理系统是一套完善的信息管理系统&#xff0c;结合SSM框架和bootstrap完成本系统&#xff0c;对理解JSP java编程开发语言有帮助系统采用SSM框架&#xff08;MVC模式开发&#xff09;&#xff0c;系统具有完整的源代码和数据库&#xff…

分布式链上随机数和keyless account

1. 引言 相关论文见&#xff1a; Aptos团队2024年论文 Distributed Randomness using Weighted VRFs 相关代码实现见&#xff1a; https://github.com/aptos-labs/aptos-core&#xff08;Rust&#xff09; 在链中生成和集成共享随机数&#xff0c;以扩展应用和强化安全。该…

G - Find a way

题目分析 1.双重bfs,遍历两个起点求最短路再计算总和即可 2.唯一的坑点在于对于一个KFC&#xff0c;两人中可能有一个到不了&#xff0c;所以还要对到不了的点距离做处理 #include <bits/stdc.h> using namespace std; using ll long long; const int N 220;struct pos…

交通事故档案管理系统|基于JSP技术+ Mysql+Java+Tomcat的交通事故档案管理系统设计与实现(可运行源码+数据库+设计文档)

推荐阅读100套最新项目 最新ssmjava项目文档视频演示可运行源码分享 最新jspjava项目文档视频演示可运行源码分享 最新Spring Boot项目文档视频演示可运行源码分享 2024年56套包含java&#xff0c;ssm&#xff0c;springboot的平台设计与实现项目系统开发资源&#xff08;可…

docker 修改日志存储路径

docker 日志默认存放在 /var/lib/docker/ 下 docker info修改步骤&#xff1a; 1、停止docker服务 systemctl stop docker 2、新建配置文件 vi /etc/docker/daemon.json添加如下内容 {"data-root": "/data/docker" }3、然后把之前的数据全部复制到新目…

十、C#基数排序算法

简介 基数排序是一种非比较性排序算法&#xff0c;它通过将待排序的数据拆分成多个数字位进行排序。 实现原理 首先找出待排序数组中的最大值&#xff0c;并确定排序的位数。 从最低位&#xff08;个位&#xff09;开始&#xff0c;按照个位数的大小进行桶排序&#xff0c;将…

将OpenCV与gdb驱动的IDE结合使用

返回&#xff1a;OpenCV系列文章目录&#xff08;持续更新中......&#xff09; 上一篇&#xff1a;OpenCV4.9.0开源计算机视觉库在 Linux 中安装 下一篇&#xff1a;将OpenCV与gcc和CMake结合使用 ​ 能力 这个漂亮的打印机可以显示元素类型、、标志is_continuous和is_subm…

【Java常用API】简单爬虫练习题

&#x1f36c; 博主介绍&#x1f468;‍&#x1f393; 博主介绍&#xff1a;大家好&#xff0c;我是 hacker-routing &#xff0c;很高兴认识大家~ ✨主攻领域&#xff1a;【渗透领域】【应急响应】 【Java】 【VulnHub靶场复现】【面试分析】 &#x1f389;点赞➕评论➕收藏 …

0基础学习VR全景平台篇第146篇:为什么需要3D元宇宙编辑器?

一.什么是3D元宇宙编辑器&#xff1f; 3D元宇宙编辑器是全新3DVR交互渲染创作工具&#xff0c;集3D建模、虚拟展厅、AI数字人等能力&#xff0c;渲染和虚拟现实技术于一身的生产力工具。 具有跨平台和随时随地编辑等特点&#xff0c;可广泛应用于展会、展厅、博物馆、可视化园…

uniapp_微信小程序客服

一、调用api 二、代码 <button open-type"contact">客服</button> 三、小程序后台添加客服人员就行

Ubuntu学习笔记之Shell与APT下载工具

基本都是摘抄正点原子的文章&#xff1a;<领航者 ZYNQ 之嵌入式Linux 开发指南 V3.2.pdf&#xff0c;因初次学习&#xff0c;仅作学习摘录之用&#xff0c;有不懂之处后续会继续更新~ 一、Ubuntu Shell操作 简单的说Shell 就是敲命令。国内把 Linux 下通过命令行输入命令叫…

CSS隐藏video标签中各种控件

1.edio标签加上controls会出现视频控件&#xff0c;如播放按钮、进度条、全屏、观看的当前时间、剩余时间、音量按钮、音量的控制条等等 <video type"video/mp4" src"" autoplay"" style"width: 400px; height: 300px;" id"e…

idea 2023 spring initializr 没有JDK1.8选项的解决方法

在升级最新版本的IDEA后,新建项目里面的 spring initializr的选项里面已经没有了JDK1.8的选项了,原因是spring官方的initializr https://start.spring.io/ 现在主推3.x版本这个最低要求是JDK17, 解决方法: 将IDEA默认的 Initializr的URL https://start.spring.io/换成第三方…