机器学习的模型校准

背景知识

之前一直没了解过模型校准是什么东西,最近上班业务需要看了一下:

模型校准是指对分类模型进行修正以提高其概率预测的准确性。在分类模型中,预测结果通常以类别标签形式呈现(例如,0或1),但有时我们更关注的是预测的概率。

当使用某些分类模型(例如支持向量机(SVM)或随机森林)时,其预测的概率并不一定与真实标签的概率分布相匹配。这意味着,即使预测概率较高的类别出现的频率更高,模型的预测概率也可能偏离真实情况。这可能导致对模型的概率输出有误解,或者在需要高度依赖概率预测的任务(例如风险评估或阈值选择)中出现问题。

通过校准分类模型,我们可以将模型的预测概率调整为更准确地反映真实情况。`CalibratedClassifierCV`是Scikit-learn库中提供的用于校准分类器的类。它根据指定的校准方法(`method`),通过拟合后的分类器(`model`)和交叉验证拟合(`cv='prefit'`)来创建一个经过校准的分类器(`calibrated_model`)。

在代码中,使用`calibrated_model.fit(X_train, y_train)`通过使用交叉验证拟合来训练、校准模型。之后,使用`calibrated_model.predict(X_test)`对测试集进行预测,并使用`classification_report`输出校准模型的分类性能报告。

通过校准分类模型,我们可以使得模型的概率预测更为准确,从而提高在概率判断和相关任务中的性能和可靠性。


代码实现

模型校准主要是针对分类模型的,我之前都是做回归,难怪没怎么接触过。也没空找真实数据了,直接模拟数据来实现一下。

导入包和制作数据集

import numpy as np
import pandas as pd
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import classification_report
from sklearn.calibration import calibration_curve
from sklearn.ensemble import RandomForestClassifierimport matplotlib.pyplot as plt# 生成二分类数据集
X, y = make_classification(n_samples=10000, n_features=40, n_classes=2, weights=[0.9, 0.1], random_state=2, flip_y=0.3)

查看分布:
 

pd.Series(y).value_counts()

不平衡样本。

标准化,划分训练集测试集

# 标准化数据
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X_scaled, y,stratify=y, test_size=0.2, random_state=2)

训练,然后评价,这里就弄了个随机森林模型试试

# 模型训练
model =RandomForestClassifier()
model.fit(X_train, y_train)# 模型评价
y_pred = model.predict(X_test)
print("Classification Report:")
print(classification_report(y_test, y_pred))


画校准曲线。

# 计算校准曲线
prob_true, prob_pred = calibration_curve(y_test, model.predict_proba(X_test)[:, 1], n_bins=10)# 绘制校准曲线
plt.figure(figsize=(7, 4),dpi=128)
plt.plot(prob_pred, prob_true, marker='o', label='uncalibrated')
plt.plot([0, 1], [0, 1], linestyle='--', color='gray', label='perfectly calibrated')
plt.xlabel('Mean predicted probability')
plt.ylabel('Fraction of positives')
plt.title('Calibration Curve (Uncalibrated)')
plt.legend()
plt.show()

这玩意怎么看,,,我也不太懂,反正就是要单调,并且越靠近对角线越好。这个明显在0.1-0.2区间不单调,还有0.7-0.9也在下降。

来校准一下:


模型校准

模型校准很多方法,目前这个是用了 method='sigmoid',这个方法,好像叫做什么p系数校准。

from sklearn.calibration import CalibratedClassifierCV
calibrated_model = CalibratedClassifierCV(model, method='sigmoid', cv='prefit')
calibrated_model.fit(X_train, y_train)# 模型评价(校准后)
y_pred_calibrated = calibrated_model.predict(X_test)
print("Classification Report (Calibrated Model):")
print(classification_report(y_test, y_pred_calibrated))

emmm,效果好像没有明显提升。

method='isotonic',这个是什么保序回归方法校准。

calibrated_model2 = CalibratedClassifierCV(model, method='isotonic', cv='prefit')
calibrated_model2.fit(X_train, y_train)# 模型评价(校准后)
y_pred_calibrated2 = calibrated_model2.predict(X_test)
print("Classification Report (Calibrated Model):")
print(classification_report(y_test, y_pred_calibrated2))

效果也差不多。

画出校准曲线的对比图:
 

# 计算校准后的校准曲线
prob_true_calibrated, prob_pred_calibrated = calibration_curve(y_pred_calibrated,calibrated_model.predict_proba(X_test)[:, 1], n_bins=10)
prob_true_calibrated2, prob_pred_calibrated2 = calibration_curve(y_pred_calibrated2,calibrated_model2.predict_proba(X_test)[:, 1], n_bins=10)
# 绘制校准后的校准曲线
plt.figure(figsize=(7, 4),dpi=128)
plt.plot(prob_pred, prob_true, marker='o', label='uncalibrated')
plt.plot(prob_pred_calibrated, prob_true_calibrated, marker='o', label='sigmoid calibrated')
plt.plot(prob_pred_calibrated2, prob_true_calibrated2, marker='o', label='isotonic calibrated')
plt.plot([0, 1], [0, 1], linestyle='--', color='gray', label='perfectly calibrated')
plt.xlabel('Mean predicted probability')
plt.ylabel('Fraction of positives')
plt.title('Calibration Curve (Calibrated)')
plt.legend()
plt.show()

 

