【sklearn练习】模型评估

一、交叉验证 cross_val_score 的使用

1、不用交叉验证的情况:

from __future__ import print_function
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifieriris = load_iris()
X = iris.data
y = iris.target# test train split #
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=4)
knn = KNeighborsClassifier(n_neighbors=5)
knn.fit(X_train, y_train)
print(knn.score(X_test, y_test))

输出结果:

0.9736842105263158

2、使用交叉验证

from sklearn.model_selection import cross_val_score
knn2 = KNeighborsClassifier(n_neighbors=5)
scores = cross_val_score(knn2, X, y, cv=5, scoring='accuracy')
print(scores)

输出结果:

[0.96666667 1.         0.93333333 0.96666667 1.        ]

二、确定合适模型参数

1、迭代模型中n_neighbors参数

import matplotlib.pyplot as plt
k_range = range(1, 31)
k_scores = []
for k in k_range:knn = KNeighborsClassifier(n_neighbors=k)
##    loss = -cross_val_score(knn, X, y, cv=10, scoring='mean_squared_error') # for regressionscores = cross_val_score(knn, X, y, cv=10, scoring='accuracy') # for classificationk_scores.append(scores.mean())plt.plot(k_range, k_scores)
plt.xlabel('Value of K for KNN')
plt.ylabel('Cross-Validated Accuracy')
plt.show()

画出scores为:

下面是画loss的代码:

k_range = range(1, 31)
k_loss = []
for k in k_range:knn = KNeighborsClassifier(n_neighbors=k)loss = -cross_val_score(knn, X, y, cv=10, scoring='neg_mean_squared_error') # for regression##    scores = cross_val_score(knn, X, y, cv=10, scoring='accuracy') # for classificationk_loss.append(loss.mean())plt.plot(k_range, k_loss)
plt.xlabel('Value of K for KNN')
plt.ylabel('neg_mean_squared_error')
plt.show()

画出loss为:

三、cross_val_score  中的  scoring参数(本标题内容可删,可以是一个链接插入解释这个参数即可)

cross_val_score 函数中的 scoring 参数用于指定评估模型性能的评分指标。评分指标是用来衡量模型预测结果与真实结果之间的匹配程度的方法。在机器学习任务中,选择合适的评分指标对于模型的评估和选择非常重要,因为不同的任务和数据可能需要不同的评估标准。以下是一些常见的评分指标以及它们在 cross_val_score 中的使用方式:

  1. 分类问题的评分指标

    • scoring="accuracy":用于多类分类问题,计算正确分类的样本比例。
    • scoring="precision":计算正类别预测的精确度,即正类别的真正例与所有正类别预测的样本之比。
    • scoring="recall":计算正类别预测的召回率,即正类别的真正例与所有真实正类别的样本之比。
    • scoring="f1":计算 F1 分数,它是精确度和召回率的调和均值,用于综合考虑模型的性能。

    示例使用方法:

    from sklearn.model_selection import cross_val_scorescores_accuracy = cross_val_score(estimator, X, y, cv=5, scoring="accuracy")
    scores_precision = cross_val_score(estimator, X, y, cv=5, scoring="precision")
    scores_recall = cross_val_score(estimator, X, y, cv=5, scoring="recall")
    scores_f1 = cross_val_score(estimator, X, y, cv=5, scoring="f1")
    

  2. 回归问题的评分指标

    • scoring="neg_mean_squared_error":用于回归问题,计算负均方误差(Negative Mean Squared Error),即平均预测值与真实值的平方差。
    • scoring="r2":计算决定系数(R-squared),用于度量模型对目标变量的解释方差程度,取值范围在0到1之间。

    示例使用方法:

    from sklearn.model_selection import cross_val_scorescores_mse = cross_val_score(estimator, X, y, cv=5, scoring="neg_mean_squared_error")
    scores_r2 = cross_val_score(estimator, X, y, cv=5, scoring="r2")
    

  3. 其他评分指标

    • 除了上述常见的评分指标外,还可以使用其他自定义评分函数或指标,例如 AUC、log损失等,只需将评分函数传递给 scoring 参数即可。

    示例使用方法:

    from sklearn.metrics import roc_auc_score
    from sklearn.model_selection import cross_val_scorescoring_function = make_scorer(roc_auc_score)
    scores_auc = cross_val_score(estimator, X, y, cv=5, scoring=scoring_function)
    

