人工智能|机器学习——强大的 Scikit-learn 可视化让模型说话

一、显示 API 简介

使用 utils.discovery.all_displays 查找可用的 API。

Sklearn 的utils.discovery.all_displays可以让你看到哪些类可以使用。

from sklearn.utils.discovery import all_displays
displays = all_displays()
displays

Scikit-learn (sklearn) 总是会在新版本中添加 "Display "API,因此这里可以了解你的版本中有哪些可用的 API 。例如,在我的 Scikit-learn 1.4.0 中,就有这些类:

[('CalibrationDisplay', sklearn.calibration.CalibrationDisplay),('ConfusionMatrixDisplay',sklearn.metrics._plot.confusion_matrix.ConfusionMatrixDisplay),('DecisionBoundaryDisplay',sklearn.inspection._plot.decision_boundary.DecisionBoundaryDisplay),('DetCurveDisplay', sklearn.metrics._plot.det_curve.DetCurveDisplay),('LearningCurveDisplay', sklearn.model_selection._plot.LearningCurveDisplay),('PartialDependenceDisplay',sklearn.inspection._plot.partial_dependence.PartialDependenceDisplay),('PrecisionRecallDisplay',sklearn.metrics._plot.precision_recall_curve.PrecisionRecallDisplay),('PredictionErrorDisplay',sklearn.metrics._plot.regression.PredictionErrorDisplay),('RocCurveDisplay', sklearn.metrics._plot.roc_curve.RocCurveDisplay),('ValidationCurveDisplay',sklearn.model_selection._plot.ValidationCurveDisplay)]

二、显示决策边界

使用 inspection.DecisionBoundaryDisplay 显示决策边界

如果使用 Matplotlib 来绘制,会很麻烦:

  • 使用 np.linspace 设置坐标范围;

  • 使用 plt.meshgrid 计算网格;

  • 使用 plt.contourf 绘制决策边界填充;

  • 然后使用 plt.scatter 绘制数据点。

现在,使用 inspection.DecisionBoundaryDisplay 可以简化这一过程:

from sklearn.inspection import DecisionBoundaryDisplay
from sklearn.datasets import load_iris
from sklearn.svm import SVC
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as pltiris = load_iris(as_frame=True)
X = iris.data[['petal length (cm)', 'petal width (cm)']]
y = iris.targetsvc_clf = make_pipeline(StandardScaler(), SVC(kernel='linear', C=1))
svc_clf.fit(X, y)display = DecisionBoundaryDisplay.from_estimator(svc_clf, X, grid_resolution=1000,xlabel="Petal length (cm)",ylabel="Petal width (cm)")
plt.scatter(X.iloc[:, 0], X.iloc[:, 1], c=y, edgecolors='w')
plt.title("Decision Boundary")
plt.show()

使用 DecisionBoundaryDisplay 绘制三重分类模型。

请记住,Display 只能绘制二维数据,因此请确保数据只有两个特征或更小的维度。

三、概率校准

要比较分类模型,使用 calibration.CalibrationDisplay 进行概率校准,概率校准曲线可以显示模型预测的可信度。

CalibrationDisplay使用的是模型的 predict_proba。如果使用支持向量机,需要将 probability 设为 True:

from sklearn.calibration import CalibrationDisplay
from sklearn.model_selection import train_test_split
from sklearn.datasets import make_classification
from sklearn.ensemble import HistGradientBoostingClassifierX, y = make_classification(n_samples=1000,n_classes=2, n_features=5,random_state=42)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
proba_clf = make_pipeline(StandardScaler(), SVC(kernel="rbf", gamma="auto", C=10, probability=True))
proba_clf.fit(X_train, y_train)CalibrationDisplay.from_estimator(proba_clf, X_test, y_test)hist_clf = HistGradientBoostingClassifier()
hist_clf.fit(X_train, y_train)ax = plt.gca()
CalibrationDisplay.from_estimator(hist_clf,X_test, y_test,ax=ax)
plt.show()

CalibrationDisplay.

