Keras深度学习框架第三十讲:在KerasTuner中处理失败的训练

1、绪论

KerasTuner程序可能会运行很长时间,因为每个模型的训练可能需要很长时间。我们不希望程序仅仅因为一些试验随机失败就停止运行。

本文将讨论在KerasTuner中如何处理失败的训练,包括:

  • 如何在搜索过程中容忍失败的训练
  • 如何在构建和评估模型时将某个训练标记为失败
  • 如何通过抛出FatalError来终止搜索过程

正式讨论使用前,需要进行如下的设置

!pip install keras-tuner -q
import keras
from keras import layers
import keras_tuner
import numpy as np

2、容忍失败的训练

在初始化tuner时,我们将使用max_retries_per_trialmax_consecutive_failed_trials参数。

max_retries_per_trial控制如果一个试验持续失败时,允许重试的最大次数。例如,如果它被设置为3,那么该试验可能会运行4次(1次失败的运行 + 3次失败的重试),之后才最终被标记为失败。max_retries_per_trial的默认值是0。

max_consecutive_failed_trials控制在终止搜索之前,允许连续失败多少个训练(这里的失败试验指的是一个训练在所有的重试中都失败了)。例如,如果它被设置为3,并且试验2、试验3和试验4都失败了,搜索就会终止。但是,如果它被设置为3,但只有试验2、试验3、试验5和试验6失败,搜索不会终止,因为失败的试验不是连续的。max_consecutive_failed_trials的默认值是3。

以下代码展示了这两个参数如何工作。

我们定义了一个搜索空间,包含两个超参数,用于确定两个密集层的单元数。当这两个数的乘积大于800时,我们会为模型太大而引发ValueError

def build_model(hp):# Define the 2 hyperparameters for the units in dense layersunits_1 = hp.Int("units_1", 10, 40, step=10)units_2 = hp.Int("units_2", 10, 30, step=10)# Define the modelmodel = keras.Sequential([layers.Dense(units=units_1, input_shape=(20,)),layers.Dense(units=units_2),layers.Dense(units=1),])model.compile(loss="mse")# Raise an error when the model is too largenum_params = model.count_params()if num_params > 1200:raise ValueError(f"Model too large! It contains {num_params} params.")return model

我们按照以下方式设置tuner。

我们将max_retries_per_trial设置为3。
我们将max_consecutive_failed_trials设置为8。
我们使用GridSearch来枚举所有超参数值的组合。

这样设置后,如果在某个超参数组合下模型训练失败,KerasTuner会为该试验重试最多3次。如果某个试验在所有的重试中都失败了,并且连续有8个这样的失败试验,那么搜索过程将被终止。而GridSearch则会遍历搜索空间中的每个超参数组合,无论它们是否导致试验失败。

tuner = keras_tuner.GridSearch(hypermodel=build_model,objective="val_loss",overwrite=True,max_retries_per_trial=3,max_consecutive_failed_trials=8,
)# Use random data to train the model.
tuner.search(x=np.random.rand(100, 20),y=np.random.rand(100, 1),validation_data=(np.random.rand(100, 20),np.random.rand(100, 1),),epochs=10,
)# Print the results.
tuner.results_summary()

3、将试验标记为失败

当模型过大时,我们不需要重新尝试它。无论使用相同的超参数尝试多少次,模型都会过大。

我们可以将max_retries_per_trial设置为0来实现这一点。但是,这样做的话,无论出现什么错误都不会重试,而我们可能仍然希望对于其他意外的错误进行重试。有没有更好的方式来处理这种情况?

我们可以引发FailedTrialError来跳过重试。每当引发此错误时,试验将不会被重试。当发生其他错误时,重试仍会进行。以下是一个示例:

