第100+23步 ChatGPT学习:概率校准 Sigmoid Calibration

基于Python 3.9版本演示

一、写在前面

最近看了一篇在Lancet子刊《eClinicalMedicine》上发表的机器学习分类的文章:《Development of a novel dementia risk prediction model in the general population: A large, longitudinal, population-based machine-learning study》。

学到一种叫做“概率校准”的骚操作,顺手利用GPT系统学习学习。

文章中用的技术是:保序回归(Isotonic regression)。

为了体现举一反三,顺便问了GPT还有哪些方法也可以实现概率校准。它给我列举了很多,那么就一个一个学习吧。

这一期,介绍一个叫做 Sigmoid Calibration 的方法。

二、Sigmoid Calibration

Sigmoid Calibration是一种后处理技术,用于改进机器学习分类器的概率估计。它通常应用于二元分类器的输出,将原始得分转换为校准后的概率。该技术使用逻辑(Sigmoid)函数将分类器的得分映射到概率上,旨在确保预测的概率更准确地反映真实结果的可能性。

(1)Sigmoid Calibration 的基本步骤

1)训练分类器:在训练数据上训练你的二元分类器。

2)获取原始得分:收集分类器在验证数据集上的原始得分或 logits。

3)拟合逻辑回归模型:使用验证数据集拟合一个逻辑回归模型,将原始得分映射为概率。

4)预测校准后的概率:使用拟合的逻辑回归模型,将分类器的原始得分转换为校准后的概率。

(2)Sigmoid Calibration 的使用

对于逻辑回归模型,通常不需要进行Sigmoid校准,因为逻辑回归本身就是基于Sigmoid函数来计算概率的。然而,在一些情况下,即使是逻辑回归模型,校准仍然可能有帮助。以下是一些可能需要校准的情况:

1)类不平衡问题:如果训练数据集中存在严重的类别不平衡问题,即某个类别的数据明显多于其他类别,逻辑回归模型的概率估计可能会偏向于较多的类别。在这种情况下,校准可以帮助调整概率估计,使其更准确地反映实际的类别分布。

2)模型训练不充分:如果逻辑回归模型没有充分训练,可能会导致概率估计不准确。校准可以在一定程度上纠正这种情况。

3)训练和测试数据分布不同:如果训练数据和测试数据的分布存在差异,逻辑回归模型的概率估计可能不适用于测试数据。在这种情况下,可以使用校准技术对模型的输出进行调整。

4)多模型集成:在多模型集成(例如集成学习)中,不同模型的输出需要组合在一起。校准可以确保不同模型的输出概率具有一致性,从而提高集成模型的性能。

三、Sigmoid Calibration代码实现

下面,我编一个1比3的不太平衡的数据进行测试,对照组使用不进行校准的SVM模型,实验组就是加入校准的SVM模型,看看性能能够提高多少?

(1)不进行校准的SVM模型(默认参数)

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.svm import SVC
from sklearn.metrics import confusion_matrix, roc_auc_score, roc_curve# 加载数据
dataset = pd.read_csv('8PSMjianmo.csv')
X = dataset.iloc[:, 1:20].values
Y = dataset.iloc[:, 0].values# 分割数据集
X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=0.30, random_state=666)# 标准化数据
sc = StandardScaler()
X_train = sc.fit_transform(X_train)
X_test = sc.transform(X_test)# 使用SVM分类器
classifier = SVC(kernel='linear', probability=True)
classifier.fit(X_train, y_train)# 预测结果
y_pred = classifier.predict(X_test)
y_testprba = classifier.decision_function(X_test)y_trainpred = classifier.predict(X_train)
y_trainprba = classifier.decision_function(X_train)# 混淆矩阵
cm_test = confusion_matrix(y_test, y_pred)
cm_train = confusion_matrix(y_train, y_trainpred)
print(cm_train)
print(cm_test)# 绘制测试集混淆矩阵
classes = list(set(y_test))
classes.sort()
plt.imshow(cm_test, cmap=plt.cm.Blues)
indices = range(len(cm_test))
plt.xticks(indices, classes)
plt.yticks(indices, classes)
plt.colorbar()
plt.xlabel('Predicted')
plt.ylabel('Actual')
for first_index in range(len(cm_test)):for second_index in range(len(cm_test[first_index])):plt.text(first_index, second_index, cm_test[first_index][second_index])plt.show()# 绘制训练集混淆矩阵
classes = list(set(y_train))
classes.sort()
plt.imshow(cm_train, cmap=plt.cm.Blues)
indices = range(len(cm_train))
plt.xticks(indices, classes)
plt.yticks(indices, classes)
plt.colorbar()
plt.xlabel('Predicted')
plt.ylabel('Actual')
for first_index in range(len(cm_train)):for second_index in range(len(cm_train[first_index])):plt.text(first_index, second_index, cm_train[first_index][second_index])plt.show()# 计算并打印性能参数
def calculate_metrics(cm, y_true, y_pred_prob):a = cm[0, 0]b = cm[0, 1]c = cm[1, 0]d = cm[1, 1]acc = (a + d) / (a + b + c + d)error_rate = 1 - accsen = d / (d + c)sep = a / (a + b)precision = d / (b + d)F1 = (2 * precision * sen) / (precision + sen)MCC = (d * a - b * c) / (np.sqrt((d + b) * (d + c) * (a + b) * (a + c)))auc_score = roc_auc_score(y_true, y_pred_prob)metrics = {"Accuracy": acc,"Error Rate": error_rate,"Sensitivity": sen,"Specificity": sep,"Precision": precision,"F1 Score": F1,"MCC": MCC,"AUC": auc_score}return metricsmetrics_test = calculate_metrics(cm_test, y_test, y_testprba)
metrics_train = calculate_metrics(cm_train, y_train, y_trainprba)print("Performance Metrics (Test):")
for key, value in metrics_test.items():print(f"{key}: {value:.4f}")print("\nPerformance Metrics (Train):")
for key, value in metrics_train.items():
print(f"{key}: {value:.4f}")