根据任务和数据类型,选择适当的评分指标非常重要,它有助于衡量模型的性能,确定模型是否满足预期的要求,并在不同模型之间进行比较和选择。不同的评分指标可以反映模型性能的不同方面,因此需要根据具体情况进行选择。

四、learning_curve函数的使用

1、learning_curve函数功能

learning_curve 是一个用于评估机器学习模型性能的可视化工具。它通常用于了解模型在不同训练数据集大小下的性能变化,以帮助决定是否需要更多的训练数据或模型是否已经过拟合。learning_curve 可以帮助你可视化训练集和验证集上的性能指标,通常是准确性(accuracy)或损失函数(loss)随着训练数据集大小的变化而变化的情况。

在 Python 中,可以使用 sklearn.model_selection.learning_curve 函数来创建学习曲线。

2、例子

代码:

from __future__ import print_function
from sklearn.model_selection import learning_curve
from sklearn.datasets import load_digits
from sklearn.svm import SVC
import matplotlib.pyplot as plt
import numpy as npdigits = load_digits()
X = digits.data
y = digits.target
train_sizes, train_loss, test_loss = learning_curve(SVC(gamma=0.001), X, y, cv=10, scoring='neg_mean_squared_error',train_sizes=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0])
train_loss_mean = -np.mean(train_loss, axis=1)
test_loss_mean = -np.mean(test_loss, axis=1)plt.plot(train_sizes, train_loss_mean, 'o-', color="r",label = "Training")
plt.plot(train_sizes, test_loss_mean, 'o-', color="g",label = "Cross-validation")plt.xlabel("Training examples")
plt.ylabel("Loss")
plt.legend(loc="best")
plt.show()

结果为:

当把SVC的参数gamma改为0.01后执行程序得到结果为:

可见,对于训练集,模型更加精确了,损失很少,但对于测试集,损失很大,且随着训练的进行损失不会下降,发生了过拟合,gamma参数的作用为【问GPT】。

五、解决过拟合(validation_curve函数的使用)

1、validation_curve函数功能

validation_curve 函数是 scikit-learn(sklearn)库中的一个工具函数,用于评估模型在不同超参数设置下的性能,并帮助你找到最优的超参数配置。它的主要功能是绘制不同超参数值的模型性能曲线,以便你可以直观地看到模型性能如何随着超参数的变化而变化。

validation_curve 函数通常用于调整模型的超参数,例如正则化参数、决策树深度、学习率等。它帮助你了解不同超参数值对模型性能的影响,以便选择最佳的超参数配置。

以下是 validation_curve 函数的一些关键参数:

  1. estimator:要评估的机器学习模型,通常是一个分类器或回归器的实例

  2. X:特征矩阵,包含输入样本的特征值

  3. y:目标向量,包含对应于输入样本的目标值或标签

  4. param_name:要调整的超参数的名称,例如正则化参数、树的深度等。

  5. param_range:超参数的一组不同取值。validation_curve 将在这些不同的取值上评估模型性能。

  6. scoring:用于评估模型性能的评分指标,例如准确度(accuracy)、均方误差(MSE)、F1 分数等。

  7. cv:交叉验证的折数,用于计算性能的平均值和标准差。

  8. n_jobs:并行计算的数量,用于加速计算。

validation_curve 函数返回一个包含训练得分和验证得分的数组,以及对应于每个超参数值的均值和标准差。这些信息可以用于绘制性能曲线,以便可视化超参数的选择。

2、迭代gamma的值,选择合适的gamma:

代码:

from __future__ import print_function
from sklearn.model_selection import validation_curve
from sklearn.datasets import load_digits
from sklearn.svm import SVC
import matplotlib.pyplot as plt
import numpy as npdigits = load_digits()
X = digits.data
y = digits.target
param_range = np.logspace(-6, -2.3, 5)
train_loss, test_loss = validation_curve(SVC(), X, y, param_name='gamma', param_range=param_range, cv=10,scoring='neg_mean_squared_error')
train_loss_mean = -np.mean(train_loss, axis=1)
test_loss_mean = -np.mean(test_loss, axis=1)plt.plot(param_range, train_loss_mean, 'o-', color="r",label="Training")
plt.plot(param_range, test_loss_mean, 'o-', color="g",label="Cross-validation")plt.xlabel("gamma")
plt.ylabel("Loss")
plt.legend(loc="best")
plt.show()