def build_model(hp):# Define the 2 hyperparameters for the units in dense layersunits_1 = hp.Int("units_1", 10, 40, step=10)units_2 = hp.Int("units_2", 10, 30, step=10)# Define the modelmodel = keras.Sequential([layers.Dense(units=units_1, input_shape=(20,)),layers.Dense(units=units_2),layers.Dense(units=1),])model.compile(loss="mse")# Raise an error when the model is too largenum_params = model.count_params()if num_params > 1200:# When this error is raised, it skips the retries.raise keras_tuner.errors.FailedTrialError(f"Model too large! It contains {num_params} params.")return modeltuner = keras_tuner.GridSearch(hypermodel=build_model,objective="val_loss",overwrite=True,max_retries_per_trial=3,max_consecutive_failed_trials=8,
)# Use random data to train the model.
tuner.search(x=np.random.rand(100, 20),y=np.random.rand(100, 1),validation_data=(np.random.rand(100, 20),np.random.rand(100, 1),),epochs=10,
)# Print the results.
tuner.results_summary()

4、程序化终止搜索

当代码中存在错误时,我们应该立即终止搜索并修复该错误。当满足您定义的条件时,程序员可以程序化地终止搜索。引发FatalError(或其子类FatalValueError、FatalTypeError或FatalRuntimeError)将终止搜索,而不管max_consecutive_failed_trials参数的值如何。

以下是一个当模型过大时终止搜索的示例:

def build_model(hp):# Define the 2 hyperparameters for the units in dense layersunits_1 = hp.Int("units_1", 10, 40, step=10)units_2 = hp.Int("units_2", 10, 30, step=10)# Define the modelmodel = keras.Sequential([layers.Dense(units=units_1, input_shape=(20,)),layers.Dense(units=units_2),layers.Dense(units=1),])model.compile(loss="mse")# Raise an error when the model is too largenum_params = model.count_params()if num_params > 1200:# When this error is raised, the search is terminated.raise keras_tuner.errors.FatalError(f"Model too large! It contains {num_params} params.")return modeltuner = keras_tuner.GridSearch(hypermodel=build_model,objective="val_loss",overwrite=True,max_retries_per_trial=3,max_consecutive_failed_trials=8,
)try:# Use random data to train the model.tuner.search(x=np.random.rand(100, 20),y=np.random.rand(100, 1),validation_data=(np.random.rand(100, 20),np.random.rand(100, 1),),epochs=10,)
except keras_tuner.errors.FatalError:print("The search is terminated.")

5、总结

本文讨论学习了如何在KerasTuner中处理失败的试验:

在KerasTuner中处理失败的试验有几种不同的方法,具体取决于失败的原因和你想要的行为。以下是几种常见的策略:

5.1 设置重试次数 (max_retries_per_trial)

如果你知道某些类型的错误可能是暂时的(例如,由于网络中断或硬件问题),你可以通过设置max_retries_per_trial参数来指定在标记一个试验为失败之前应该重试多少次。这样,即使试验在首次尝试时失败,Tuner也会尝试使用相同的超参数重新运行试验。

tuner = RandomSearch(build_model,objective='val_loss',max_trials=10,executions_per_trial=1,max_retries_per_trial=3,  # 最多重试3次# ... 其他参数 ...
)

5.2 直接标记试验为失败 (FailedTrialError)

如果你能在构建模型或运行试验的过程中检测到某些条件将导致失败(例如,模型大小超出限制),你可以引发FailedTrialError来直接标记该试验为失败。这将导致Tuner跳过该试验的重试,并继续尝试其他超参数组合。

from keras_tuner.engine.trial import FailedTrialErrordef build_model(hp):# ... 省略模型构建代码 ...if some_condition_that_will_fail:raise FailedTrialError('Model configuration will fail due to ...')# ... 省略模型其余部分的构建代码 ...

5.3 立即终止搜索 (FatalError 及其子类)

如果你遇到了一个严重的错误,该错误表明搜索无法继续进行(例如,数据加载错误或模型构建中的根本性错误),你可以引发FatalError或其子类(如FatalValueErrorFatalTypeErrorFatalRuntimeError)来立即终止搜索。这将停止Tuner的所有活动,并允许你修复代码中的错误。

