机器学习模型超参数优化最常用的5个工具包

优化超参数始终是确保模型性能最佳的关键任务。通常,网格搜索、随机搜索和贝叶斯优化等技术是主要使用的方法。

今天分享几个常用于模型超参数优化的 Python 工具包,如下所示:

  • scikit-learn:使用在指定参数值上进行的网格搜索或随机搜索。
  • HyperparameterHunter:构建在scikit-learn之上,以使其更易于使用。
  • Optuna:使用随机搜索、Parzen估计器(TPE)和基于群体的训练。
  • Hyperopt:使用随机搜索和TPE。
  • Talos:构建在Keras之上,以使其更易于使用。

技术交流

技术要学会分享、交流,不建议闭门造车。一个人可以走的很快、一堆人可以走的更远。

本文由粉丝群小伙伴总结与分享,如果你也想学习交流,资料获取,均可加交流群获取,群友已超过2000人,添加时最好的备注方式为:来源+兴趣方向,方便找到志同道合的朋友。

方式①、添加微信号:dkl88194,备注:来自CSDN + 加群
方式②、微信搜索公众号:Python学习与数据挖掘,后台回复:加群

现在,让我们看一些使用这些库进行自动编码器模型超参数优化的Python代码示例:

from keras.layers import Input, Dense
from keras.models import Model# define the Autoencoder
input_layer = Input(shape=(784,))
encoded = Dense(32, activation='relu')(input_layer)
decoded = Dense(784, activation='sigmoid')(encoded)
autoencoder = Model(input_layer, decoded)
autoencoder.compile(optimizer='adam', loss='binary_crossentropy')
autoencoder.fit(X_train, X_train, epochs=100, batch_size=256, validation_data=(X_test, X_test))

scikit-learn

from sklearn.model_selection import GridSearchCV# define the parameter values that should be searched
param_grid = {'batch_size': [64, 128, 256], 'epochs': [50, 100, 150]}# create a KFold cross-validator
kfold = KFold(n_splits=10, random_state=7)# create the grid search object
grid = GridSearchCV(estimator=autoencoder, param_grid=param_grid, cv=kfold)# fit the grid search object to the training data
grid_result = grid.fit(X_train, X_train)# print the best parameters and the corresponding score
print(f'Best parameters: {grid_result.best_params_}')
print(f'Best score: {grid_result.best_score_}')

HyperparameterHunter

import HyperparameterHunter as hh# create a HyperparameterHunter object
hunter = hh.HyperparameterHunter(input_data=X_train, output_data=X_train, model_wrapper=hh.ModelWrapper(autoencoder))# define the hyperparameter search space
hunter.setup(objective='val_loss', metric='val_loss', optimization_mode='minimize', max_trials=100)
hunter.add_experiment(parameters=hh.Real(0.1, 1, name='learning_rate', digits=3, rounding=4))
hunter.add_experiment(parameters=hh.Real(0.1, 1, name='decay', digits=3, rounding=4))# perform the hyperparameter search
hunter.hunt(n_jobs=1, gpu_id='0')# print the best hyperparameters and the corresponding score
print(f'Best hyperparameters: {hunter.best_params}')
print(f'Best score: {hunter.best_score}')

Hyperopt

from hyperopt import fmin, tpe, hp# define the parameter space
param_space = {'batch_size': hp.quniform('batch_size', 64, 256, 1), 'epochs': hp.quniform('epochs', 50, 150, 1)}# define the objective function
def objective(params):autoencoder.compile(optimizer='adam', loss='binary_crossentropy')autoencoder.fit(X_train, X_train, batch_size=params['batch_size'], epochs=params['epochs'], verbose=0)scores = autoencoder.evaluate(X_test, X_test, verbose=0)return {'loss': scores, 'status': STATUS_OK}# perform the optimization
best = fmin(objective, param_space, algo=tpe.suggest, max_evals=100)# print the best parameters and the corresponding score
print(f'Best parameters: {best}')
print(f'Best score: {objective(best)}')

Optuna

