模型评估方法

目录

数据集切分

交叉验证

交叉验证实例

混淆矩阵

实例

代码实现

阈值

全局阈值处理

自适应阈值处理

阈值对结果的影响

ROC曲线


数据集切分

数据集切分是指将一个数据集分割成训练集和测试集的过程。常用的方法是随机切分,即将数据集中的样本按照一定比例分配到训练集和测试集中。切分数据集的目的是为了评估模型在未见过的数据上的性能,以便更好地了解模型的泛化能力。

在Python中,可以使用train_test_split函数来进行数据集切分。该函数位于sklearn.model_selection模块中,可以根据指定的比例将数据集切分成训练集和测试集。

from sklearn.model_selection import train_test_split# 假设x为特征数据,y为标签数据
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=42)# test_size参数指定了测试集的比例,这里设置为0.2,即将20%的数据分配给测试集
# random_state参数用于设置随机种子,保证每次切分的结果一致# 接下来可以使用x_train和y_train进行模型训练,使用x_test和y_test进行模型评估

本实验使用sklearn内置数据集Mnist手写数字识别数据进行实验

# 数据集读取
from sklearn.datasets import fetch_mldata
mnist = fetch_mldata('MNIST original')
X, y = mnist["data"], mnist["target"]
X.shape
# (70000, 784)
# 784个像素点 即784个特征值 28*28*1(长*宽*颜色通道[灰度图为1])
y.shape# 划分数据集
X_train, X_test, y_train, y_test = X[:60000], X[60000:], y[:60000], y[60000:]# 洗牌操作 打乱数据顺序(样本是独立的)
import numpy as npshuffle_index = np.random.permutation(60000)
X_train, y_train = X_train[shuffle_index], y_train[shuffle_index]
shuffle_index
# array([12628, 37730, 39991, ...,   860, 15795, 56422])

交叉验证

交叉验证是一种常用的模型评估方法,通过将数据集划分为训练集和验证集,来评估模型的性能和泛化能力。交叉验证可以帮助我们更好地了解模型在未知数据上的表现,并选择最佳的模型参数。

常见的交叉验证方法有K折交叉验证和留一交叉验证。

  1. K折交叉验证(K-fold Cross Validation):将数据集分成K个子集,其中K-1个子集用于训练模型,剩下的1个子集用于验证模型。这个过程会重复K次,每次选择不同的验证集。最后,将K次验证结果的平均值作为模型的性能指标。

  2. 留一交叉验证(Leave-One-Out Cross Validation,LOO-CV):将每个样本单独作为验证集,其余样本作为训练集。这个过程会重复N次,其中N是数据集的样本数量。最后,将N次验证结果的平均值作为模型的性能指标。

交叉验证的优点是能够更准确地评估模型的性能,减少因数据集划分不合理而引起的偏差。同时,交叉验证还可以帮助我们选择最佳的模型参数,以提高模型的泛化能力。

通俗解释交叉验证:一位高三的学生小明在高考前一直在刷题做53模拟试卷,53模拟试卷就表示测试集,最终的高考即代表训练集,由于高考只有一次机会,而小明想验证学习成果的时候只能先通过53模拟试卷的题来进行测试,由于53模拟试卷的题量多为了更好的验证学习成果则需要多次测试,最终测试结果就等于其多次测试的平均值。

交叉验证实例
# 使用StratifiedKFold进行了交叉验证,并在每个折叠中训练了一个克隆的分类器。
# 然后,使用训练好的分类器对测试集进行预测,并计算了预测准确率
from sklearn.model_selection import StratifiedKFold
from sklearn.base import clone
# 将训练集分成3个折叠,并在每个折叠上进行训练和测试
skflods = StratifiedKFold(n_splits=3,random_state=42)
for train_index,test_index in skflods.split(X_train,y_train_5):clone_clf = clone(sgd_clf)X_train_folds = X_train[train_index]y_train_folds = y_train_5[train_index]X_test_folds = X_train[test_index]y_test_folds = y_train_5[test_index]clone_clf.fit(X_train_folds,y_train_folds)y_pred = clone_clf.predict(X_test_folds)n_correct = sum(y_pred == y_test_folds)print(n_correct/len(y_pred))