from keras_tuner.engine.trial import FatalErrordef build_model(hp):# ... 省略模型构建代码 ...if some_unrecoverable_error:raise FatalError('Unrecoverable error occurred: ...')# ... 省略模型其余部分的构建代码 ...

5.4 处理异常并继续

有时,你可能想要捕获异常并进行一些处理(例如,记录错误或尝试使用不同的策略),而不是直接标记试验为失败或终止搜索。你可以使用Python的异常处理机制(try/except块)来实现这一点。

def build_model(hp):try:# ... 省略模型构建代码 ...# 这里可能会引发异常except SomeSpecificError as e:# 处理异常,例如记录日志或采取其他措施print(f'Caught an exception: {e}')# 但不引发FailedTrialError或FatalError,以便Tuner可以继续# ... 省略模型其余部分的构建代码 ...

5.5 自定义回调

你还可以使用KerasTuner的回调机制来在试验的不同阶段执行自定义逻辑。例如,你可以在试验开始前、结束后或每个epoch结束时运行自定义函数,以检查试验的状态或执行其他操作。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/17166.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

GPT-4o: 未来的智能助手

GPT-4o: 未来的智能助手 在这个信息爆炸的时代,人工智能(AI)已经成为我们生活中不可或缺的一部分。作为OpenAI最新推出的语言模型,GPT-4o不仅继承了前几代模型的优点,还在多个方面进行了显著的提升。本文将带你深入了解…

DreamerV3阅读笔记

