第100+23步 ChatGPT学习:概率校准 Sigmoid Calibration

基于Python 3.9版本演示

一、写在前面

最近看了一篇在Lancet子刊《eClinicalMedicine》上发表的机器学习分类的文章:《Development of a novel dementia risk prediction model in the general population: A large, longitudinal, population-based machine-learning study》。

学到一种叫做“概率校准”的骚操作,顺手利用GPT系统学习学习。

文章中用的技术是:保序回归(Isotonic regression)。

为了体现举一反三,顺便问了GPT还有哪些方法也可以实现概率校准。它给我列举了很多,那么就一个一个学习吧。

这一期,介绍一个叫做 Sigmoid Calibration 的方法。

二、Sigmoid Calibration

Sigmoid Calibration是一种后处理技术,用于改进机器学习分类器的概率估计。它通常应用于二元分类器的输出,将原始得分转换为校准后的概率。该技术使用逻辑(Sigmoid)函数将分类器的得分映射到概率上,旨在确保预测的概率更准确地反映真实结果的可能性。

(1)Sigmoid Calibration 的基本步骤

1)训练分类器:在训练数据上训练你的二元分类器。

2)获取原始得分:收集分类器在验证数据集上的原始得分或 logits。

3)拟合逻辑回归模型:使用验证数据集拟合一个逻辑回归模型,将原始得分映射为概率。

4)预测校准后的概率:使用拟合的逻辑回归模型,将分类器的原始得分转换为校准后的概率。

(2)Sigmoid Calibration 的使用

对于逻辑回归模型,通常不需要进行Sigmoid校准,因为逻辑回归本身就是基于Sigmoid函数来计算概率的。然而,在一些情况下,即使是逻辑回归模型,校准仍然可能有帮助。以下是一些可能需要校准的情况:

1)类不平衡问题:如果训练数据集中存在严重的类别不平衡问题,即某个类别的数据明显多于其他类别,逻辑回归模型的概率估计可能会偏向于较多的类别。在这种情况下,校准可以帮助调整概率估计,使其更准确地反映实际的类别分布。

2)模型训练不充分:如果逻辑回归模型没有充分训练,可能会导致概率估计不准确。校准可以在一定程度上纠正这种情况。

3)训练和测试数据分布不同:如果训练数据和测试数据的分布存在差异,逻辑回归模型的概率估计可能不适用于测试数据。在这种情况下,可以使用校准技术对模型的输出进行调整。

4)多模型集成:在多模型集成(例如集成学习)中,不同模型的输出需要组合在一起。校准可以确保不同模型的输出概率具有一致性,从而提高集成模型的性能。

三、Sigmoid Calibration代码实现

下面,我编一个1比3的不太平衡的数据进行测试,对照组使用不进行校准的SVM模型,实验组就是加入校准的SVM模型,看看性能能够提高多少?

(1)不进行校准的SVM模型(默认参数)

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.svm import SVC
from sklearn.metrics import confusion_matrix, roc_auc_score, roc_curve# 加载数据
dataset = pd.read_csv('8PSMjianmo.csv')
X = dataset.iloc[:, 1:20].values
Y = dataset.iloc[:, 0].values# 分割数据集
X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=0.30, random_state=666)# 标准化数据
sc = StandardScaler()
X_train = sc.fit_transform(X_train)
X_test = sc.transform(X_test)# 使用SVM分类器
classifier = SVC(kernel='linear', probability=True)
classifier.fit(X_train, y_train)# 预测结果
y_pred = classifier.predict(X_test)
y_testprba = classifier.decision_function(X_test)y_trainpred = classifier.predict(X_train)
y_trainprba = classifier.decision_function(X_train)# 混淆矩阵
cm_test = confusion_matrix(y_test, y_pred)
cm_train = confusion_matrix(y_train, y_trainpred)
print(cm_train)
print(cm_test)# 绘制测试集混淆矩阵
classes = list(set(y_test))
classes.sort()
plt.imshow(cm_test, cmap=plt.cm.Blues)
indices = range(len(cm_test))
plt.xticks(indices, classes)
plt.yticks(indices, classes)
plt.colorbar()
plt.xlabel('Predicted')
plt.ylabel('Actual')
for first_index in range(len(cm_test)):for second_index in range(len(cm_test[first_index])):plt.text(first_index, second_index, cm_test[first_index][second_index])plt.show()# 绘制训练集混淆矩阵
classes = list(set(y_train))
classes.sort()
plt.imshow(cm_train, cmap=plt.cm.Blues)
indices = range(len(cm_train))
plt.xticks(indices, classes)
plt.yticks(indices, classes)
plt.colorbar()
plt.xlabel('Predicted')
plt.ylabel('Actual')
for first_index in range(len(cm_train)):for second_index in range(len(cm_train[first_index])):plt.text(first_index, second_index, cm_train[first_index][second_index])plt.show()# 计算并打印性能参数
def calculate_metrics(cm, y_true, y_pred_prob):a = cm[0, 0]b = cm[0, 1]c = cm[1, 0]d = cm[1, 1]acc = (a + d) / (a + b + c + d)error_rate = 1 - accsen = d / (d + c)sep = a / (a + b)precision = d / (b + d)F1 = (2 * precision * sen) / (precision + sen)MCC = (d * a - b * c) / (np.sqrt((d + b) * (d + c) * (a + b) * (a + c)))auc_score = roc_auc_score(y_true, y_pred_prob)metrics = {"Accuracy": acc,"Error Rate": error_rate,"Sensitivity": sen,"Specificity": sep,"Precision": precision,"F1 Score": F1,"MCC": MCC,"AUC": auc_score}return metricsmetrics_test = calculate_metrics(cm_test, y_test, y_testprba)
metrics_train = calculate_metrics(cm_train, y_train, y_trainprba)print("Performance Metrics (Test):")
for key, value in metrics_test.items():print(f"{key}: {value:.4f}")print("\nPerformance Metrics (Train):")
for key, value in metrics_train.items():
print(f"{key}: {value:.4f}")