结果输出:

记住这些个数字。

这个参数的SVM还没有LR好。

(2)进行校准的SVM模型(默认参数)

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.svm import SVC
from sklearn.metrics import confusion_matrix, roc_auc_score, roc_curve
from sklearn.calibration import CalibratedClassifierCV# 加载数据
dataset = pd.read_csv('8PSMjianmo.csv')
X = dataset.iloc[:, 1:20].values
Y = dataset.iloc[:, 0].values# 分割数据集
X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=0.30, random_state=666)# 标准化数据
sc = StandardScaler()
X_train = sc.fit_transform(X_train)
X_test = sc.transform(X_test)# 使用SVM分类器
classifier = SVC(kernel='rbf', C= 0.1, probability=True)
classifier.fit(X_train, y_train)# 进行Sigmoid校准
calibrated_classifier = CalibratedClassifierCV(base_estimator=classifier, method='sigmoid', cv='prefit')
calibrated_classifier.fit(X_test, y_test)# 预测结果
y_pred = calibrated_classifier.predict(X_test)
y_testprba = calibrated_classifier.predict_proba(X_test)[:, 1]y_trainpred = calibrated_classifier.predict(X_train)
y_trainprba = calibrated_classifier.predict_proba(X_train)[:, 1]# 混淆矩阵
cm_test = confusion_matrix(y_test, y_pred)
cm_train = confusion_matrix(y_train, y_trainpred)
print(cm_train)
print(cm_test)# 绘制测试集混淆矩阵
classes = list(set(y_test))
classes.sort()
plt.imshow(cm_test, cmap=plt.cm.Blues)
indices = range(len(cm_test))
plt.xticks(indices, classes)
plt.yticks(indices, classes)
plt.colorbar()
plt.xlabel('Predicted')
plt.ylabel('Actual')
for first_index in range(len(cm_test)):for second_index in range(len(cm_test[first_index])):plt.text(first_index, second_index, cm_test[first_index][second_index])plt.show()# 绘制训练集混淆矩阵
classes = list(set(y_train))
classes.sort()
plt.imshow(cm_train, cmap=plt.cm.Blues)
indices = range(len(cm_train))
plt.xticks(indices, classes)
plt.yticks(indices, classes)
plt.colorbar()
plt.xlabel('Predicted')
plt.ylabel('Actual')
for first_index in range(len(cm_train)):for second_index in range(len(cm_train[first_index])):plt.text(first_index, second_index, cm_train[first_index][second_index])plt.show()# 计算并打印性能参数
def calculate_metrics(cm, y_true, y_pred_prob):a = cm[0, 0]b = cm[0, 1]c = cm[1, 0]d = cm[1, 1]acc = (a + d) / (a + b + c + d)error_rate = 1 - accsen = d / (d + c)sep = a / (a + b)precision = d / (b + d)F1 = (2 * precision * sen) / (precision + sen)MCC = (d * a - b * c) / (np.sqrt((d + b) * (d + c) * (a + b) * (a + c)))auc_score = roc_auc_score(y_true, y_pred_prob)metrics = {"Accuracy": acc,"Error Rate": error_rate,"Sensitivity": sen,"Specificity": sep,"Precision": precision,"F1 Score": F1,"MCC": MCC,"AUC": auc_score}return metricsmetrics_test = calculate_metrics(cm_test, y_test, y_testprba)
metrics_train = calculate_metrics(cm_train, y_train, y_trainprba)print("Performance Metrics (Test):")
for key, value in metrics_test.items():print(f"{key}: {value:.4f}")print("\nPerformance Metrics (Train):")
for key, value in metrics_train.items():print(f"{key}: {value:.4f}")