混淆矩阵

混淆矩阵是用于评估分类模型性能的一种工具,它将实际目标值与机器学习模型预测的目标值进行比较。混淆矩阵是一个N x N的矩阵,其中N是目标类别的数量。

混淆矩阵的元素含义如下:

  • True Positive(真正,TP):实际为正例,模型预测为正的样本数。
  • True Negative(真负,TN):实际为负例,模型预测为负例的样本数。
  • False Positive(假正,FP):实际为负例,模型预测为正例的样本数。
  • False Negative(假负,FN):实际为正例,模型预测为负例的样本数。

混淆矩阵可以帮助我们计算出各种分类指标,例如正确率、召回率、精确率和F1值等,从而评估模型的性能和效果。

实例

代码实现
# 判断标签是否等于 5 
y_train_5 = (y_train==5)
y_test_5 = (y_test==5)
# 查看前十个
y_train_5[:10]
# array([False, False, False, False, False, False, False, False, False, True])# 使用Scikit-learn库中的SGDClassifier类来训练一个二分类模型
from sklearn.linear_model import SGDClassifier
# 创建一个SGDClassifier对象,并设置参数max_iter=5和random_state=42。max_iter参数表示迭代次数,
# random_state参数用于控制随机数生成器的种子,以确保结果的可重复性
sgd_clf = SGDClassifier(max_iter=5,random_state=42)
# 使用fit()方法来训练模型。fit()方法接受训练数据集X_train和对应的标签y_train_5作为输入。
# 模型将根据这些数据进行学习,以便能够对新的数据进行预测
sgd_clf.fit(X_train,y_train_5)# 使用分类器预测结果
sgd_clf.predict([X[35000]]) #array([ True])
# 查看标签的真实值
y[35000]  # 5.0
#混淆矩阵 二分类任务
from sklearn.model_selection import cross_val_predict
y_train_pred = cross_val_predict(sgd_clf,X_train,y_train_5,cv=3)
y_train_pred.shape #(60000,)
X_train.shape #(60000, 784)
from sklearn.metrics import confusion_matrix
# 需要两个参数 标签值、预测值
confusion_matrix(y_train_5,y_train_pred)
"""
array([[53272,  1307],[ 1077,  4344]], dtype=int64)
"""

negative class [[ true negatives , false positives ],

positive class [ false negatives , true positives ]]

  • true negatives: 53,272个数据被正确的分为非5类别
  • false positives:1307张被错误的分为5类别

  • false negatives:1077张错误的分为非5类别

  • true positives: 4344张被正确的分为5类别

一个完美分类器应该只有true positives 和 true negatives, 即主对角线元素不为0,其余元素为0

精度、召回率计算

from sklearn.metrics import precision_score,recall_score 
# 精度
precision_score(y_train_5,y_train_pred) # 0.7687135020350381
# 召回率
recall_score(y_train_5,y_train_pred) #0.801328168234643

F1 score指标

Precision 和 Recall结合到一个称为F1 score 的指标,调和平均值给予低值更多权重。 因此,如果召回和精确度都很高,分类器将获得高F1分数。

代码实现

from sklearn.metrics import f1_score
f1_score(y_train_5,y_train_pred) #0.7846820809248555

阈值

阈值是图像处理中的一个重要概念,它是指将图像转换为二值图像的临界点。在阈值处理中,根据设定的阈值,将图像中的像素值分为两个类别,一类大于阈值,另一类小于阈值。大于阈值的像素被赋予一个固定的值(通常是白色),小于阈值的像素被赋予另一个固定的值(通常是黑色)。这样就可以将图像转换为黑白图像,以突出图像中的目标物体或特定区域。

阈值处理在图像分割、边缘检测、目标检测等领域有广泛的应用。常见的阈值处理方法有全局阈值和自适应阈值。

全局阈值是指将整个图像的像素值与设定的阈值进行比较,根据比较结果将像素分为两类。全局阈值处理适用于图像的整体对比度较好的情况。