结果为:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/615969.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

ubuntu系统(9):ubuntu 20.02安装pydot

目录 警告信息 1、确保安装了Python和pip 2、安装Graphviz软件包 3、pip安装pydot 验证 在gem5中,pydot库用于生成图形化输出,特别是生成.dot文件和相关的图像文件,如PDF、PNG等。它与gem5结合使用的一个常见用途是生成系统结构图、内存…

基础篇_面向对象(什么是对象,对象演化,继承,多态,封装,接口,Service,核心类库,异常处理)

文章目录 一. 什么是对象1. 抽取属性2. 字段默认值3. this4. 无参构造5. 抽取行为 二. 对象演化1. 对象字段演化2. 对象方法演化3. 贷款计算器 - 对象改造4. 静态变量5. 四种变量 三. 继承1. 继承语法2. 贷款计算器 - 继承改造3. java 类型系统4. 类型转换1) 基本类型转换2) 包…

【算法分析与设计】最大子数组和

题目 给你一个整数数组 nums ,请你找出一个具有最大和的连续子数组(子数组最少包含一个元素),返回其最大和。 子数组 是数组中的一个连续部分。 示例 示例 1: 输入:nums [-2,1,-3,4,-1,2,1,-5,4] 输出&a…

硬核加码!星邦蓝助力全球运力最大固体火箭“引力一号”海上首飞

继助力我国最大固体运载火箭“力箭一号”首飞后,星邦蓝再次有幸参与和见证了全球运力最大的固体火箭“引力一号”首次成功发射。 今日,全球运力最大的固体火箭“引力一号”从山东海阳附近海域完成首次发射,刷新世界最大固体运载火箭纪录&…

关于鸿蒙的ArkUI的自我理解

先不说好不好上手 一些软件必要的基础概念了解 ①瓦片地图 --无或未找到 ②视频播放功能 --未找到能播放直播流(找到个 ohos/ijkplayer不知如何) ③支付功能 微信无 支付宝的是java代码写得,AskUI中如何调用 ④推送 --自己应该有吧 ⑤长…

【一周安全资讯0106】国家标准《信息安全技术 网络安全信息报送指南》正式发布;全球1100万SSH服务器面临“水龟攻击”威胁

要闻速览 1、国家标准GB/T 43557-2023《信息安全技术 网络安全信息报送指南》发布 2、《未成年人网络保护条例》元旦起施行 织密未成年人网络保护立体“安全网” 3、深圳证监局:证券期货经营机构应建立健全网络安全应急处置机制 4、黑客大规模恶意注册与ChatGPT相似…

全面解析微服务

导读 微服务是企业应用及数据变革升级的利器,也是数字化转型及运营不可或缺的助产工具,企业云原生更离不开微服务,同时云原生的既要最大化发挥微服务的价值,也要最大化弥补微服务的缺陷。本文梳理了微服务基础设施组件、服务网格、…

C++重新认知:拷贝构造函数

一、什么是拷贝构造函数 对于简单变量来说&#xff0c;可以轻松完成拷贝。 int a 10; int b a;但是对于复杂的类对象来说&#xff0c;不仅存在变量成员&#xff0c;也存在各种函数等。因此相同类型的类对象是通过拷贝构造函数来完成复制过程的。 #include<iostream>…

基于 TensorFlow.js 构建垃圾评论检测系统

基于 TensorFlow.js 构建垃圾评论检测系统。 准备工作 在过去的十年中,Web 应用变得越来越具有社交性和互动性,而即使是在中等热门的网站上,也有数万人可能实时对多媒体、评论等的支持。这也让垃圾内容发布者有机会滥用此类系统,将不太令人满意的内容与其他人撰写的文章、视…

小程序必看系列!什么是抖音小程序?抖音小程序怎么制作?

随着移动互联网的飞速发展&#xff0c;抖音已经成为了一个广受欢迎的短视频平台。在这个平台上&#xff0c;用户可以分享自己的生活点滴、表达自己的观点&#xff0c;甚至还能通过小程序来丰富自己的社交体验。那么&#xff0c;如何制作抖音小程序呢&#xff1f; 一、抖音小程…