四、显示混淆矩阵

在评估分类模型和处理不平衡数据时,需要查看精确度和召回率。使用 metrics.ConfusionMatrixDisplay绘制混淆矩阵(TP、FP、TN 和 FN)。

from sklearn.datasets import fetch_openml
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import ConfusionMatrixDisplaydigits = fetch_openml('mnist_784', version=1)
X, y = digits.data, digits.target
rf_clf = RandomForestClassifier(max_depth=5, random_state=42)
rf_clf.fit(X, y)ConfusionMatrixDisplay.from_estimator(rf_clf, X, y)
plt.show()

五、Roc 和 Det 曲线

因为经常并列评估Roc 和 Det 曲线,因此把metrics.RocCurveDisplay 和 metrics.DetCurveDisplay两个图表放在一起。

  • RocCurveDisplay比较模型的 TPR 和 FPR。对于二分类,希望 FPR 低而 TPR 高,因此左上角是最佳位置。Roc 曲线向这个角弯曲。

由于 Roc 曲线停留在左上角附近,右下角是空的,因此很难看到模型差异。

  • 使用 DetCurveDisplay 绘制一条带有 FNR 和 FPR 的 Det 曲线。它使用了更多空间,比 Roc 曲线更清晰。Det 曲线的最佳点是左下角。

from sklearn.metrics import RocCurveDisplay
from sklearn.metrics import DetCurveDisplayX, y = make_classification(n_samples=10_000, n_features=5,n_classes=2, n_informative=2)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42,stratify=y)classifiers = {"SVC": make_pipeline(StandardScaler(), SVC(kernel="linear", C=0.1, random_state=42)),"Random Forest": RandomForestClassifier(max_depth=5, random_state=42)
}fig, [ax_roc, ax_det] = plt.subplots(1, 2, figsize=(10, 4))
for name, clf in classifiers.items():clf.fit(X_train, y_train)RocCurveDisplay.from_estimator(clf, X_test, y_test, ax=ax_roc, name=name)DetCurveDisplay.from_estimator(clf, X_test, y_test, ax=ax_det, name=name)

六、调整阈值

在数据不平衡的情况下,希望调整召回率和精确度。可以使用使用 metrics.PrecisionRecallDisplay 调整阈值

  • 对于电子邮件欺诈,需要高精确度。

  • 而对于疾病筛查,则需要高召回率来捕获更多病例。

那么可以调整阈值,但调整多少才合适呢?因此可以使用metrics.PrecisionRecallDisplay 来绘制相关图表。

from xgboost import XGBClassifier
from sklearn.datasets import load_wine
from sklearn.metrics import PrecisionRecallDisplaywine = load_wine()
X, y = wine.data[wine.target<=1], wine.target[wine.target<=1]
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3,stratify=y, random_state=42)xgb_clf = XGBClassifier()
xgb_clf.fit(X_train, y_train)PrecisionRecallDisplay.from_estimator(xgb_clf, X_test, y_test)
plt.show()

这表明可以按照 Scikit-learn 的设计绘制模型,就像这里的 xgboost

七、回归模型评估

Scikit-learn 的 metrics.PredictionErrorDisplay 绘制残差图可以帮助评估回归模型。

from sklearn.svm import SVR
from sklearn.metrics import PredictionErrorDisplayrng = np.random.default_rng(42)
X = rng.random(size=(200, 2)) * 10
y = X[:, 0]**2 + 5 * X[:, 1] + 10 + rng.normal(loc=0.0, scale=0.1, size=(200,))reg = make_pipeline(StandardScaler(), SVR(kernel='linear', C=10))
reg.fit(X, y)fig, axes = plt.subplots(1, 2, figsize=(8, 4))
PredictionErrorDisplay.from_estimator(reg, X, y, ax=axes[0], kind="actual_vs_predicted")
PredictionErrorDisplay.from_estimator(reg, X, y, ax=axes[1], kind="residual_vs_predicted")
plt.show()

