模型预测笔记(三):通过交叉验证网格搜索机器学习的最优参数

文章目录

    • 网络搜索
      • 介绍
      • 步骤
      • 参数
      • 代码实现

网络搜索

介绍

网格搜索(Grid Search)是一种超参数优化方法,用于选择最佳的模型超参数组合。在机器学习中,超参数是在训练模型之前设置的参数,无法通过模型学习得到。网格搜索通过尝试所有可能的超参数组合,并使用交叉验证来评估每个组合的性能,从而确定最佳的超参数组合。

步骤

网格搜索的步骤如下:

  1. 定义要调整的超参数范围:确定要调整的每个超参数的可能取值范围。例如,学习率、正则化参数等。
  2. 创建参数网格:将每个超参数的可能取值组合成一个参数网格。
  3. 定义评估指标:选择一个评估指标来衡量每个超参数组合的性能。例如,准确率、均方误差等。
  4. 构建模型和交叉验证:选择一个机器学习模型,并定义交叉验证策略,将数据集分成训练集和验证集。
  5. 执行网格搜索:对于每个超参数组合,在交叉验证的每个训练集上训练模型,并在验证集上评估模型性能。
  6. 选择最佳超参数组合:根据评估指标的结果,选择具有最佳性能的超参数组合。
  7. 用最佳超参数训练模型:使用最佳超参数组合在整个训练数据集上重新训练模型。

网格搜索的优点是能够系统地尝试不同的超参数组合,找到最佳的模型性能。然而,由于需要尝试所有可能的组合,网格搜索的计算成本较高,尤其是超参数的数量较多时。因此,对于大型数据集和复杂模型,网格搜索可能会变得非常耗时。

为了减少计算成本,可以使用随机搜索(Randomized Search)等其他超参数优化方法,或者使用启发式方法来选择最佳超参数组合。

参数

GridSearchCV的参数包括:

  • estimator:要使用的模型或者估计器对象。
  • param_grid:一个字典或者列表,包含要进行网格搜索的参数和对应的取值范围。
  • scoring:评估模型性能的指标,可以是字符串(使用模型的内置评估指标)或者可调用对象(自定义评估指标)。
  • cv:交叉验证的折数或者交叉验证迭代器。
  • n_jobs:并行运行的作业数量。-1表示使用所有可用的处理器。
  • verbose:控制详细程度的整数值。0表示不输出任何信息,大于1表示输出详细的信息。
  • refit:如果为True(默认值),则在找到最佳参数后,使用最佳参数重新拟合整个数据集。
  • return_train_score:如果为True,则同时返回训练集上的得分。
  • error_score:当模型在某些参数组合下发生错误时,用于返回的分数。可以设置为’raise’(抛出错误)或者数字(返回指定的分数)。
  • verbose:控制详细程度的整数值。0表示不输出任何信息,大于1表示输出详细的信息。

注意:

在GridSearchCV中,scoring参数可以选择以下评分指标:

回归问题:

  • ‘explained_variance’:可解释方差
  • ‘neg_mean_absolute_error’:负平均绝对误差
  • ‘neg_mean_squared_error’:负均方误差
  • ‘neg_mean_squared_log_error’:负对数均方误差
  • ‘neg_median_absolute_error’:负中位数绝对误差
  • ‘r2’:R^2决定系数

二分类问题:

  • ‘accuracy’:准确率
  • ‘balanced_accuracy’:平衡准确率
  • ‘average_precision’:平均精确率
  • ‘f1’:F1得分
  • ‘precision’:精确率
  • ‘recall’:召回率
  • ‘roc_auc’:ROC曲线下的面积
    多分类问题:
  • ‘accuracy’:准确率
  • ‘balanced_accuracy’:平衡准确率
  • ‘average_precision’:平均精确率
  • ‘f1_micro’:微观平均F1得分
  • ‘f1_macro’:宏观平均F1得分
  • ‘precision_micro’:微观平均精确率
  • ‘precision_macro’:宏观平均精确率
  • ‘recall_micro’:微观平均召回率
  • ‘recall_macro’:宏观平均召回率
  • ‘roc_auc_ovr’:基于一对多的ROC曲线下的面积

请注意,不同问题类型和评估指标之间的兼容性可能会有所不同。

5折交叉验证就是把数据集分成5份,然后进行5此测试,如model1就是将第一折fold1的数据作为测试集,其余的四份作为数据集。最后每个model都计算出来一个准确度accuracy,求平均后作为此验证集的精确度。

代码实现