5288 SDH/PDH数字传输分析仪

5288 SDH/PDH数字传输分析仪 数字通信测量仪器 5288 SDH/PDH数字传输分析仪为高性能手持式数字传输分析仪&#xff0c;符合ITU-T SDH/PDH技术规范和我国光同步传输网技术体制的规定,支持2.048、34.368、139.264Mb/s及155.520Mb/s传输速率的测试。可进行SDH/PDH传输设备和网络的…

云畅科技技术中心被认定为湖南省省级企业技术中心

近日&#xff0c;湖南省工业和信息化厅公布《2023年第二批湖南省省级企业技术中心(第29批)》&#xff0c;云畅科技技术中心作为研发设计型代表入选。 省级企业技术中心是强化企业技术创新主体地位&#xff0c;增强企业自主创新能力&#xff0c;推动工业企业高质量发展的一个重要…

SQL-分组查询

&#x1f389;欢迎您来到我的MySQL基础复习专栏 ☆* o(≧▽≦)o *☆哈喽~我是小小恶斯法克&#x1f379; ✨博客主页&#xff1a;小小恶斯法克的博客 &#x1f388;该系列文章专栏&#xff1a;重拾MySQL &#x1f379;文章作者技术和水平很有限&#xff0c;如果文中出现错误&am…

turnjs实现翻书效果

需求&#xff1a;要做一个效果&#xff0c;类似于阅读器上的翻书效果。 咱们要实现这个需求就需要使用turnjs这个插件&#xff0c;他的官网是turnjs官网。 进入官网后可以点击 这个按钮去下载官网的demo。 这个插件依赖于jQuery&#xff0c;所以你的先安装jQuery. npm insta…

Unity URP下阴影锯齿

1.概述 在Unity开发的URP项目中出现阴影有明显锯齿。如下图所示&#xff1a; 并且在主光源的Shadow Type已经是Soft Shadows模式了。 2.URP Asset 阴影出现锯齿说明阴影质量不高&#xff0c;所以要先找到URP Asset文件进行阴影质量参数的设置。 1.打开PlayerSetting找到Graph…

代码签名证书怎么选择?软件开发者必看

随着互联网的高速发展&#xff0c;各种购物、资讯、社交类软件高速增长。而对于软件开发者来说&#xff0c;选择合适的代码签名证书来为软件进行数字签名、确保软件程序代码的完整性和软件的可信任性是很有必要的。但市场上有多种品牌、多种类型的代码签名证书可以选择&#xf…

03.阿里Java开发手册——OOP规约

【强制】避免通过一个类的对象引用访问此类的静态变量或静态方法&#xff0c;无谓增加编译器解析成本&#xff0c;直接用类名来访问即可。 【强制】所有的覆写方法&#xff0c;必须加Override 注解。 说明&#xff1a;getObject()与 get0bject()的问题。一个是字母的 O&#x…

vue前端开发自学,插槽练习,同时渲染父子组件的数据信息

vue前端开发自学,插槽练习,同时渲染父子组件的数据信息&#xff01; 如果想在slot插槽出口里面&#xff0c;同时渲染出来&#xff0c;来自父组件的数据&#xff0c;和子组件自身的数据呢。又有点绕口了。vue官方给的解决办法是。需要借助于&#xff0c;父组件的自定义属性。 …

第二百五十九回

文章目录 知识回顾示例代码经验总结 我们在上一章回中介绍了MethodChannel的使用方法&#xff0c;本章回中将介绍EventChannel的使用方法.闲话休提&#xff0c;让我们一起Talk Flutter吧。 知识回顾 我们在前面章回中介绍了通道的概念和作用&#xff0c;并且提到了通道有不同的…

本地部署Canal笔记-实现MySQL与ElasticSearch7数据同步

背景 本地搭建canal实现mysql数据到es的简单的数据同步&#xff0c;仅供学习参考 建议首先熟悉一下canal同步方式&#xff1a;https://github.com/alibaba/canal/wiki 前提条件 本地搭建MySQL数据库本地搭建ElasticSearch本地搭建canal-server本地搭建canal-adapter 操作步骤…