机器学习的模型校准

背景知识

之前一直没了解过模型校准是什么东西,最近上班业务需要看了一下:

模型校准是指对分类模型进行修正以提高其概率预测的准确性。在分类模型中,预测结果通常以类别标签形式呈现(例如,0或1),但有时我们更关注的是预测的概率。

当使用某些分类模型(例如支持向量机(SVM)或随机森林)时,其预测的概率并不一定与真实标签的概率分布相匹配。这意味着,即使预测概率较高的类别出现的频率更高,模型的预测概率也可能偏离真实情况。这可能导致对模型的概率输出有误解,或者在需要高度依赖概率预测的任务(例如风险评估或阈值选择)中出现问题。

通过校准分类模型,我们可以将模型的预测概率调整为更准确地反映真实情况。`CalibratedClassifierCV`是Scikit-learn库中提供的用于校准分类器的类。它根据指定的校准方法(`method`),通过拟合后的分类器(`model`)和交叉验证拟合(`cv='prefit'`)来创建一个经过校准的分类器(`calibrated_model`)。

在代码中,使用`calibrated_model.fit(X_train, y_train)`通过使用交叉验证拟合来训练、校准模型。之后,使用`calibrated_model.predict(X_test)`对测试集进行预测,并使用`classification_report`输出校准模型的分类性能报告。

通过校准分类模型,我们可以使得模型的概率预测更为准确,从而提高在概率判断和相关任务中的性能和可靠性。


代码实现

模型校准主要是针对分类模型的,我之前都是做回归,难怪没怎么接触过。也没空找真实数据了,直接模拟数据来实现一下。

导入包和制作数据集

import numpy as np
import pandas as pd
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import classification_report
from sklearn.calibration import calibration_curve
from sklearn.ensemble import RandomForestClassifierimport matplotlib.pyplot as plt# 生成二分类数据集
X, y = make_classification(n_samples=10000, n_features=40, n_classes=2, weights=[0.9, 0.1], random_state=2, flip_y=0.3)

查看分布:
 

pd.Series(y).value_counts()

不平衡样本。

标准化,划分训练集测试集

# 标准化数据
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X_scaled, y,stratify=y, test_size=0.2, random_state=2)

训练,然后评价,这里就弄了个随机森林模型试试

# 模型训练
model =RandomForestClassifier()
model.fit(X_train, y_train)# 模型评价
y_pred = model.predict(X_test)
print("Classification Report:")
print(classification_report(y_test, y_pred))


画校准曲线。

# 计算校准曲线
prob_true, prob_pred = calibration_curve(y_test, model.predict_proba(X_test)[:, 1], n_bins=10)# 绘制校准曲线
plt.figure(figsize=(7, 4),dpi=128)
plt.plot(prob_pred, prob_true, marker='o', label='uncalibrated')
plt.plot([0, 1], [0, 1], linestyle='--', color='gray', label='perfectly calibrated')
plt.xlabel('Mean predicted probability')
plt.ylabel('Fraction of positives')
plt.title('Calibration Curve (Uncalibrated)')
plt.legend()
plt.show()

这玩意怎么看,,,我也不太懂,反正就是要单调,并且越靠近对角线越好。这个明显在0.1-0.2区间不单调,还有0.7-0.9也在下降。

来校准一下:


模型校准

模型校准很多方法,目前这个是用了 method='sigmoid',这个方法,好像叫做什么p系数校准。

from sklearn.calibration import CalibratedClassifierCV
calibrated_model = CalibratedClassifierCV(model, method='sigmoid', cv='prefit')
calibrated_model.fit(X_train, y_train)# 模型评价(校准后)
y_pred_calibrated = calibrated_model.predict(X_test)
print("Classification Report (Calibrated Model):")
print(classification_report(y_test, y_pred_calibrated))

emmm,效果好像没有明显提升。

method='isotonic',这个是什么保序回归方法校准。

calibrated_model2 = CalibratedClassifierCV(model, method='isotonic', cv='prefit')
calibrated_model2.fit(X_train, y_train)# 模型评价(校准后)
y_pred_calibrated2 = calibrated_model2.predict(X_test)
print("Classification Report (Calibrated Model):")
print(classification_report(y_test, y_pred_calibrated2))

效果也差不多。

画出校准曲线的对比图:
 

# 计算校准后的校准曲线
prob_true_calibrated, prob_pred_calibrated = calibration_curve(y_pred_calibrated,calibrated_model.predict_proba(X_test)[:, 1], n_bins=10)
prob_true_calibrated2, prob_pred_calibrated2 = calibration_curve(y_pred_calibrated2,calibrated_model2.predict_proba(X_test)[:, 1], n_bins=10)
# 绘制校准后的校准曲线
plt.figure(figsize=(7, 4),dpi=128)
plt.plot(prob_pred, prob_true, marker='o', label='uncalibrated')
plt.plot(prob_pred_calibrated, prob_true_calibrated, marker='o', label='sigmoid calibrated')
plt.plot(prob_pred_calibrated2, prob_true_calibrated2, marker='o', label='isotonic calibrated')
plt.plot([0, 1], [0, 1], linestyle='--', color='gray', label='perfectly calibrated')
plt.xlabel('Mean predicted probability')
plt.ylabel('Fraction of positives')
plt.title('Calibration Curve (Calibrated)')
plt.legend()
plt.show()

 