结果输出:

记住这些个数字。

这个参数的SVM还没有LR好。

(2)进行校准的SVM模型(默认参数)

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.svm import SVC
from sklearn.metrics import confusion_matrix, roc_auc_score, roc_curve
from sklearn.calibration import CalibratedClassifierCV# 加载数据
dataset = pd.read_csv('8PSMjianmo.csv')
X = dataset.iloc[:, 1:20].values
Y = dataset.iloc[:, 0].values# 分割数据集
X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=0.30, random_state=666)# 标准化数据
sc = StandardScaler()
X_train = sc.fit_transform(X_train)
X_test = sc.transform(X_test)# 使用SVM分类器
classifier = SVC(kernel='rbf', C= 0.1, probability=True)
classifier.fit(X_train, y_train)# 进行Sigmoid校准
calibrated_classifier = CalibratedClassifierCV(base_estimator=classifier, method='sigmoid', cv='prefit')
calibrated_classifier.fit(X_test, y_test)# 预测结果
y_pred = calibrated_classifier.predict(X_test)
y_testprba = calibrated_classifier.predict_proba(X_test)[:, 1]y_trainpred = calibrated_classifier.predict(X_train)
y_trainprba = calibrated_classifier.predict_proba(X_train)[:, 1]# 混淆矩阵
cm_test = confusion_matrix(y_test, y_pred)
cm_train = confusion_matrix(y_train, y_trainpred)
print(cm_train)
print(cm_test)# 绘制测试集混淆矩阵
classes = list(set(y_test))
classes.sort()
plt.imshow(cm_test, cmap=plt.cm.Blues)
indices = range(len(cm_test))
plt.xticks(indices, classes)
plt.yticks(indices, classes)
plt.colorbar()
plt.xlabel('Predicted')
plt.ylabel('Actual')
for first_index in range(len(cm_test)):for second_index in range(len(cm_test[first_index])):plt.text(first_index, second_index, cm_test[first_index][second_index])plt.show()# 绘制训练集混淆矩阵
classes = list(set(y_train))
classes.sort()
plt.imshow(cm_train, cmap=plt.cm.Blues)
indices = range(len(cm_train))
plt.xticks(indices, classes)
plt.yticks(indices, classes)
plt.colorbar()
plt.xlabel('Predicted')
plt.ylabel('Actual')
for first_index in range(len(cm_train)):for second_index in range(len(cm_train[first_index])):plt.text(first_index, second_index, cm_train[first_index][second_index])plt.show()# 计算并打印性能参数
def calculate_metrics(cm, y_true, y_pred_prob):a = cm[0, 0]b = cm[0, 1]c = cm[1, 0]d = cm[1, 1]acc = (a + d) / (a + b + c + d)error_rate = 1 - accsen = d / (d + c)sep = a / (a + b)precision = d / (b + d)F1 = (2 * precision * sen) / (precision + sen)MCC = (d * a - b * c) / (np.sqrt((d + b) * (d + c) * (a + b) * (a + c)))auc_score = roc_auc_score(y_true, y_pred_prob)metrics = {"Accuracy": acc,"Error Rate": error_rate,"Sensitivity": sen,"Specificity": sep,"Precision": precision,"F1 Score": F1,"MCC": MCC,"AUC": auc_score}return metricsmetrics_test = calculate_metrics(cm_test, y_test, y_testprba)
metrics_train = calculate_metrics(cm_train, y_train, y_trainprba)print("Performance Metrics (Test):")
for key, value in metrics_test.items():print(f"{key}: {value:.4f}")print("\nPerformance Metrics (Train):")
for key, value in metrics_train.items():print(f"{key}: {value:.4f}")