import optuna# define the objective function
def objective(trial):batch_size = trial.suggest_int('batch_size', 64, 256)epochs = trial.suggest_int('epochs', 50, 150)autoencoder.compile(optimizer='adam', loss='binary_crossentropy')autoencoder.fit(X_train, X_train, batch_size=batch_size, epochs=epochs, verbose=0)score = autoencoder.evaluate(X_test, X_test, verbose=0)return score# create the Optuna study
study = optuna.create_study()# optimize the hyperparameters
study.optimize(objective, n_trials=100)# print the best parameters and the corresponding score
print(f'Best parameters: {study.best_params}')
print(f'Best score: {study.best_value}')

Talos

import talos# define the parameter space
param_space = {'learning_rate': [0.1, 0.01, 0.001], 'decay': [0.1, 0.01, 0.001]}# define the objective function
def objective(params):autoencoder.compile(optimizer='adam', loss='binary_crossentropy', lr=params['learning_rate'], decay=params['decay'])history = autoencoder.fit(X_train, X_train, epochs=100, batch_size=256, validation_data=(X_test, X_test), verbose=0)score = history.history['val_loss'][-1]return score# perform the optimization
best = talos.Scan(X_train, X_train, params=param_space, model=autoencoder, experiment_name='autoencoder').best_params(objective, n_trials=100)# print the best parameters and the corresponding score
print(f'Best parameters: {best}')
print(f'Best score: {objective(best)}')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/152679.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

leetcode每日一题31

搜索旋转排序数组 那……二分法呗 数组中的数可以相同 比 33. 搜索旋转排序数组 多了一个「有重复元素」,导致无法根据 num > nums[0] 来判断 num 在哪一半,比如 [1,1,1,1,1,2,1,1,1] 旋转数组两头相等,元素 1 可能在左半边可能在右半边 …

vue2 - SuperMap3D加载基于Nginx服务生成的3DTileset模型切片服务地址

文章目录 🍍开发环境🍉1:nginx发布3Dtileset模型切片服务🍍1.1:准备3DTileset文件🍍1.2:安装nginx服务,配置相关文件1.2.1:下载nginx1.2.2:下载完解压文件如下1.2.3:将3Dtileset模型文件放置 nginx-1.24.0/html/gc 新建文件中如下:1.2.4:配置nginx服务🍉2:…

基于Docker的安装和配置Canal

基本介绍 Canal介绍:Canal 是用 Java 开发的基于数据库增量日志解析,提供增量数据订阅&消费的中间件(数据库同步需要阿里的 Otter 中间件,基于 Canal)。 Canal背景:阿里巴巴 B2B 公司,因为…

AI绘画使用Stable Diffusion(SDXL)绘制三星堆风格的图片

一、前言 三星堆文化是一种古老的中国文化,它以其精湛的青铜铸造技术闻名,出土文物中最著名的包括青铜面具、青铜人像、金杖、玉器等。这些文物具有独特的艺术风格,显示了高度的工艺水平和复杂的社会结构。 青铜面具的巨大眼睛和突出的颧骨&a…

【Web】Ctfshow Nodejs刷题记录

目录 ①web334 ②web335 ③web336 ④web337 ⑤web338 ⑥web339 ⑦web340 ⑧web341 ⑨web342-343 ⑩web344 ①web334 进来是一个登录界面 下载附件,简单代码审计 表单传ctfshow 123456即可 ②web335 进来提示 get上传eval参数执行nodejs代码 payload: …

【力扣面试经典150题】(链表)K 个一组翻转链表

题目描述 力扣原文链接 给你链表的头节点 head ,每 k 个节点一组进行翻转,请你返回修改后的链表。 k 是一个正整数,它的值小于或等于链表的长度。如果节点总数不是 k 的整数倍,那么请将最后剩余的节点保持原有顺序。 你不能只…

面试题:设置view点击事件不回调的几种方式和原理

如何设置view 点击事件不回调,如何实现?有什么区别? setEnabled(false) 这个方案用于设置view是否可以响应用户的其他交互事件如触摸,轨迹球等。 setClickable(false) 这个方法用于设置view是否可以响应用户的点击事件。 set…

技术分享 | 如何写好测试用例?