可以看到模型校准之后这个线都是单调上升的了。但是都很奇怪,而且预测效果也没太多改善,可能是我这个数据集是随便造的原因。

校准曲线的单调性在模型校准中确实非常重要。校准曲线的单调性指的是在横轴表示预测概率的均值,纵轴表示实际观测到的正例比例时,曲线应该是单调递增的,即预测概率越高,观测到的正例比例也应该越高。

校准曲线的单调性反映了模型输出的概率与实际观测之间的一致性。如果校准曲线的单调性较差,意味着模型的输出概率与实际观测之间存在较大的偏差,可能会导致模型在实际应用中表现不稳定或不可靠。因此,单调的校准曲线通常被认为是一个良好校准的指标之一。

在实际应用中,如果模型的校准曲线不单调,可能需要进一步考虑以下问题:

模型的输出概率是否准确反映了样本的真实概率:如果模型输出的概率存在系统性的偏差,可能需要对模型进行校准,使其输出更加准确地反映样本的真实概率。

模型是否过度自信或不足自信:校准曲线的不单调性可能反映了模型在某些概率范围内过度自信或不足自信的问题。对于过度自信的模型,可能需要降低其输出概率;对于不足自信的模型,可能需要提高其输出概率。

模型的可靠性:校准曲线的单调性也反映了模型的可靠性。单调递增的校准曲线意味着模型的输出概率与实际观测之间的一致性较好,通常更可靠。

因此,校准曲线的单调性对于评估模型的校准效果和可靠性具有重要意义,在模型校准过程中应该注意观察和优化校准曲线的单调性。

嗯,都是gpt的话,看看了解一下就行。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/791032.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Python程序设计 单例模式

1. 单例设计模式 设计模式设计模式 是 前人工作的总结和提炼,通常,被人们广泛流传的设计模式都是针对 某一特定问题 的成熟的解决方案使用 设计模式 是为了可重用代码、让代码更容易被他人理解、保证代码可靠性单例设计模式目的 —— 让 类 创建的对象&…

mac 上通过命令行挂载NTFS硬盘,使其可以进行读写

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 前言1. 安装 osxfuse 和 ntfs-3g2. 挂载 NTFS 硬盘3. 卸载 NTFS 硬盘4. 自动挂载1. 找出设备UUID2. 编辑 /etc/fstab 文件3. 添加挂载信息4. 保存并退出编辑器5. 重…

【THM】Nmap Advanced Port Scans(高级端口扫描)-初级渗透测试

介绍 本房间是Nmap系列的第三个房间(网络安全简介模块的一部分)。在前两个房间中,我们了解了实时主机发现和基本端口扫描。 Nmap实时主机发现Nmap基本端口扫描Nmap高级端口扫描Nmap后端口扫描在Nmap基本端口扫描中,我们介绍了TCP标志并回顾了TCP 3 路握手。要启动连接,TC…

AcWing刷题-约数个数

约数的个数 代码 # 计数 def f(x)->int:cnt 0i 1while i * i < x:if x % i 0:cnt 1if i * i < x:cnt 1i 1return cntn int(input()) a list(map(int,input().split())) for i in a:print(f(i))

HDFSRPC通信框架参数详解

写在前面 请先阅读HDFSRPC通信框架详解&#xff0c;对整体框架先有一定的了解。 参数列表 参数默认值描述ipc.server.read.connection-queue.size100readeripc.server.read.threadpool.size1readeripc.server.listen.queue.size128Listener:backlogipc.server.tcpnodelaytru…

Generative AI for Beginners

Generative AI for Beginners 微软推出的面向初学者的免费生成式人工智能课程。 课程章节相关教学内容学习目标课程介绍和学习环境设置学习环境配置和课程结构在学习本课程的同时帮助您取得成功生成式人工智能和 LLMs 介绍知识点: 生成式人工智能以及我们如何适应当前的技术格…

蓝桥杯练习——拼出一个未来

选中 index.html 右键启动 Web Server 服务&#xff08;Open with Live Server&#xff09;&#xff0c;让项目运行起来。接着&#xff0c;打开环境右侧的【Web 服务】&#xff0c;就可以在浏览器中看到如下效果&#xff1a; 目标 完善 js/index.js 的 TODO 部分&#xff0c;实…

使用Pointpillar神经网络识别rosbag中的障碍物

PointPillar-ROS-Node https://github.com/MengWoods/pointpillar-ros-node 这个仓库包含一个ROS节点&#xff0c;用于处理点云数据。它使用了PointPillar神经网络模型&#xff0c;允许用户在ROS环境中处理ROSbags中的点云数据。通过简单的命令&#xff0c;用户可以克隆该仓库…