图表展示预测值与实际值的比较,左图适合线性回归。然而,并非所有数据都是完全线性的,因此,请参考右图。右图展示了实际值与预测值的差异,即残差图。残差图的香蕉形状暗示我们的数据可能不适合线性回归。考虑将核函数从"线性" 转换为 "rbf" ,残差图会更好。

reg = make_pipeline(StandardScaler(), SVR(kernel='rbf', C=10))

八、绘制学习曲线

学习曲线主要研究模型的泛化效果和训练测试数据之间的差异或偏差。接下来,使用 model_selection.LearningCurveDisplay 绘制学习曲线,并比较了决策树分类器和梯度提升分类器在不同训练数据下的表现。

from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.model_selection import LearningCurveDisplayX, y = make_classification(n_samples=1000, n_classes=2, n_features=10,n_informative=2, n_redundant=0, n_repeated=0)tree_clf = DecisionTreeClassifier(max_depth=3, random_state=42)
gb_clf = GradientBoostingClassifier(n_estimators=50, max_depth=3, tol=1e-3)train_sizes = np.linspace(0.4, 1.0, 10)
fig, axes = plt.subplots(1, 2, figsize=(10, 4))
LearningCurveDisplay.from_estimator(tree_clf, X, y,train_sizes=train_sizes,ax=axes[0],scoring='accuracy')
axes[0].set_title('DecisionTreeClassifier')
LearningCurveDisplay.from_estimator(gb_clf, X, y,train_sizes=train_sizes,ax=axes[1],scoring='accuracy')
axes[1].set_title('GradientBoostingClassifier')
plt.show()

从图中可以看出,虽然基于树的 GradientBoostingClassifier 在训练数据上保持了良好的准确性,但其在测试数据上的泛化能力与 DecisionTreeClassifier 相比并无明显优势。

九、可视化参数调整

为了改善泛化效果差的模型,可以尝试通过调整正则化参数来提高性能。传统的方法是使用 "GridSearchCV" 或 "Optuna" 等工具来实现模型调整,然而这些方法只能找出整体表现最佳的模型,且调整过程并不直观。如果需要调整特定参数以测试其对模型的影响,建议使用 model_selection.ValidationCurveDisplay 来直观地观察模型在参数变化时的表现。

from sklearn.model_selection import ValidationCurveDisplay
from sklearn.linear_model import LogisticRegressionparam_name, param_range = "C", np.logspace(-8, 3, 10)
lr_clf = LogisticRegression()ValidationCurveDisplay.from_estimator(lr_clf, X, y,param_name=param_name,param_range=param_range,scoring='f1_weighted',cv=5, n_jobs=-1)
plt.show()

十、讨论

尝试过所有这些显示后,我必须承认一些遗憾:

  • 最大的遗憾是这些 API 大多数缺乏详细的教程,这可能也是与 Scikit-learn 的详尽文档相比不为人知的原因。

  • 这些应用程序接口散布在不同的软件包中,因此很难从一个地方引用它们。

  • 代码仍然非常基础。通常需要将其与 Matplotlib 的 API 搭配使用才能完成工作。一个典型的例子是 "DecisionBoundaryDisplay",在绘制决策边界后,还需要使用 Matplotlib 来绘制数据分布。

  • 它们很难扩展。除了一些验证参数的方法外,很难用工具或方法来简化模型的可视化过程;最终需要重写了很多东西。

这些 API 希望得到更多关注,并且随着版本升级,可视化 API 也能更易用。

在机器学习中,用可视化方式解释模型与训练模型同样重要。