可以看到模型校准之后这个线都是单调上升的了。但是都很奇怪,而且预测效果也没太多改善,可能是我这个数据集是随便造的原因。

校准曲线的单调性在模型校准中确实非常重要。校准曲线的单调性指的是在横轴表示预测概率的均值,纵轴表示实际观测到的正例比例时,曲线应该是单调递增的,即预测概率越高,观测到的正例比例也应该越高。

校准曲线的单调性反映了模型输出的概率与实际观测之间的一致性。如果校准曲线的单调性较差,意味着模型的输出概率与实际观测之间存在较大的偏差,可能会导致模型在实际应用中表现不稳定或不可靠。因此,单调的校准曲线通常被认为是一个良好校准的指标之一。

在实际应用中,如果模型的校准曲线不单调,可能需要进一步考虑以下问题:

模型的输出概率是否准确反映了样本的真实概率:如果模型输出的概率存在系统性的偏差,可能需要对模型进行校准,使其输出更加准确地反映样本的真实概率。

模型是否过度自信或不足自信:校准曲线的不单调性可能反映了模型在某些概率范围内过度自信或不足自信的问题。对于过度自信的模型,可能需要降低其输出概率;对于不足自信的模型,可能需要提高其输出概率。

模型的可靠性:校准曲线的单调性也反映了模型的可靠性。单调递增的校准曲线意味着模型的输出概率与实际观测之间的一致性较好,通常更可靠。

因此,校准曲线的单调性对于评估模型的校准效果和可靠性具有重要意义,在模型校准过程中应该注意观察和优化校准曲线的单调性。

嗯,都是gpt的话,看看了解一下就行。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/791032.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【THM】Nmap Advanced Port Scans(高级端口扫描)-初级渗透测试

介绍 本房间是Nmap系列的第三个房间(网络安全简介模块的一部分)。在前两个房间中,我们了解了实时主机发现和基本端口扫描。 Nmap实时主机发现Nmap基本端口扫描Nmap高级端口扫描Nmap后端口扫描在Nmap基本端口扫描中,我们介绍了TCP标志并回顾了TCP 3 路握手。要启动连接,TC…

AcWing刷题-约数个数

约数的个数 代码 # 计数 def f(x)->int:cnt 0i 1while i * i < x:if x % i 0:cnt 1if i * i < x:cnt 1i 1return cntn int(input()) a list(map(int,input().split())) for i in a:print(f(i))

蓝桥杯练习——拼出一个未来

选中 index.html 右键启动 Web Server 服务&#xff08;Open with Live Server&#xff09;&#xff0c;让项目运行起来。接着&#xff0c;打开环境右侧的【Web 服务】&#xff0c;就可以在浏览器中看到如下效果&#xff1a; 目标 完善 js/index.js 的 TODO 部分&#xff0c;实…

【leetcode】 c++ 数字全排列, test ok

1. 问题 2. 思路 3. 代码实现 #if 0 class Solution { private:vector<int> path; // 满足条件的一个结果 vector<vector<int>> res; // 结果集 void backtracking(vector<int> nums, vector<bool> used){// 若path的个数和nums个数相等&…

算法整理:排序

快速排序 首先不妨以第一个数为基准数&#xff0c;在一轮遍历后&#xff0c;使基准数左边的数都小于基准数&#xff0c;基准数右边的数都大于基准数。 当然也可以取中间的数为基准数。 void quick_sort(vector<int>&nums,int l,int r){if(l>r)return;int idxl;//…

硬件工程师职责与核心技能有哪些?

作为一个优秀的硬件工程师&#xff0c;必须要具备优秀的职业技能。那么&#xff0c;有些刚入行的工程师及在校的学生经常会问到&#xff1a;硬件工程师需要哪些核心技能&#xff1f;要回答这个问题&#xff0c;首先要明白硬件工程师的职责&#xff0c;然后才能知道核心技能要求…

神经网络学习笔记10——RNN、ELMo、Transformer、GPT、BERT

系列文章目录 参考博客1 参考博客2 文章目录 系列文章目录前言一、RNN1、简介2、模型结构3、RNN公式分析4、RNN的优缺点及优化1&#xff09;LSTM是RNN的优化结构2&#xff09;GRU是LSTM的简化结构 二、ELMo1、简介2、模型结构1&#xff09;输入2&#xff09;左右双向上下文信…

Gemini即将收费,GPT无需注册?GPT3.5白嫖和升级教程

&#x1f310;Gemini 即将开始收费 开发者“白嫖”的好日子到头了 - Gemini将开始收费&#xff0c;影响使用Google AI for Developers提供的Gemini API的用户。 - Gemini API将引入按量付费定价&#xff0c;需要注意新的服务条款。 - 用户需在5月2日之前停止使用Gemini API和Go…