自适应阈值是根据图像的局部特性来确定阈值。它将图像分成多个小区域,针对每个小区域计算局部阈值。自适应阈值处理适用于图像的局部对比度不均匀的情况。

全局阈值处理
import cv2# 读取图像
image = cv2.imread('image.jpg', 0)# 设定阈值
threshold_value = 127# 对图像进行全局阈值处理
_, thresholded_image = cv2.threshold(image, threshold_value, 255, cv2.THRESH_BINARY)# 显示结果
cv2.imshow('Thresholded Image', thresholded_image)
cv2.waitKey(0)
cv2.destroyAllWindows()
自适应阈值处理
import cv2# 读取图像
image = cv2.imread('image.jpg', 0)# 对图像进行自适应阈值处理
thresholded_image = cv2.adaptiveThreshold(image, 255, cv2.ADAPTIVE_THRESH_MEAN_C, cv2.THRESH_BINARY, 11, 2)# 显示结果
cv2.imshow('Thresholded Image', thresholded_image)
cv2.waitKey(0)
cv2.destroyAllWindows()
阈值对结果的影响

Scikit-Learn不允许直接设置阈值,但它可以得到决策分数,调用其decision_function()方法,而不是调用分类器的predict()方法,该方法返回每个实例的分数,然后使用想要的阈值根据这些分数进行预测。

y_scores = sgd_clf.decision_function([X[35000]])
y_scores # array([43349.73739616])
t = 50000
y_pred = (y_scores > t)
y_pred #array([False])y_scores = cross_val_predict(sgd_clf, X_train, y_train_5, cv=3,method="decision_function")y_scores[:10]
from sklearn.metrics import precision_recall_curve
precisions, recalls, thresholds = precision_recall_curve(y_train_5, y_scores)
y_train_5.shape #(60000,)
thresholds.shape #(59698,)
precisions[:10]
precisions.shape #(59699,)
recalls.shape #(59699,)def plot_precision_recall_vs_threshold(precisions,recalls,thresholds):plt.plot(thresholds,precisions[:-1],"b--",label="Precision")plt.plot(thresholds,recalls[:-1],"g-",label="Recall")plt.xlabel("Threshold",fontsize=16)plt.legend(loc="upper left",fontsize=16)plt.ylim([0,1])plt.figure(figsize=(8, 4))
plot_precision_recall_vs_threshold(precisions,recalls,thresholds)
plt.xlim([-700000, 700000])
plt.show()

ROC曲线

ROC曲线(Receiver Operating Characteristic curve)是一种用于评估二分类模型性能的图形工具。它以虚警率(False Positive Rate)为横轴,命中率(True Positive Rate)为纵轴,绘制出的曲线可以反映出模型在不同阈值下的性能表现。

ROC曲线的横轴是虚警率,表示将负例错误地判定为正例的概率。纵轴是命中率,表示将正例正确地判定为正例的概率。ROC曲线上的每个点代表了在不同阈值下模型的性能表现,而曲线上的每个点都对应着一个不同的阈值。

ROC曲线的形状可以帮助我们评估模型的性能。曲线越靠近左上角,说明模型的性能越好,虚警率较低的同时命中率较高。曲线越接近对角线,说明模型的性能越差,虚警率和命中率的比例相对均衡。

通过比较不同模型的ROC曲线,我们可以选择最佳的模型。通常情况下,我们会选择曲线下面积(Area Under Curve,AUC)较大的模型作为最佳模型,因为AUC值表示了模型在所有阈值下的平均性能。

from sklearn.metrics import roc_curve
fpr, tpr, thresholds = roc_curve(y_train_5, y_scores)def plot_roc_curve(fpr, tpr, label=None):plt.plot(fpr, tpr, linewidth=2, label=label)plt.plot([0, 1], [0, 1], 'k--')plt.axis([0, 1, 0, 1])plt.xlabel('False Positive Rate', fontsize=16)plt.ylabel('True Positive Rate', fontsize=16)plt.figure(figsize=(8, 6))
plot_roc_curve(fpr, tpr)
plt.show()