【leetcode】 c++ 数字全排列, test ok

1. 问题 2. 思路 3. 代码实现 #if 0 class Solution { private:vector<int> path; // 满足条件的一个结果 vector<vector<int>> res; // 结果集 void backtracking(vector<int> nums, vector<bool> used){// 若path的个数和nums个数相等&…

Qt控件样式设置其一(常见方法及优缺点)

如果你对Qt有基本的了解&#xff0c;应该知道它的一大优点是跨平台&#xff0c;可以在不同的系统中编译运行。但在我看来&#xff0c;Qt还有另外一个优点&#xff0c;就是制作界面比较方便和灵活&#xff0c;能够实现主流静态效果的桌面应用。&#xff08;如果需要实现比较灵动…

4款免费可用的数据集成平台亮点

在众多免费的数据集成工具中&#xff0c;我们选出了四个平台&#xff0c;它们分别是Apache Nifi、FineDataLink、kettle、ETLCLoud。现在&#xff0c;让我们快速浏览一下这四个平台的亮点。 Apache Nifi&#xff1a; Apache NiFi 是一款强大的数据集成和处理平台&#xff0c;它…

DockerFile启动jar程序

1.创建Dockerfile 在项目的根目录下创建一个名为Dockerfile的文件&#xff0c;并使用文本编辑器打开它。Dockerfile的内容如下&#xff1a; # 基础镜像 FROM openjdk:8-jre # 创建目录 RUN mkdir -p /usr/app/ # 设置工作目录 WORKDIR /usr/app # 将JAR文件复制到容器中,注:…

算法整理:排序

快速排序 首先不妨以第一个数为基准数&#xff0c;在一轮遍历后&#xff0c;使基准数左边的数都小于基准数&#xff0c;基准数右边的数都大于基准数。 当然也可以取中间的数为基准数。 void quick_sort(vector<int>&nums,int l,int r){if(l>r)return;int idxl;//…

硬件工程师职责与核心技能有哪些?

作为一个优秀的硬件工程师&#xff0c;必须要具备优秀的职业技能。那么&#xff0c;有些刚入行的工程师及在校的学生经常会问到&#xff1a;硬件工程师需要哪些核心技能&#xff1f;要回答这个问题&#xff0c;首先要明白硬件工程师的职责&#xff0c;然后才能知道核心技能要求…

神经网络学习笔记10——RNN、ELMo、Transformer、GPT、BERT

系列文章目录 参考博客1 参考博客2 文章目录 系列文章目录前言一、RNN1、简介2、模型结构3、RNN公式分析4、RNN的优缺点及优化1&#xff09;LSTM是RNN的优化结构2&#xff09;GRU是LSTM的简化结构 二、ELMo1、简介2、模型结构1&#xff09;输入2&#xff09;左右双向上下文信…

Gemini即将收费,GPT无需注册?GPT3.5白嫖和升级教程

&#x1f310;Gemini 即将开始收费 开发者“白嫖”的好日子到头了 - Gemini将开始收费&#xff0c;影响使用Google AI for Developers提供的Gemini API的用户。 - Gemini API将引入按量付费定价&#xff0c;需要注意新的服务条款。 - 用户需在5月2日之前停止使用Gemini API和Go…

使用Java拓展本地开源大模型的网络搜索问答能力

背景 开源大模型通常不具备最新语料的问答能力。因此需要外部插件的拓展&#xff0c;目前主流的langChain框架已经集成了网络搜索的能力。但是作为一个倔强的Java程序员&#xff0c;还是想要用Java去实现。 注册SerpAPI Serpapi 提供了多种搜索引擎的搜索API接口。 访问 Ser…

Linux初学(十三)中间件

一、Nginx 简介 Nginx是一个高性能的HTTP和反向代理web服务器 轻量级、高性能 1.1 Nginx安装 方法一&#xff1a;编译安装 依赖&#xff1a;openssl-devel、zlib-devel、ncurses-devel、pcre-devel、gcc、gcc-c 方法二&#xff1a;yum安装 Nginx的rpm包在epel源中 编译安…

2024.3.10力扣每日一题——猜数字游戏

2024.3.10 题目来源我的题解方法一 哈希表方法二 使用数组优化 题目来源 力扣每日一题&#xff1b;题序&#xff1a;299 我的题解 方法一 哈希表 使用哈希表记录secret中每个数字出现的次数&#xff0c;然后遍历guess的每一位&#xff0c;再判断与secret对应位置是否相同&am…

数据结构(二)----线性表(顺序表,链表)

目录 1.线性表的概念 2.线性表的基本操作 3.存储线性表的方式 &#xff08;1&#xff09;顺序表 •顺序表的概念 •顺序表的实现 静态分配&#xff1a; 动态分配&#xff1a; 顺序表的插入&#xff1a; 顺序表的删除&#xff1a; 顺序表的按位查找&#xff1a; 顺序…