#调用网格搜索和决策树
from sklearn.model_selection import GridSearchCV
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import classification_report, roc_curve, auc
parameters = {'max_depth':[3, 5, 7, 9], 'min_samples_leaf': [1, 2, 3, 4]}# 选择两个超参数 树的深度max_depth和叶子的最小值min_samples_leafclf = GridSearchCV(DecisionTreeClassifier(), parameters, cv=3, scoring='accuracy')# 进行网格搜索得到最优参数组合
clf.fit(X_train, y_train) #通过有最优参数组合的最优模型进行训练print('最优参数:', clf.best_params_)
print('验证集最高得分:', clf.best_score_)
# 获取最优模型
best_model = clf.best_estimator_
print('测试集上准确率:', best_model.score(X_test, y_test))# 得到预测概率
y_prob_DT = clf.predict_proba(X_test)[:, 1]# 得到预测标签
y_pred_DT = clf.predict(X_test)# 得到分类报告
print(classification_report(y_pred = y_pred_DT, y_true = y_test))# 绘制ROC图
fpr, tpr, threshold = roc_curve(y_score = y_prob_DT, y_true = y_test)
print("AUC值", auc(fpr, tpr))
plt.plot(fpr, tpr,"r-")
plt.plot([0, 1], [0, 1],"b-")
plt.xlable("FPR")
plt.ylable("TPR")
plt.title("ROC Curve")# 输出结果文件
result = pd.DataFrame()
result["load_ID"] = pd.read_csv("***.csv")["**ID"]
result["predict_labels"] = y_pred_DT
result.to_csv("result.csv", index = False)# 特征重要性评估
best_DT = clf.best_estimator_
best_DT.fit(X_train, y_train)# 重要性绘制
plt.figure(figsize(8, 6))
pd.Series(best_DT.feature_importances_, index=X_train.columns).sort_values().plot(kind="barh")

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/60995.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

gif怎么转换成mp4格式视频

gif怎么转换成mp4格式视频?GIF格式是一种广泛应用的公用图像文件格式标准,具有许多优势。它占用的内存较小,可以实现自动循环播放,并且兼容多个平台。然而,GIF格式也存在一些缺点。例如,它无法处理复杂的图…

laravel excel导入导出