receiver operating characteristic (ROC) 曲线是二元分类中的常用评估方法

  • 它与精确度/召回曲线非常相似,但ROC曲线不是绘制精确度与召回率,而是绘制true positive rate(TPR) 与false positive rate(FPR)

  • 要绘制ROC曲线,首先需要使用roc_curve()函数计算各种阈值的TPR和FPR

TPR = TP / (TP + FN) (Recall)

FPR = FP / (FP + TN)

虚线表示纯随机分类器的ROC曲线; 一个好的分类器尽可能远离该线(朝左上角)。

比较分类器的一种方法是测量曲线下面积(AUC)。完美分类器的ROC AUC等于1,而纯随机分类器的ROC AUC等于0.5。 Scikit-Learn提供了计算ROC AUC的函数

from sklearn.metrics import roc_auc_scoreroc_auc_score(y_train_5, y_scores) #0.9624496555967156

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/237957.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

OpenAI 官方 Prompt 工程指南:写好 Prompt 的六个策略

其实一直有很多人问我,Prompt 要怎么写效果才好,有没有模板。 我每次都会说,能清晰的表达你的想法,才是最重要的,各种技巧都是其次。但是,我还是希望发给他们一些靠谱的文档。 但是,网上各种所…

APEX后台弱密码增强改造出现的问题及解决方法

为了加强APEX后台密码的安全性和可靠性,对其进行弱密码改造,通过改写登录函数,判断密码可靠性,在密码不符合条件(密码长度必须大于8位小于16位,其包含数字、大小写字母与特殊符号)时跳转到密码修…

【Docker】基于华为 openEuler 应用 Docker 镜像体积压缩

书接 openEuler 系列文章(可以翻看测试系列),本次跟大家说说如何将 Java 包轻量化地构建到 openEuler 镜像中且保持镜像内操作系统是全补丁状态。 之前我们都是使用现成的 jdk 镜像进行构建的,如下图: FROM ibm-seme…

智能数字人1688直播软件系统源码有哪些适用的场景?

智能数字人1688直播软件系统源码适用于多个场景,小编给大家列举了一些。 以下是部分代码的示例: 适用场景: 1.电商直播:1688智能数字人直播软件系统源码可以用于电商直播平台,为商家提供智能化的直播服务。数字人主播…

macOS制作dmg包

macOS制作dmg包 准备:磁盘工具、以及要制作的软件,这里以Firefox为例 图片素材 背景图: 找到Firefox,点击显示简介,查看包的大小 打开磁盘工具 文件–>新建映像–>空白映像 填写信息,大小…

nodejs微信小程序+python+PHP个性化书籍推荐系统-计算机毕业设计推荐

目 录 摘 要 I ABSTRACT II 目 录 II 第1章 绪论 1 1.1背景及意义 1 1.2 国内外研究概况 1 1.3 研究的内容 1 第2章 相关技术 3 2.1 nodejs简介 4 2.2 express框架介绍 6 2.4 MySQL数据库 4 第3章 系统分析 5 3.1 需求分析 5 3.2 系统可行性分析 5 3.2.1技术可行性:…

静态HTTP:构建高效、可扩展的Web应用程序的基础

静态HTTP是Web应用程序的重要组成部分,它为构建高效、可扩展的Web应用程序提供了坚实的基础。下面将详细介绍静态HTTP的优势和在Web应用程序中的作用。 一、静态HTTP的优势 高效性能:静态HTTP内容在服务器上预先生成,然后通过HTTP协议传输到…

CloudPulse:一款针对AWS云环境的SSL证书搜索与分析引擎

关于CloudPulse CloudPulse是一款针对AWS云环境的SSL证书搜索与分析引擎,广大研究人员可以使用该工具简化并增强针对SSL证书数据的检索和分析过程。 在网络侦查阶段,我们往往需要收集与目标相关的信息,并为目标创建一个专用文档&#xff0c…

智慧互联网银行引领金融变革,开源网安VulHunter护航数字化发展