对于软件测试工程师来说,设计测试用例和提交缺陷报告是最基本的职业技能。是非常重要的部分。一个好的测试用例能够指示测试人员如何对软件进行测试。在这篇文章中,我们将介绍测试用例设计常用的几种方法,以及如何编写高效的测试用例。 ## 一…

vue和uni-app的递归组件排坑

有这样一个数组数据,实际可能有很多级。 tree: [{id: 1,name: 1,children: [{ id: 2, name: 1-1, children: [{id: 7, name: 1-1-1,children: []}]},{ id: 3, name: 1-2 }]},{id: 4,name: 2,children: [{ id: 5, name: 2-1 },{ id: 6, name: 2-2 }]} ]要渲染为下面…

java算法学习索引之二叉树问题

一 分别用递归和非递归方式实现二叉树先序、中序和后序遍历 用递归和非递归方式,分别按照二叉树先序、中序和后序打印所有的节点。我们约定:先序遍历顺序为根、左、右;中序遍历顺序为左、根、右;后序遍历顺序为左、右、根。 递…

春秋云境靶场CVE-2022-30887漏洞复现(任意文件上传漏洞)

文章目录 前言一、CVE-2022-30887描述和介绍二、CVE-2021-41402漏洞复现1、信息收集2、找可能可以进行任意php代码执行的地方3、漏洞利用找flag 总结 前言 此文章只用于学习和反思巩固渗透测试知识,禁止用于做非法攻击。注意靶场是可以练习的平台,不能随…

C语言中的指针(上)

目录 一、基本概念 1.变量的存储空间 2.定义指针 3.引用与解引用 二、指针的算术运算、类型以及通用指针 1.指针的算数运算 2.指针类型以及通用型指针 三、指向指针的指针(pointers to pointers) 四、函数传值以及传引用 1.局部变量 2.从存储地…

IOS输入框聚焦会把内容区域顶起

前几天做了一个类似qq布局的h5的聊天界面,输入框固定在最底下。本来初始情况会有默认的两条聊天记录,但是当点击底部的输入框时,输入框聚焦,弹起键盘,然后整个界面就被顶上去了,然后那两条默认的聊天记录也…

gitlab环境准备

1.准备环境 gitlab只支持linux系统,本人在虚拟机下使用Ubuntu作为操作系统,gitlab镜像要使用和操作系统版本对应的版本,(ubuntu18.04,gitlab-ce_13.2.3-ce.0_amd64 .deb) book100ask:/$ lsb_release -a No LSB modules are available. Dist…

机器学习二元分类 二元交叉熵 二元分类例子

二元交叉熵损失函数 深度学习中的二元分类损失函数通常采用二元交叉熵(Binary Cross-Entropy)作为损失函数。 二元交叉熵损失函数的基本公式是: L(y, y_pred) -y * log(y_pred) - (1 - y) * log(1 - y_pred)其中,y是真实标签&…

【C++11】右值引用使用详解

系列文章目录 C11新特性使用详解-持续更新 文章目录 系列文章目录前言一、关联特性1.1 左值/右值 二、使用方法2.1 获得右值引用2.2 对象移动方法2.2.1 移动构造函数/移动赋值运算符2.2.2 标记为noexcept2.2.3 使移动源对象进入是可析构状态 三、使用场景3.1 移动语义3.1 完美…

中贝通信-603220 三季报分析(20231120)

中贝通信-603220 基本情况 公司名称:中贝通信集团股份有限公司 A股简称:中贝通信 成立日期:1999-12-29 上市日期:2018-11-15 所属行业:软件和信息技术服务业 周期性:1 主营业务:通信网络技术服务…

Qt ListWidget

先创建QListWidgetItem: QListWidgetItem* pListItem1 new QListWidgetItem(QIcon(":/resources/editor.png"),u8"editor");QListWidgetItem* pListItem2 new QListWidgetItem(QIcon(":/resources/env.png"),u8"env");Q…

通信网络安全防护定级备案流程介绍(附流程图)

通信网络安全防护定级备案是拥有增值电信业务经营许可证并且有开展电信业务的企业要做的一件事情。刚接触这块的家人们在填报操作的时候可能对具体通信网络安全防护定级备案流程还不是很清楚,所以就给大家画张具体的流程图吧,可以更加直观的了解。 通信…