使用Java拓展本地开源大模型的网络搜索问答能力

背景 开源大模型通常不具备最新语料的问答能力。因此需要外部插件的拓展&#xff0c;目前主流的langChain框架已经集成了网络搜索的能力。但是作为一个倔强的Java程序员&#xff0c;还是想要用Java去实现。 注册SerpAPI Serpapi 提供了多种搜索引擎的搜索API接口。 访问 Ser…

数据结构(二)----线性表(顺序表,链表)

目录 1.线性表的概念 2.线性表的基本操作 3.存储线性表的方式 &#xff08;1&#xff09;顺序表 •顺序表的概念 •顺序表的实现 静态分配&#xff1a; 动态分配&#xff1a; 顺序表的插入&#xff1a; 顺序表的删除&#xff1a; 顺序表的按位查找&#xff1a; 顺序…

自我认识的方法模型图

在漫长的人生旅途中&#xff0c;我们都在不断地探索、追寻&#xff0c;努力寻找那个最真实、最完整的自我。因为只有真正了解自己&#xff0c;才能战胜内心的种种困惑与恐惧&#xff0c;进而战胜外在的一切挑战与困难。自我认识&#xff0c;是每个人成长的必经之路&#xff0c;…

探索未来外贸电商系统的创新架构

在全球化、数字化的时代背景下&#xff0c;外贸电商行业呈现出蓬勃发展的态势。为了适应市场竞争的激烈和用户需求的多样化&#xff0c;外贸电商系统的架构设计显得尤为重要。本文将深入探讨未来外贸电商系统的创新架构&#xff0c;以期为行业发展提供新的思路和方向。 随着全…

使用Flutter创建带有图标提示的TextField

在移动应用开发中&#xff0c;TextField是一种常用的用户输入小部件。然而&#xff0c;有时向用户提供有关他们应该输入什么的提示或说明是很有帮助的。在本教程中&#xff0c;我们将创建一个Flutter应用程序&#xff0c;演示如何在TextField旁边包含一个图标提示。 编写代码 …

初次在 GitHub 建立仓库以及公开代码的流程 - 公开代码

初次在 GitHub 建立仓库以及公开代码的流程 - 公开代码 References 在已有仓库中添加代码并公开。 git clone 已有仓库 将已有仓库 clone 到本地的开发环境中。 strongforeverstrong:~$ mkdir github_work strongforeverstrong:~$ cd github_work/ strongforeverstrong:~/git…

【.NET全栈】.NET全栈学习路线

一、微软官方C#学习 https://learn.microsoft.com/zh-cn/dotnet/csharp/tour-of-csharp/ C#中的数据类型 二、2021 ASP.NET Core 开发者路线图 GitHub地址&#xff1a;https://github.com/MoienTajik/AspNetCore-Developer-Roadmap/blob/master/ReadMe.zh-Hans.md 三、路线…

把标注数据导入到知识图谱

文章目录 简介数据导入Doccano标注数据&#xff0c;导入到Neo4j寻求帮助 简介 团队成员使用 Doccano 标注了一些数据&#xff0c;包括 命名实体识别、关系和文本分类 的标注的数据&#xff1b; 工作步骤如下&#xff1a; 首先将标注数据导入到Doccano&#xff0c;查看一下标注…

Idea2023创建Servlet项目

① Java EE 只是一个抽象的规范&#xff0c;具体实现称为应用服务器。 ② Java EE 只需要两个包 jsp-api.jar 和 servlet-api.jar&#xff0c;而这两个包是没有官方版本的。也就是说&#xff0c;Java 没有提供这两个包&#xff0c;只提供了一个规范。那么这两个包是谁提供的…

逻辑回归(Logistic Regression)详解

逻辑回归&#xff08;Logistic Regression&#xff09;是一种常用的统计学习方法&#xff0c;用于解决二分类问题。虽然名字中包含“回归”&#xff0c;但逻辑回归实际上是一种分类算法&#xff0c;而不是回归算法。它的基本原理是使用逻辑函数&#xff08;也称为Sigmoid函数&a…

【TI毫米波雷达】IWR6843AOP的官方文件资源名称BUG,选择xwr68xx还是xwr64xx,及需要注意的问题

【TI毫米波雷达】IWR6843AOP的官方文件资源名称BUG&#xff0c;选择xwr68xx还是xwr64xx&#xff0c;及需要注意的问题 文章目录 demo工程out_of_box文件调试bin文件名称需要注意的问题附录&#xff1a;结构框架雷达基本原理叙述雷达天线排列位置芯片框架Demo工程功能CCS工程导…

大模型学习笔记八:手撕AutoGPT

文章目录 一、功能需求二、演示用例三、核心模块流程图四、代码分析1&#xff09;Agent类目录创建智能体对象2&#xff09;开始主流程3&#xff09;在prompt的main目录输入主prompt和最后prompt4&#xff09;增加实际的工具集tools&#xff08;也就是函数&#xff09;5&#xf…