本文介绍了当前版本 scikit-learn 中的各种绘图 API,利用这些 API,可以简化一些 Matplotlib 代码,缓解学习曲线,并简化模型评估过程。由于篇幅有限,未对每个 API 进行详细介绍。如果有兴趣,可以查看 [官方文档:https://scikit-learn.org/stable/visualizations.html?ref=dataleadsfuture.com] 了解更多详情。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/833357.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

无人零售,重塑购物新纪元

在这个快节奏的时代&#xff0c;科技的每一次跃进都在悄无声息地改变着我们的生活方式。而今&#xff0c;无人零售正以雷霆之势&#xff0c;颠覆传统购物模式&#xff0c;为我们带来前所未有的便捷与智能体验。想知道无人零售如何彻底改变我们的购物方式吗&#xff1f;跟随我&a…

市场营销的酒店营销策略研究意义

在市场经济条件下&#xff0c;市场营销策略已成为企业经营管理中最重要的组成部分&#xff0c;其在企业管理中的地位日益显现出来。 然而&#xff0c;由于酒店营销环境的特殊性&#xff0c;酒店营销策略研究一直是咱们从业者研究的热点之一。 对于酒店营销策略的研究&#xf…

uts插件开发-继uniapp原生插件nativeplugins,uts插件开发可直接操作原生安卓sdk等,支持uniappx,支持源码授权价格等等

1.创建uts项目 2.创建uts插件cf-takepic 3.在index.uts中编写原生安卓代码&#xff0c;首先定义一个函数方法&#xff0c;在页面中看是否可引用成功 uts函数代码 /*** 拍照函数*/ export const takepicfunction():void{console.log("11111111") } index.vue代码 …

详细分析Mybatis与MybatisPlus中分页查询的差异(附Demo)

目录 前言1. Mybatis2. MybatisPlus3. 实战 前言 更多的知识点推荐阅读&#xff1a; 【Java项目】实战CRUD的功能整理&#xff08;持续更新&#xff09;java框架 零基础从入门到精通的学习路线 附开源项目面经等&#xff08;超全&#xff09; 本章节主要以Demo为例&#xff…

C# winform 连接mysql数据库(navicat)

1.解决方案资源管理器->右键->管理NuGet程序包->搜索&#xff0c; 安装Mysql.Data 2.解决方案资源管理器->右键->添加->引用->浏览-> C:\Program Files (x86)\MySQL\MySQL Installer for Windows ->选择->MySql.Data.dll 3.解决方案资源管理器…

深入剖析Tomcat(七) 日志记录器

在看原书第六章之前&#xff0c;一直觉得Tomcat记日志的架构可能是个“有点东西”的东西。在看了第六章之后呢&#xff0c;额… 就这&#xff1f;不甘心的我又翻了翻logback与新版tomcat的源码&#xff0c;额…&#xff0c;日志架构原来也没那么神秘。本篇文章先过一遍原书内容…

Github 2024-05-07 开源项目日报 Tp10

根据Github Trendings的统计,今日(2024-05-07统计)共有10个项目上榜。根据开发语言中项目的数量,汇总情况如下: 开发语言项目数量TypeScript项目4Jupyter Notebook项目2Python项目1Batchfile项目1非开发语言项目1Java项目1HTML项目1C#项目1从零开始构建你喜爱的技术 创建周期…

k8s 资源文件参数介绍

Kubernetes资源文件yaml参数介绍 yaml 介绍 yaml 是一个类似 XML、JSON 的标记性语言。它强调以数据为中心&#xff0c;并不是以标识语言为重点例如 SpringBoot 的配置文件 application.yml 也是一个 yaml 格式的文件 语法格式 通过缩进表示层级关系不能使用tab进行缩进&am…

金三银四面试题(二十五):策略模式知多少?

什么是策略模式 策略模式&#xff08;Strategy Pattern&#xff09;是一种行为型设计模式&#xff0c;旨在定义一系列算法&#xff0c;将每个算法封装到一个独立的类中&#xff0c;使它们可以互换。策略模式让算法的变化独立于使用它们的客户端&#xff0c;使得客户端可以根据…

二叉搜索树相关

二叉搜索树 定义&#xff1a;对二叉搜索树的一些操作基本结构Insert操作Find操作Erase操作 InOrder遍历二叉树操作模拟字典模拟统计次数 定义&#xff1a; 二叉搜索树又称二叉排序树&#xff0c;它或者是一棵空树&#xff0c;或者是具有以下性质的二叉树:若它的左子树不为空&a…

MacOS快速安装FFmpeg,并使用FFmpeg转换视频

前言&#xff1a;目前正在接入flv视频流&#xff0c;但是没有一个合适的flv视频流地址。网上提供的flv也都不是H264AAC&#xff08;一种视频和音频编解码器组合&#xff09;&#xff0c;所以想通过fmpeg来将flv文件转换为H264AAC。 一、MacOS环境 博主的MacOS环境&#xff08;…

初始C++(一)

目录 前言&#xff1a; 命名空间&#xff1a; 总结&#xff1a; 前言&#xff1a; C语言学好了&#xff0c;现在当然要进阶了&#xff0c;那么就是从C开始。 C兼容C&#xff0c;支持其中90%的语法。可能有很多同学听说过C#&#xff0c;C#和C没有关系&#xff0c;是微软研究出…

Mintegral引领短剧行业新风尚:广告变现策略助力出海应用高效增长

短剧&#xff0c;一颗正在冉冉升起的新星&#xff0c;如今成为了影视行业的新风口。《2023短剧行业研究报告》显示&#xff0c;2023年短剧市场规模达到373.9亿元&#xff0c;同比增长267.65%&#xff0c;预计2024年将超过500亿元。近年来&#xff0c;政策出台、供需促进及应用出…

什么是Facebook付费广告营销?

Facebook作为全球最大的社交平台之一&#xff0c;成为了跨境卖家不可或缺的营销阵地。它不仅拥有庞大的用户基数&#xff0c;还提供了丰富的广告工具和社群互动功能&#xff0c;让商家能够精准触达目标市场&#xff0c;提升品牌影响力。云衔科技通过Facebook付费广告营销的专业…

ODOO17数据库安全策略一(ODOO17 Database Security Policy I)

ODOO17作为ERP软件&#xff0c;其核心优势在于数据安全。凭借强大的原生安全机制及灵活的配置&#xff0c;确保数据安全无忧&#xff1a; ODOO17, as an ERP software, boasts its significant advantage in exceptional data security performance. It effectively ensures wo…

DirClass

DirClass 通过分析&#xff0c;发现当接收到DirClass远控指令后&#xff0c;样本将返回指定目录的目录信息&#xff0c;返回数据中的远控指令为0x2。 相关代码截图如下&#xff1a; DelDir 通过分析&#xff0c;发现当接收到DelDir远控指令后&#xff0c;样本将删除指定目录…

2024暨南大学校赛热身赛解析

文章目录 Uzi 的真身D 基站建设 Uzi 的真身 分析&#xff1a;本来想使用动态规划算法的&#xff0c;但是由于数据计算过大&#xff0c;导致超时 正确的思想&#xff1a;直接线性遍历字符串&#xff0c;分别统计字符“j”的个数&#xff0c;后面对于统计字符“z”和字符“h” 的…

flutter中固定底部按钮,防止键盘弹出时按钮跟随上移

当我们想要将底部按钮固定在底部&#xff0c;我们只需在Widget中的Scaffold里面加一句 resizeToAvoidBottomInset: false, // 设置为false&#xff0c;固定页面不会因为键盘弹出而移动 效果图如下

[法规规划|数据概念]金融行业数据资产和安全管理系列文件解析(2)

“ 金融行业在自身数据治理和资产化建设方面一直走在前列。” 一直以来&#xff0c;金融行业由于其自身需要&#xff0c;都是国内开展信息化建设最早&#xff0c;信息化程度最高的行业。 在当今数据要素资产化的浪潮下&#xff0c;除了行业自身自身数据治理和资产化建设方面&am…

蓝牙模块HC-08+WIFI模块ESP-01S

蓝牙模块 又叫蓝牙串口模块。 串口透传技术&#xff1a;透传即透明传送&#xff0c;是指在数据的传输过程中&#xff0c;通过无线的方式使这组数据不发生任何形式的改变&#xff0c;仿佛传输过程是透明的一样&#xff0c;同时保证传输的质量&#xff0c;原封不动地道了最终接收…