DreamerV3 文章希望解决的一个挑战是用固定的hyperparameter来同时处理不同domain的任务。文章发现,通过结合KL balancing 和free bits可以使得world model learn without tuning(是指上面这件事,即不需要对不同任务改变hyperparameter&#…

2024年电工杯高校数学建模竞赛(B题) 建模解析| 大学生平衡膳食食谱的优化设计

问题重述及方法概述 问题1:膳食食谱的营养分析评价及调整 数学方法:线性规划模型、营养素评价模型、比较分析 可视化数据图:营养素含量表、营养素摄入量对比图、营养素缺乏情况图 问题2:基于附件3的日平衡膳食食谱的优化设计 数…

KingbaseES数据库物理备份还原sys_rman

数据库版本:KingbaseES V008R006C008B0014 简介 sys_rman 是 KingbaseES 数据库中重要的物理备份还原工具,支持不同类型的全量备份、差异备份、增量备份,保证数据库在遇到故障时及时使用 sys_rman 来恢复到数据库先前状态。 文章目录如下 1.…

揭秘爬虫技术:从请求到存储的全方位解析

新书上架~👇全国包邮奥~ python实用小工具开发教程http://pythontoolsteach.com/3 欢迎关注我👆,收藏下次不迷路┗|`O′|┛ 嗷~~ 目录 一、爬虫初探:请求与响应 二、数据解析:从混乱中提炼价值 三、数据…

Megatron-LM源码系列(八): Context Parallel并行

1. Context Parallel并行原理介绍 megatron中的context并行(简称CP)与sequence并行(简称SP)不同点在于,SP只针对Layernorm和Dropout输出的activation在sequence维度上进行切分,CP则是对所有的input输入和所有的输出activation在sequence维度上进行切分&…

M00238-固定翼无人机集群飞行仿真平台MATLAB完整代码含效果

一个小型无人机集群仿真演示平台,使用matlab和simulink搭建。 给出的例子是5架的,当然如果你愿意花时间,也可以把它扩展到10架,20架甚至更多。 输入:5架飞机的规划路径 输出:每架无人机每个时刻的13个状态量…

Docker环境安装并使用Elasticsearch

1、拉取es docker pull elasticsearch:7.10.12、查看镜像 docker images3、启动es docker run -d --name esearch -p 9200:9200 -p 9300:9300 elasticsearch:7.10.14、如果启动ES时出现一下问题 Unable to find image docker.elastic.co/elasticsearch/elasticsearch:7.10.…

python max_min标准化

python max_min标准化 max_min标准化sklearn实现max_min标准化手动实现max_min标准化 max_min标准化 Max-Min标准化(也称为归一化或Min-Max Scaling)是一种将数据缩放到特定范围(通常是0到1)的标准化方法。这种方法通过线性变换将…

用PhpStudy在本地电脑搭建WordPress网站教程(2024版)

对新手来说,明白了建站3要素后,如果直接购买域名、空间去建站,因为不熟练,反复测试主题、框架、插件等费时费力,等网站建成可能要两三个月,白白损失这段时间的建站费用。那么新手怎么建测试网站来练手呢&am…

Python学习:一个简单的登录系统演示了如何使用Python处理JSON数据来管理用户信息

闲来无事,学习一下python AI里搜索:python做一个登录系统json添加删除读取修改 以下是过程和结果: Python学习 环境,window10,Python 3.10.6,pip 24.0 一个简单的登录系统演示了如何使用Python处理JSON数据来管理用户信息 实现登…

06.部署jpress

安装mariadb数据 yum -y install mariadb-server #启动并设置开启自启动 systemctl start mariadb.service systemctl enable mariadb.service数据库准备 [rootweb01 ~]# mysql Welcome to the MariaDB monitor. Commands end with ; or \g. Your MariaDB connection id…

OpenAI 再次刷新认知边界:GPT-4 颠覆语音助手市场,流畅度直逼真人互动?

前言 近日,美国人工智能研究公司 OpenAI 发布了其最新旗舰模型 GPT-4o,这一革命性的进展不仅标志着人工智能领域的新突破,更预示着即将步入一个全新的交互时代?GPT-4o 的发布,对于我们来说,意味着人工智能…

冯喜运:5.28黄金今日走势分析及黄金原油操作策略

【黄金消息面分析】:周一(5月27日)美盘时段,现货黄金止跌回稳,缓慢回升,盘中最高触及2358.4美元。美国商品期货交易委员会(Commodity Futures Trading Commission)的最新交易数据显示,对黄金的投…

数据流的中位数 - LeetCode 热题 76

大家好!我是曾续缘😙 今天是《LeetCode 热题 100》系列 发车第 76 天 堆第 3 题 ❤️点赞 👍 收藏 ⭐再看,养成习惯 数据流的中位数 中位数是有序整数列表中的中间值。如果列表的大小是偶数,则没有中间值,中…

Deploy Tomcat for Centos 7

介绍 Tomcat 是一个免费的开放源代码的Web 应用服务器,属于轻量级应用服务器,在中小型系统和并发访问用户不是很多的场合下被普遍使用,是开发和调试JSP程序的首选。 Tomcat 是Apache软件基金会(Apache Software Foundation&…

空压机的热回收原理介绍

空压机运行时会产生大量的压缩热,通常这部分能量通过机组的风冷或水冷系统释放到大气当中。压缩机的热回收是持续降低空气系统损耗,提高客户生产力的必要手段。 余热回收的节能技术目前研究很多,但大多只针对喷油螺杆式空压机的油路改造而言…

笔试---C++

1.class和struct的默认权限分别是什么? class:private struct:public 2.const和static的作用,说的越多越好 const的了解-CSDN博客 static的了解-CSDN博客 3.c语言中链表 struct node{ int value; struct node * next; } typedef struct node node…

Eureka全面解析:轻松实现高效服务发现与治理!

一、引言 Eureka是Netflix开源的一款服务发现框架,它提供了一种高效的服务注册和发现机制,适用于大规模分布式系统。本文将详细介绍Eureka的相关知识。 二、Eureka简介 Eureka是一个基于REST的服务发现框架,它提供了一种简单的服务注册和发…

如果创办Google

本文是一篇演讲稿,来自于《黑客与画家》一书的作者保罗*格雷厄姆,被称为硅谷创业之父。这是他为14至15岁的孩子们做的一次演讲,内容是关于如果他们将来想创立一家创业公司,现在应该做些什么。很多学校认为应该向学生们传授一些有关…