看看结果:

总体来看,仅仅训练集起作用了,验证集差强人意。

四、换个策略

参考那篇文章的策略:采用五折交叉验证来建立和评估模型,其中四折用于训练,一折用于评估,在训练集中,其中三折用于建立SVM模型,另一折采用Sigmoid Calibration概率校正,在训练集内部采用交叉验证对超参数进行调参。

代码:

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.model_selection import train_test_split, KFold, GridSearchCV
from sklearn.preprocessing import StandardScaler
from sklearn.svm import SVC
from sklearn.calibration import CalibratedClassifierCV
from sklearn.metrics import confusion_matrix, roc_auc_score, make_scorer# 加载数据
dataset = pd.read_csv('8PSMjianmo.csv')
X = dataset.iloc[:, 1:20].values
Y = dataset.iloc[:, 0].values# 标准化数据
sc = StandardScaler()
X = sc.fit_transform(X)# 五折交叉验证
kf = KFold(n_splits=5, shuffle=True, random_state=666)# 超参数调优参数网格
param_grid = {'C': [0.1, 1, 10, 100],'kernel': ['linear', 'rbf']
}# 计算并打印性能参数
def calculate_metrics(cm, y_true, y_pred_prob):a = cm[0, 0]b = cm[0, 1]c = cm[1, 0]d = cm[1, 1]acc = (a + d) / (a + b + c + d)error_rate = 1 - accsen = d / (d + c)sep = a / (a + b)precision = d / (b + d)F1 = (2 * precision * sen) / (precision + sen)MCC = (d * a - b * c) / (np.sqrt((d + b) * (d + c) * (a + b) * (a + c)))auc_score = roc_auc_score(y_true, y_pred_prob)metrics = {"Accuracy": acc,"Error Rate": error_rate,"Sensitivity": sen,"Specificity": sep,"Precision": precision,"F1 Score": F1,"MCC": MCC,"AUC": auc_score}return metrics# 初始化结果列表
results_train = []
results_test = []# 初始化变量以跟踪最优模型
best_auc = 0
best_model = None
best_X_train = None
best_X_test = None
best_y_train = None
best_y_test = None# 交叉验证过程
for train_index, test_index in kf.split(X):X_train, X_test = X[train_index], X[test_index]y_train, y_test = Y[train_index], Y[test_index]# 内部交叉验证进行超参数调优和模型训练inner_kf = KFold(n_splits=4, shuffle=True, random_state=666)grid_search = GridSearchCV(SVC(probability=True), param_grid, cv=inner_kf, scoring='roc_auc')grid_search.fit(X_train, y_train)model = grid_search.best_estimator_# Sigmoid Calibration 概率校准calibrated_svm = CalibratedClassifierCV(model, method='sigmoid', cv='prefit')calibrated_svm.fit(X_train, y_train)# 评估模型y_trainpred = calibrated_svm.predict(X_train)y_trainprba = calibrated_svm.predict_proba(X_train)[:, 1]cm_train = confusion_matrix(y_train, y_trainpred)metrics_train = calculate_metrics(cm_train, y_train, y_trainprba)results_train.append(metrics_train)y_pred = calibrated_svm.predict(X_test)y_testprba = calibrated_svm.predict_proba(X_test)[:, 1]cm_test = confusion_matrix(y_test, y_pred)metrics_test = calculate_metrics(cm_test, y_test, y_testprba)results_test.append(metrics_test)# 更新最优模型if metrics_test['AUC'] > best_auc:best_auc = metrics_test['AUC']best_model = calibrated_svmbest_X_train = X_trainbest_X_test = X_testbest_y_train = y_trainbest_y_test = y_testbest_params = grid_search.best_params_print("Performance Metrics (Train):")for key, value in metrics_train.items():print(f"{key}: {value:.4f}")print("\nPerformance Metrics (Test):")for key, value in metrics_test.items():print(f"{key}: {value:.4f}")print("\n" + "="*40 + "\n")# 使用最优模型评估性能
y_trainpred = best_model.predict(best_X_train)
y_trainprba = best_model.predict_proba(best_X_train)[:, 1]
cm_train = confusion_matrix(best_y_train, y_trainpred)
metrics_train = calculate_metrics(cm_train, best_y_train, y_trainprba)y_pred = best_model.predict(best_X_test)
y_testprba = best_model.predict_proba(best_X_test)[:, 1]
cm_test = confusion_matrix(best_y_test, y_pred)
metrics_test = calculate_metrics(cm_test, best_y_test, y_testprba)print("Performance Metrics of the Best Model (Train):")
for key, value in metrics_train.items():print(f"{key}: {value:.4f}")print("\nPerformance Metrics of the Best Model (Test):")
for key, value in metrics_test.items():print(f"{key}: {value:.4f}")# 打印最优模型的参数
print("\nBest Model Parameters:")
for key, value in best_params.items():print(f"{key}: {value}")