看看结果:

总体来看,仅仅训练集起作用了,验证集差强人意。

四、换个策略

参考那篇文章的策略:采用五折交叉验证来建立和评估模型,其中四折用于训练,一折用于评估,在训练集中,其中三折用于建立SVM模型,另一折采用Sigmoid Calibration概率校正,在训练集内部采用交叉验证对超参数进行调参。

代码:

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.model_selection import train_test_split, KFold, GridSearchCV
from sklearn.preprocessing import StandardScaler
from sklearn.svm import SVC
from sklearn.calibration import CalibratedClassifierCV
from sklearn.metrics import confusion_matrix, roc_auc_score, make_scorer# 加载数据
dataset = pd.read_csv('8PSMjianmo.csv')
X = dataset.iloc[:, 1:20].values
Y = dataset.iloc[:, 0].values# 标准化数据
sc = StandardScaler()
X = sc.fit_transform(X)# 五折交叉验证
kf = KFold(n_splits=5, shuffle=True, random_state=666)# 超参数调优参数网格
param_grid = {'C': [0.1, 1, 10, 100],'kernel': ['linear', 'rbf']
}# 计算并打印性能参数
def calculate_metrics(cm, y_true, y_pred_prob):a = cm[0, 0]b = cm[0, 1]c = cm[1, 0]d = cm[1, 1]acc = (a + d) / (a + b + c + d)error_rate = 1 - accsen = d / (d + c)sep = a / (a + b)precision = d / (b + d)F1 = (2 * precision * sen) / (precision + sen)MCC = (d * a - b * c) / (np.sqrt((d + b) * (d + c) * (a + b) * (a + c)))auc_score = roc_auc_score(y_true, y_pred_prob)metrics = {"Accuracy": acc,"Error Rate": error_rate,"Sensitivity": sen,"Specificity": sep,"Precision": precision,"F1 Score": F1,"MCC": MCC,"AUC": auc_score}return metrics# 初始化结果列表
results_train = []
results_test = []# 初始化变量以跟踪最优模型
best_auc = 0
best_model = None
best_X_train = None
best_X_test = None
best_y_train = None
best_y_test = None# 交叉验证过程
for train_index, test_index in kf.split(X):X_train, X_test = X[train_index], X[test_index]y_train, y_test = Y[train_index], Y[test_index]# 内部交叉验证进行超参数调优和模型训练inner_kf = KFold(n_splits=4, shuffle=True, random_state=666)grid_search = GridSearchCV(SVC(probability=True), param_grid, cv=inner_kf, scoring='roc_auc')grid_search.fit(X_train, y_train)model = grid_search.best_estimator_# Sigmoid Calibration 概率校准calibrated_svm = CalibratedClassifierCV(model, method='sigmoid', cv='prefit')calibrated_svm.fit(X_train, y_train)# 评估模型y_trainpred = calibrated_svm.predict(X_train)y_trainprba = calibrated_svm.predict_proba(X_train)[:, 1]cm_train = confusion_matrix(y_train, y_trainpred)metrics_train = calculate_metrics(cm_train, y_train, y_trainprba)results_train.append(metrics_train)y_pred = calibrated_svm.predict(X_test)y_testprba = calibrated_svm.predict_proba(X_test)[:, 1]cm_test = confusion_matrix(y_test, y_pred)metrics_test = calculate_metrics(cm_test, y_test, y_testprba)results_test.append(metrics_test)# 更新最优模型if metrics_test['AUC'] > best_auc:best_auc = metrics_test['AUC']best_model = calibrated_svmbest_X_train = X_trainbest_X_test = X_testbest_y_train = y_trainbest_y_test = y_testbest_params = grid_search.best_params_print("Performance Metrics (Train):")for key, value in metrics_train.items():print(f"{key}: {value:.4f}")print("\nPerformance Metrics (Test):")for key, value in metrics_test.items():print(f"{key}: {value:.4f}")print("\n" + "="*40 + "\n")# 使用最优模型评估性能
y_trainpred = best_model.predict(best_X_train)
y_trainprba = best_model.predict_proba(best_X_train)[:, 1]
cm_train = confusion_matrix(best_y_train, y_trainpred)
metrics_train = calculate_metrics(cm_train, best_y_train, y_trainprba)y_pred = best_model.predict(best_X_test)
y_testprba = best_model.predict_proba(best_X_test)[:, 1]
cm_test = confusion_matrix(best_y_test, y_pred)
metrics_test = calculate_metrics(cm_test, best_y_test, y_testprba)print("Performance Metrics of the Best Model (Train):")
for key, value in metrics_train.items():print(f"{key}: {value:.4f}")print("\nPerformance Metrics of the Best Model (Test):")
for key, value in metrics_test.items():print(f"{key}: {value:.4f}")# 打印最优模型的参数
print("\nBest Model Parameters:")
for key, value in best_params.items():print(f"{key}: {value}")