某银行作为国内知名的互联网银行,以构建“智慧型互联行”为总体战略目标,始终坚持科技赋能金融的理念。通过AI、大数据、云计算等数字技术与金融业务的探索融合,实现以更低的成本为客户提供便捷、高效和优质体验的互联网金融服务。 架构升级助…

DBeaver中使用外部格式化程序对进行sql格式化

本文介绍了如何在DBeaver中使用pgFormatter、sqlprase、sqlformatter等外部格式化程序对sql进行格式化。 目录 一、pgFormatter 1.准备工作 2.DBeaver中进行配置 二、sqlprase 1.准备工作 2.在DBeaver中配置 三、sql-formatter 1.准备工作 2.在DBeaver中配置 一、pgF…

防火墙安全策略

目录 一、防火墙种类 二、防火墙流量控制手段 1、包过滤技术(传统) 2、状态检测技术 (1)、状态检测机制 三、安全实验 1、拓扑 2、需求 3、配置思路 4、关键配置截图 5、验证 一、防火墙种类 对于防火墙来说就是针对哪…

选型前必看,CRM系统在线演示为什么重要?

在CRM挑选环节中,假如企业需要深入了解CRM管理系统的功能和功能,就需要CRM厂商提供在线演示。简单的说,就是按照企业的需要,检测怎样通过CRM进行。如今我们来谈谈CRM在线演示的作用。 在线演示 1、了解CRM情况 熟悉系统功能&…

姿态识别、目标检测和跟踪的综合应用

引言: 近年来,随着人工智能技术的不断发展,姿态识别、目标检测和跟踪成为了计算机视觉领域的热门研究方向。这三个技术的综合应用为各个行业带来了巨大的变革和机遇。本文将分别介绍姿态识别、目标检测和跟踪的基本概念和算法,并探…

基于Java开发的微信约拍小程序

一、系统架构 前端:vue | element-ui 后端:springboot | mybatis 环境:jdk8 | mysql8 | maven | mysql 二、代码及数据库 三、功能说明 01. 首页 02. 授权登录 03. 我的 04. 我的-编辑个人资料 05. 我的-我的联系方式 06. …

等待队列头实现阻塞 IO(BIO)

文章目录 等待队列头实现阻塞 IO(BIO)模型等待队列头init_waitqueue_headDECLARE_WAIT_QUEUE_HEAD 等待队列项使用方法驱动程序应用程序模块使用参考 等待队列头实现阻塞 IO(BIO) 等待队列是内核实现阻塞和唤醒的内核机制。 等待队列以循环链表为基础结构,链表头和…

苹果如何从iCloud恢复备份?正确方法看这里!

iCloud为所有苹果用户免费提供5G内存空间,用户可以将照片、短信、联系人、备忘录等重要信息备份到iCloud云端,这样可以方便在不同设备之间同步和共享。 同时,iCloud保证这些数据在所有苹果设备上及时自动更新。当遇到手机数据丢失时&#xf…

构建搜索引擎,而非向量数据库(Vector DB) [译]

原文:Build a search engine, not a vector DB 作者: Panda Smith 在过去 12 个月中,我们见证了向量数据库(Vector DB)创业公司的迅猛增长。我此刻并不打算深入探讨它们各自的设计取舍。相反,我更想探讨和…

做外贸多想一步,多走一步

最近在网上给小儿买了一个液晶画画板,自从告诉小儿已经购物需要耐心等待之后,几乎每天小儿要询问几遍,快递到哪里了? 好不容易盼到了,结果打开一看却是个坏的,虽然外包装是好的,但是明显这个快…

数据库客户案例:每个物种都需要一个数据库!

1、GERDH——花卉多组学数据库 项目名称:GERDH:花卉多组学数据库 链接地址:https://dphdatabase.com 项目描述:GERDH包含了来自150多种园艺花卉植物种质的 12961个观赏植物。将不同花卉植物转录组学、表观组学等数据进行比较&am…

读《文明之光》第四册总结

今天来给大家分享一下【吴军】老师的《文明之光》,该书全套共四册,今天给大家分享的是第四册。 人总是要有些理想和信仰。初读这本书,就被本书的第一句话说感动过。 当人们问起我的理想时,我就给他们讲…