输出:

还是有提升的,不过并没有Platt Scaling的结果好。

五、最后

各位可以去试一试在其他数据或者在其他机器学习分类模型中使用的效果。

数据不分享啦。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/51561.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

0.0 C语言被我遗忘的知识点

文章目录 位移运算(>>和<<)函数指针函数指针的应用场景 strcmp的返回值合法的c语言实数表示sizeof 数组字符串的储存 —— 字符数组与字符指针字符串可能缺少 \0 的情况 用二维数组储存字符串数组其他储存字符串数组的方法 位移运算(>>和<<) 右移(>…

c++中的匿名对象及内存管理

c中的匿名对象 A a;//a的生命周期在整个main函数中 a.Sum(1); //匿名对象生命周期只有一行&#xff0c;只有这一行会创建对象,出了这一行就会调析构 A().Sum(1);//只有这一行需要这个对象&#xff0c;其他地方不需要。 return 0; 日期到天数的转换 计算日期到天数转换_牛客…

【鸿蒙样式初探】多个组件如何共用同一样式

最近开发鸿蒙&#xff0c;刚接触难免二和尚摸不着头脑&#xff0c;尤其是样式...... 背景 在做银行卡显示的一个小需求时&#xff1a; 每个Text都需要设置fontColor:#FFFFFF" 想着是否可以简单点 解决历程 思路一&#xff1a;&#xff08;拒绝) 使用Styles 提取封装公…

爆改YOLOv8|利用可改变核卷积AKConv改进yolov8-轻量涨点

1&#xff0c;本文介绍 AKConv&#xff08;可改变核卷积&#xff09;是一种改进的卷积操作方法&#xff0c;其核心在于动态调整卷积核的形状和大小。与传统卷积层固定核大小不同&#xff0c;AKConv 通过引入可学习的机制&#xff0c;使卷积核在训练过程中能够自适应地调整&…

学生宿舍管理小程序的设计

管理员账户功能包括&#xff1a;系统首页&#xff0c;个人中心&#xff0c;宿舍公告管理&#xff0c;学生管理&#xff0c;宿舍管理&#xff0c;后勤人员管理&#xff0c;楼栋信息管理&#xff0c;宿舍分配管理管理&#xff0c;退宿信息管理 微信端账号功能包括&#xff1a;系…

程序猿成长之路之数据挖掘篇——Kmeans聚类算法

Kmeans 是一种可以将一个数据集按照距离&#xff08;相似度&#xff09;划分成不同类别的算法&#xff0c;它无需借助外部标记&#xff0c;因此也是一种无监督学习算法。 什么是聚类 用官方的话说聚类就是将物理或抽象对象的集合分成由类似的对象组成的多个类的过程。用自己的…

idea import配置

简介 本文记录idea中import相关配置&#xff1a;自动导入依赖、自动删除无用依赖、避免自动导入*包 自动导入依赖 在编辑代码时&#xff0c;当只有一个具有匹配名称的可导入声明时&#xff0c;会自动添加导入 File -> Settings -> Editor -> General -> Auto Imp…

简而不减,极致便捷!泰极预付费解决方案震撼上市

开户麻烦!绑表复杂!用电情况模糊!电费收缴难! 在日常生活中,能源缴费可能经常会遇到运维难管理、缴费收益难计算、支付安全难保障等问题。如何解决呢?正泰物联推出“泰极预付费解决方案”,“简”操作,“不减”功能,有效解决上述问题,助力实现便捷生活。 享轻松:泰极简而不减…

MySQL内部临时表(Using temporary)案例详解及优化解决方法

目录 前言 一.场景案例 二、什么是内部临时表&#xff1f; 三、哪些场景会使用内部临时表&#xff1f; 四、内部临时表如何存储&#xff1f; 1&#xff09;使用内存 2&#xff09;先使用内存&#xff0c;再转化成磁盘文件 3&#xff09;直接使用磁盘文件 五、如何优化…

【软件文档】项目总结报告编制模板(Word原件参考)

1. 项目概要 1.1. 项目基本信息 1.2. 项目期间 1.3. 项目成果 1.4. 开发工具和环境 2. 项目工作分析 2.1. 项目需求变更 2.2. 项目计划与进度实施 2.3. 项目总投入情况 2.4. 项目总收益情况 2.5. 项目质量情况 2.6. 风险管理实施情况 3. 经验与教训 3.1. 经验总结…

【异常错误】pycharm可以在terminal中运行,但是无法在run中运行(没有输出错误就停止了)

问题&#xff1a; pycharm的命令可以在terminal中运行&#xff0c;但是复制到无法在run中运行&#xff08;没有输出错误就停止了&#xff09; run中运行后什么错误提示都没有 搞不懂为什么 解决&#xff1a; 降低run中batch-size的大小&#xff0c;即可以运行 我并没有观察到…

Unity(2022.3.41LTS) - 后处理

目录 一、什么是后处理 二、后处理的工作原理 三、后处理的常见效果 四、如何在 Unity 中实现后处理 五、后处理的性能影响 六. 详细效果 一、什么是后处理 后处理是在场景渲染完成后&#xff0c;对最终图像进行的一系列操作。这些操作可以包括调整颜色、添加特效、模糊…

Windows Geth1.14.3私链搭建

geth下载官网&#xff1a;Downloads | go-ethereum 安装完成的目录 安装完后配置环境变量&#xff0c;在终端输入geth version 第一步&#xff1a;第一种创建账户方式geth account new --keystore keystore 创建一个账户&#xff0c;在当前目录下创建一个keystore的子目录&…

Linux工具使用

Linux编辑器-vim使用 1.vim的基本概念 在vim中&#xff0c;主要的三种模式分别是命令模式&#xff0c;插入模式和底行模式。 正常/普通/命令模式(Normal mode) 控制屏幕光标的移动&#xff0c;字符、字或行的删除&#xff0c;移动复制某区段及进入Insert mode下&#xff0c;…

一本读懂数据库发展史的书

数据库及其存储技术&#xff0c;一直以来都是基础软件的主力。数据库系统的操作接口标准&#xff0c;也是应用型软件的重要接口&#xff0c;关系重大。 作为最“有感”的系统软件&#xff0c;数据库的历史悠久、品类繁多、创新活跃。 对数据库历史发展的介绍&#xff0c;有利…

CSS3视图过渡动画

概述 网站的主题切换无非就是文字、背景图片或者颜色,我们可以先来看下 Element UI 官网的切换主题的动效: PS:Antdesign UI的主题切换动画也是大同小异。 实现的两种方式 CSS 为主 <script setup> const changeTheme = (e) => {if (document.startViewTransi…

深度学习实用方法 - 选择超参数篇

序言 在深度学习的浩瀚领域中&#xff0c;超参数的选择无疑是通往卓越模型性能的一把关键钥匙。超参数&#xff0c;作为训练前设定的、用于控制学习过程而非通过学习自动获得的参数&#xff0c;如学习率、批量大小、网络层数及节点数等&#xff0c;直接影响着模型的收敛速度、…

MySQL索引(三)

MySQL索引(三) 文章目录 MySQL索引(三)为什么建索引&#xff1f;怎么建立索引为什么不是说索引越多越好什么时候不用索引更好 索引怎么优化索引失效如何解决索引失效 学习网站&#xff1a;https://xiaolincoding.com/ 为什么建索引&#xff1f; 1.索引大大减少了MySQL需要扫描…

线性约束最小方差准则(LCMV)波束形成算法及MATLAB深入仿真分析

阵列信号处理——线性约束最小方差准则(LCMV)波束形成算法及MATLAB深入仿真分析 目录 前言 一、LCMV算法 二、仿真参数设置 三、抗干扰权值计算仿真 四、不同干扰方位下抗干扰性能仿真 五、不同信噪比和干噪比下抗干扰性能仿真 总结 前言 在信号处理模块中&#xff0c;通…

day13JS-MoseEvent事件

1. MouseEvent的类别 mousedown &#xff1a;按下键mouseup &#xff1a;释放键click &#xff1a;左键单击dblclick &#xff1a;左键双击contextmenu &#xff1a;右键菜单mousemove &#xff1a;鼠标移动mouseover : 鼠标经过 。 可以做事件委托&#xff0c;子元素可以冒泡…