输出:

还是有提升的,不过并没有Platt Scaling的结果好。

五、最后

各位可以去试一试在其他数据或者在其他机器学习分类模型中使用的效果。

数据不分享啦。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/51561.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

0.0 C语言被我遗忘的知识点

文章目录 位移运算(>>和<<)函数指针函数指针的应用场景 strcmp的返回值合法的c语言实数表示sizeof 数组字符串的储存 —— 字符数组与字符指针字符串可能缺少 \0 的情况 用二维数组储存字符串数组其他储存字符串数组的方法 位移运算(>>和<<) 右移(>…

c++中的匿名对象及内存管理

c中的匿名对象 A a;//a的生命周期在整个main函数中 a.Sum(1); //匿名对象生命周期只有一行&#xff0c;只有这一行会创建对象,出了这一行就会调析构 A().Sum(1);//只有这一行需要这个对象&#xff0c;其他地方不需要。 return 0; 日期到天数的转换 计算日期到天数转换_牛客…

Linux不可靠信号和可靠信号

1.不可靠信号和可靠信号 建立在早期的信号处理机制上**(1~31)的信号是不可靠信号** 不支持排队、可能会丢失&#xff0c;如果同一个信号连续产生多次&#xff0c;进程可能只相应了一次 如果解除屏蔽&#xff0c;取决于可靠性&#xff01;&#xff01;&#xff01;1 建立在新的信…

pytest自定义命令行选项

在 pytest 中&#xff0c;您可以通过多种方式自定义命令行选项。以下是一些常用的方法&#xff1a; 使用 pytest 的 addoption 方法 您可以在 conftest.py 文件中使用 pytest_addoption 钩子函数来添加自定义命令行选项。 下面是一个示例&#xff0c;展示如何添加一个名为 --…

python 执行mysql文件

MySQL 5.7 之前的版本在某些方面对存储过程的支持可能有限&#xff0c;通过 Python 脚本调用 SQL 文件是一种有效的方法来自动化数据库任务和提高运维效率。具体操作如下&#xff1a; # python 执行sql脚本 def execute_sql_script(mysql_params, script_name):# 使用 with 语…

硬件调试经验积累 关于RTC 时钟问题。

1. 电脑的 RTC 问题可以排查有这几个地方 1. 硬件问题 2. BIOS 问题 3. 系统问题 2. 排查问题大致的操作 1. 使用计算系统xx 通讯读取 RTC 芯片的寄存器&#xff0c;查看芯片是否有问题。 2. 再BIOS 下查看时钟是否准确。 查看芯片的连接性是否有问题。 3..多次----断电后开机…

【鸿蒙样式初探】多个组件如何共用同一样式

最近开发鸿蒙&#xff0c;刚接触难免二和尚摸不着头脑&#xff0c;尤其是样式...... 背景 在做银行卡显示的一个小需求时&#xff1a; 每个Text都需要设置fontColor:#FFFFFF" 想着是否可以简单点 解决历程 思路一&#xff1a;&#xff08;拒绝) 使用Styles 提取封装公…

虚拟化、双层虚拟化、容器是什么,应用场景有哪些

一、基本概念 虚拟化是一种资源管理技术&#xff0c;通过将计算机的实体资源&#xff08;如CPU、内存、磁盘空间和网络适配器等&#xff09;进行抽象和转换&#xff0c;使其能够被分割、组合成一个或多个电脑配置环境。多层虚拟化则是在虚拟化的基础上&#xff0c;再进行一次虚…

爆改YOLOv8|利用可改变核卷积AKConv改进yolov8-轻量涨点

1&#xff0c;本文介绍 AKConv&#xff08;可改变核卷积&#xff09;是一种改进的卷积操作方法&#xff0c;其核心在于动态调整卷积核的形状和大小。与传统卷积层固定核大小不同&#xff0c;AKConv 通过引入可学习的机制&#xff0c;使卷积核在训练过程中能够自适应地调整&…

学生宿舍管理小程序的设计

管理员账户功能包括&#xff1a;系统首页&#xff0c;个人中心&#xff0c;宿舍公告管理&#xff0c;学生管理&#xff0c;宿舍管理&#xff0c;后勤人员管理&#xff0c;楼栋信息管理&#xff0c;宿舍分配管理管理&#xff0c;退宿信息管理 微信端账号功能包括&#xff1a;系…

程序猿成长之路之数据挖掘篇——Kmeans聚类算法

Kmeans 是一种可以将一个数据集按照距离&#xff08;相似度&#xff09;划分成不同类别的算法&#xff0c;它无需借助外部标记&#xff0c;因此也是一种无监督学习算法。 什么是聚类 用官方的话说聚类就是将物理或抽象对象的集合分成由类似的对象组成的多个类的过程。用自己的…

idea import配置

简介 本文记录idea中import相关配置&#xff1a;自动导入依赖、自动删除无用依赖、避免自动导入*包 自动导入依赖 在编辑代码时&#xff0c;当只有一个具有匹配名称的可导入声明时&#xff0c;会自动添加导入 File -> Settings -> Editor -> General -> Auto Imp…

简而不减,极致便捷!泰极预付费解决方案震撼上市

开户麻烦!绑表复杂!用电情况模糊!电费收缴难! 在日常生活中,能源缴费可能经常会遇到运维难管理、缴费收益难计算、支付安全难保障等问题。如何解决呢?正泰物联推出“泰极预付费解决方案”,“简”操作,“不减”功能,有效解决上述问题,助力实现便捷生活。 享轻松:泰极简而不减…

FPGA工程师成长路线(持续更新ing,欢迎补充)

一、开发能力 1、FPGA基础知识 &#xff08;1&#xff09;数电基础知识 逻辑门 锁存器 触发器 进制 码制 状态机 竞争与冒险 verilog语法 &#xff08;2&#xff09;FPGA片上资源 可配置逻辑块 嵌入式块RAM 时钟管理资源 可编程输入输出单元&#xff08;IOB&#xff09; 丰富的…

MySQL内部临时表(Using temporary)案例详解及优化解决方法

目录 前言 一.场景案例 二、什么是内部临时表&#xff1f; 三、哪些场景会使用内部临时表&#xff1f; 四、内部临时表如何存储&#xff1f; 1&#xff09;使用内存 2&#xff09;先使用内存&#xff0c;再转化成磁盘文件 3&#xff09;直接使用磁盘文件 五、如何优化…

【软件文档】项目总结报告编制模板(Word原件参考)

1. 项目概要 1.1. 项目基本信息 1.2. 项目期间 1.3. 项目成果 1.4. 开发工具和环境 2. 项目工作分析 2.1. 项目需求变更 2.2. 项目计划与进度实施 2.3. 项目总投入情况 2.4. 项目总收益情况 2.5. 项目质量情况 2.6. 风险管理实施情况 3. 经验与教训 3.1. 经验总结…

【异常错误】pycharm可以在terminal中运行,但是无法在run中运行(没有输出错误就停止了)

问题&#xff1a; pycharm的命令可以在terminal中运行&#xff0c;但是复制到无法在run中运行&#xff08;没有输出错误就停止了&#xff09; run中运行后什么错误提示都没有 搞不懂为什么 解决&#xff1a; 降低run中batch-size的大小&#xff0c;即可以运行 我并没有观察到…

Unity(2022.3.41LTS) - 后处理

目录 一、什么是后处理 二、后处理的工作原理 三、后处理的常见效果 四、如何在 Unity 中实现后处理 五、后处理的性能影响 六. 详细效果 一、什么是后处理 后处理是在场景渲染完成后&#xff0c;对最终图像进行的一系列操作。这些操作可以包括调整颜色、添加特效、模糊…

Windows Geth1.14.3私链搭建

geth下载官网&#xff1a;Downloads | go-ethereum 安装完成的目录 安装完后配置环境变量&#xff0c;在终端输入geth version 第一步&#xff1a;第一种创建账户方式geth account new --keystore keystore 创建一个账户&#xff0c;在当前目录下创建一个keystore的子目录&…

Linux工具使用

Linux编辑器-vim使用 1.vim的基本概念 在vim中&#xff0c;主要的三种模式分别是命令模式&#xff0c;插入模式和底行模式。 正常/普通/命令模式(Normal mode) 控制屏幕光标的移动&#xff0c;字符、字或行的删除&#xff0c;移动复制某区段及进入Insert mode下&#xff0c;…