一、安装第三方 composer require maatwebsite/excel版本2.1和现在版本 有所不一样 二、导入 <?php namespace App\Import; use Maatwebsite\Excel\Concerns\ToCollection;class TestImport implements ToCollection {public function __construct(){}public function c…

如何提高工业网关的数据传输速度?

工业网关是工业物联网系统中不可或缺的设备&#xff0c;提高工业网关的数据采集、传输速度&#xff0c;是保障和优化物联网系统运营效率的基础。如何提高工业物联网关的数据传输速度&#xff1f;本篇就为大家简单介绍一下。 1、选用高品质网络设备 选用具有足够带宽容量的高质…

vue的公共方法封装以及class高阶封装

一、Vue.use与Vue.prototype的区别和用法 1、Vue.use和Vue.prototype区别 相同点&#xff1a;都是注册插件的方式&#xff0c;没有本质区别&#xff0c;都是在vue.prototype上添加了一个方法不同点&#xff1a;vue.use适用于注册vue生态内的插件(vuex、router、elementUI)&…

MySQL提权

参考&#xff1a; mysql提权篇 | Wh0ales Blog MySQL 提权方法整理 - Geekbys Blog MySQL_UDF提权漏洞复现-云社区-华为云 MYSQL UDF手动提权及自动化工具使用_udf提权工具_小直789的博客-CSDN博客 MySQL提权的三种方法 - FreeBuf网络安全行业门户 ...

hive问题总结

往往用了很久的函数却只知道其单一的应用场景&#xff0c;本文将不断完善所遇到的好用的hive内置函数。 1.聚合函数或者求最大最小值函数搭配开窗函数使用可以实现滑动窗口 例&#xff1a; collect_list函数&#xff0c;搭配开窗函数&#xff0c;实现了在滑动窗口内对事件路径…

idea 常用插件和常用快捷键 - 记录

idea 常用插件 记得下载插件完成后&#xff0c;点击 Apply 和 OK Alibaba Java Coding Guidelines 作用&#xff1a;使用该插件可以&#xff0c;自动提示相关的语法格式问题&#xff0c;格式参考 阿里巴巴代码规范 详情链接&#xff1a; 代码规范之Alibaba Java Coding G…

js深拷贝三种方法

使用递归函数实现深拷贝 const obj {name: zzz,age: 18,hobby: [篮球, 足球],family: {baby: baby}} // 深拷贝 数组 对象 一定要先筛数组再筛对象,因为万物皆对象function deepcopy(newObj, oldObj) {for (const k in oldObj) {// 判断值是否属于array类if (oldObj[k] i…

深度学习怎么学?

推荐这本小白看的《深度学习&#xff1a;从基础到实践&#xff08;上下册&#xff09;》。 深度学习&#xff1a;从基础到实践&#xff08;上下册&#xff09; 深入浅出的讲述了深度学习的基本概念与理论知识&#xff0c;不涉及复杂的数学内容&#xff0c;零基础小白也能轻松掌…

在应用条上共享内容

动作提供者是一个动作&#xff0c;能定义自己的外观和行为&#xff0c;下面是为应用条增加一个动作提供者。 共享动作提供者允许用户与其他应用共享应用中的内容。可以使用动作提供者让用户向他们某个联系人发送一个特色披萨的详细信息。 共享动作提供者会定义自己的图标&#…

Scala中的类型检查和转换,以及泛型,scala泛型的协变和逆变

Scala中的类型检查和转换&#xff0c;以及泛型 类型检查和转换 说明 &#xff08;1&#xff09; obj.isInstanceOf[T]&#xff1a;判断 obj 是不是T 类型。 &#xff08;2&#xff09; obj.asInstanceOf[T]&#xff1a;将 obj 强转成 T 类型。 &#xff08;3&#xff09; cla…

2023-8-31 Dijkstra求最短路(二)

题目链接&#xff1a;Dijkstra求最短路 II #include <iostream> #include <cstring> #include <algorithm> #include <vector> #include <queue>using namespace std;typedef pair<int, int> PII;const int N 150010;int n, m; int h[N…

解决 PaddleClas 下载预训练模型报错 ModuleNotFoundError No module named ‘ppcls‘ 的问题

当我们在使用 PaddleClas 进行预训练模型下载时&#xff0c;可能会遇到一个报错&#xff0c;报错信息为 ModuleNotFoundError: No module named ppcls。这个错误通常是因为 Python 解释器无法找到名为 ppcls 的模块&#xff0c;而我们的代码中正尝试导入它。让我们一起来解决这…

HTTP 代理原理及 Python 简单实现

HTTP 代理是一种网络代理服务器(Proxy Server),它能够作为客户端与 HTTP 服务器之间的中介,它的工作原理是: 当客户端向 HTTP 代理发送 HTTP 请求时,HTTP 代理会收到请求。 HTTP 代理会将请求转发给目标 HTTP 服务器。 目标 HTTP 服务器处理请求并生成响应。 HTTP 代理将…

什么是操作系统,数据结构

1、什么是操作系统&#xff1f; 操作系统是一组主管并控制计算机操作、运用和运行硬件、软件资源和提供公共服务来组织用户交互的相互关联的系统软件程序。根据运行的环境&#xff0c;操作系统可以分为桌面操作系统&#xff0c;手机操作系统&#xff0c;服务器操作系统&#x…

QT Creator工具介绍及使用

一、QT的基本概念 QT主要用于图形化界面的开发&#xff0c; QT是基于C编写的一套界面相关的类库&#xff0c;如进程线程库&#xff0c;网络编程的库&#xff0c;数据库操作的库&#xff0c;文件操作的库等。 如何使用这个类库&#xff1a;类库实例化对象(构造函数) --> 学习…

数据结构(Java实现)-二叉树(上)

树型结构 树是一种非线性的数据结构&#xff0c;它是由n&#xff08;n>0&#xff09;个有限结点组成一个具有层次关系的集合。把它叫做树是因为它看起来像一棵倒挂的树&#xff0c;也就是说它是根朝上&#xff0c;而叶朝下的。 有一个特殊的结点&#xff0c;称为根结点&…

Docker搭建elasticsearch+kibana测试

最近需要做大数据画像&#xff0c;所以先简单搭建一个eskibana学习使用&#xff0c;记录一下搭建过程和遇到的问题以及解决办法 1.拉取es和kibana镜像 在拉取镜像之前先搜索一下 elasticsearch发现是存在elasticsearch镜像的&#xff0c;我一般习惯性拉取最新镜像&#xff0c…

信息化发展12

数字民生 数字民生建设重点通常强调&#xff1a; 1 &#xff09; 普惠&#xff1a; 充分开发利用信息技术体系&#xff0c; 扩大民生保障覆盖范围&#xff0c; 助力普惠型民生建设&#xff0c; 解决民生资源配置不均衡等问题。 2&#xff09; 赋能&#xff1a; 信息技术体系与…

若依富文本 html样式 被过滤问题

一.场景 进入页面&#xff0c;富文本编辑框里回显这条新闻内容&#xff0c;如下图&#xff0c; 然后可以在富文本编辑框里对它实现再编辑&#xff0c;编辑之后将html代码提交保存到后台数据库。可以点击详情页进行查看。 出现问题&#xff1a;在提交到后台controller时&#x…