政安晨:【掌握AI的深度学习工具Keras API】(二)—— 【使用内置的训练循环和评估循环】

渐进式呈现复杂性,是指采用一系列从简单到灵活的工作流程,并逐步提高复杂性。这个原则也适用于模型训练。Keras提供了训练模型的多种工作流程。这些工作流程可以很简单,比如在数据上调用fit(),也可以很高级,比如从头开始编写新的训练算法。

政安晨的个人主页政安晨

欢迎 👍点赞✍评论⭐收藏

收录专栏: 政安晨的机器学习笔记

希望政安晨的博客能够对您有所裨益,如有不足之处,欢迎在评论区提出指正!


开始

你已经熟悉compile()、fit()、evaluate()和predict()的工作流程。

咱们看下面的代码,走一下这个流程:

标准工作流程compile()  ->  fit()  ->   evaluate()  ->  predict()

from tensorflow.keras.datasets import mnist# 创建模型(我们将其包装为一个单独的函数,以便后续复用)
def get_mnist_model(): inputs = keras.Input(shape=(28 * 28,))features = layers.Dense(512, activation="relu")(inputs)features = layers.Dropout(0.5)(features)outputs = layers.Dense(10, activation="softmax")(features)model = keras.Model(inputs, outputs)return model# 加载数据,保留一部分数据用于验证
(images, labels), (test_images, test_labels) = mnist.load_data()
images = images.reshape((60000, 28 * 28)).astype("float32") / 255
test_images = test_images.reshape((10000, 28 * 28)).astype("float32") / 255
train_images, val_images = images[10000:], images[:10000]
train_labels, val_labels = labels[10000:], labels[:10000]model = get_mnist_model()# (本行及以下3行)编译模型,指定模型的优化器、需要最小化的损失函数和需要监控的指标
model.compile(optimizer="rmsprop", loss="sparse_categorical_crossentropy",metrics=["accuracy"])# (本行及以下3行)使用fit()训练模型,可以选择提供验证数据来监控模型在前所未见的数据上的性能
model.fit(train_images, train_labels,epochs=3,validation_data=(val_images, val_labels))# 使用evaluate()计算模型在新数据上的损失和指标
test_metrics = model.evaluate(test_images, test_labels) # 使用predict()计算模型在新数据上的分类概率
predictions = model.predict(test_images)

要想自定义这个简单的工作流程,可以采用以下方法:

编写自定义指标;

向fit()方法传入调函数,以便在训练过程中的特定时间点采取行动。

咱们接下来进一步讨论。

编写自定义指标

指标是衡量模型性能的关键,尤其是衡量模型在训练数据上的性能与在测试数据上的性能之间的差异。常用的分类指标和回归指标内置于keras.metrics模块中。大多数情况下,你会使用这些指标。但如果想做一些不寻常的工作,你需要能够编写自定义指标

小伙伴们不要怕,这并不难,很简单!

Keras指标是keras.metrics.Metric类的子类。与层相同的是,指标具有一个存储在TensorFlow变量中的内部状态。与层不同的是,这些变量无法通过反向传播进行更新,所以你必须自己编写状态更新逻辑。这一逻辑由update_state()方法实现。

举个例子,如下代码实现了一个简单的自定义指标,用于衡量均方根误差(RMSE)

通过将Metric类子类化来实现自定义指标

import tensorflow as tf# 将Metric类子类化
class RootMeanSquaredError(keras.metrics.Metric):# (本行及以下4行)在构造函数中定义状态变量。与层一样,你可以访问add_weight()方法def __init__(self, name="rmse", **kwargs):super().__init__(name=name, **kwargs)self.mse_sum = self.add_weight(name="mse_sum", initializer="zeros")self.total_samples = self.add_weight(name="total_samples", initializer="zeros", dtype="int32")# 为了匹配MNIST模型,我们需要分类预测值与整数标签def update_state(self, y_true, y_pred, sample_weight=None):# 在update_state()中实现状态更新逻辑。y_true参数是一个数据批量对应的目标(或标签),y_pred则表示相应的模型预测值。你可以忽略sample_weight参数,这里不会用到y_true = tf.one_hot(y_true, depth=tf.shape(y_pred)[1])mse = tf.reduce_sum(tf.square(y_true - y_pred))self.mse_sum.assign_add(mse)num_samples = tf.shape(y_pred)[0]self.total_samples.assign_add(num_samples)

我们可以使用result()方法返回指标的当前值。

    def result(self):return tf.sqrt(self.mse_sum / tf.cast(self.total_samples, tf.float32))

此外,你还需要提供一种方法来重置指标状态,而无须将其重新实例化。

如此一来,相同的指标对象可以在不同的训练轮次中使用,或者在训练和评估中使用。

这可以用reset_state()方法来实现。

    def reset_state(self):self.mse_sum.assign(0.)self.total_samples.assign(0)

自定义指标的用法与内置指标相同。下面来测试一下我们的自定义指标

model = get_mnist_model()model.compile(optimizer="rmsprop",loss="sparse_categorical_crossentropy",metrics=["accuracy", RootMeanSquaredError()])model.fit(train_images, train_labels,epochs=3,validation_data=(val_images, val_labels))
test_metrics = model.evaluate(test_images, test_labels)

你可以看到fit()的进度条,上面显示模型的RMSE。

使用调函数

使用model.fit()在大型数据集上启动数十轮训练,这样做有点类似于投掷纸飞机:最初给它一点推力,之后你就再也无法控制它的轨迹或着陆点。如果想避免得到不好的结果(从而避免浪费纸飞机),更聪明的做法是,不用纸飞机,而用一架无人机。它可以感知环境,向操作者发送数据,并且能够根据当前状态自主航行。

Keras的回调函数(callback)API可以让model.fit()的调用从纸飞机变为自主飞行的无人机,使其能够观察自身状态并不断采取行动。

回调函数是一个对象(实现了特定方法的类实例),它在调用fit()时被传入模型,并在训练过程中的不同时间点被模型调用。

回调函数可以访问关于模型状态与模型性能的所有可用数据,还可以采取以下行动:中断训练、保存模型、加载一组不同的权重或者改变模型状态。

回调函数的一些用法示例如下:

模型检查点(model checkpointing):在训练过程中的不同时间点保存模型的当前状态。

提前终止(early stopping):如果验证损失不再改善,则中断训练(当然,同时保存在训练过程中的最佳模型)。

在训练过程中动态调节某些参数值:比如调节优化器的学习率。

在训练过程中记录训练指标和验证指标,或者将模型学到的表示可视化(这些表示在不断更新):fit()进度条实际上就是一个回调函数。

keras.callbacks模块包含许多内置的回调函数,下面列出了其中一些,还有很多没有列出来:

keras.callbacks.ModelCheckpoint
keras.callbacks.EarlyStopping
keras.callbacks.LearningRateScheduler
keras.callbacks.ReduceLROnPlateau
keras.callbacks.CSVLogger

下面介绍两个回调函数EarlyStopping  ModelCheckpoint,让你大致了解回调函数的用法。

调函数EarlyStopping和ModelCheckpoint

训练模型时,很多事情一开始无法预测,尤其是你无法预测需要多少轮才能达到最佳验证损失。

前面所有例子都采用这样一种策略:训练足够多的轮次,这时模型已经开始过拟合,利用第一次运行确定最佳训练轮数,然后用这个最佳轮数从头开始重新训练一次。当然,这种方法很浪费资源。一种更好的处理方法是,发现验证损失不再改善时,停止训练。这可以通过EarlyStopping回调函数来实现。

如果监控的目标指标在设定的轮数内不再改善,那么可以用EarlyStopping回调函数中断训练。比如,这个回调函数可以在刚开始过拟合时就立即中断训练,从而避免用更少的轮数重新训练模型。这个回调函数通常与ModelCheckpoint结合使用,后者可以在训练过程中不断保存模型(你也可以选择只保存当前最佳模型,即每轮结束后具有最佳性能的模型)。

如下代码展示了如何在fit()方法中使用callbacks参数。

在fit()方法中使用callbacks参数

# 通过fit()的callbacks参数将回调函数传入模型中,该参数接收一个回调函数列表,可以传入任意数量的回调函数
callbacks_list = [# 如果不再改善,则中断训练keras.callbacks.EarlyStopping(# 监控模型的验证精度monitor="val_accuracy",# 如果精度在两轮内都不再改善,则中断训练patience=2,),# 在每轮过后保存当前权重keras.callbacks.ModelCheckpoint(# 模型文件的保存路径filepath="checkpoint_path.keras",# (本行及以下1行)这两个参数的含义是,只有当val_loss改善时,才会覆盖模型文件,这样就可以一直保存训练过程中的最佳模型monitor="val_loss",save_best_only=True,)
]
model = get_mnist_model()
model.compile(optimizer="rmsprop",loss="sparse_categorical_crossentropy",# 监控精度,它应该是模型指标的一部分metrics=["accuracy"])
# (本行及以下3行)因为回调函数要监控验证损失和验证指标,所以在调用fit()时需要传入validation_data(验证数据)
model.fit(train_images, train_labels,epochs=10,callbacks=callbacks_list,validation_data=(val_images, val_labels))

注意,你也可以在训练完成后手动保存模型,只需调用model.save('my_checkpoint_path')。

要重新加载已保存的模型,只需使用下面这行代码:

model = keras.models.load_model("checkpoint_path.keras")

编写自定义调函数

如果想在训练过程中采取特定行动,而这些行动又没有包含在内置回调函数中,那么你可以编写自定义回调函数。

回调函数的实现方式是将keras.callbacks.Callback类子类化。

然后,你可以实现下列方法(从名称中即可看出这些方法的作用),它们在训练过程中的不同时间点被调用。

# 在每轮开始时被调用
on_epoch_begin(epoch, logs)# 在每轮结束时被调用
on_epoch_end(epoch, logs)# 在处理每个批量之前被调用
on_batch_begin(batch, logs)# 在处理每个批量之后被调用
on_batch_end(batch, logs)# 在训练开始时被调用
on_train_begin(logs)# 在训练结束时被调用
on_train_end(logs)

调用这些方法时,都会用到参数logs。

这个参数是一个字典,它包含前一个批量、前一个轮次或前一次训练的信息,比如训练指标和验证指标等。on_epoch_*方法和on_batch_*方法还将轮次索引或批量索引作为第一个参数(整数)。

如下代码给出了一个简单示例,它在训练过程中保存每个批量损失值组成的列表,还在每轮结束时保存这些损失值组成的图。

如下代码通过对Callback类子类化来创建自定义回调函数

from matplotlib import pyplot as pltclass LossHistory(keras.callbacks.Callback):def on_train_begin(self, logs):self.per_batch_losses = []def on_batch_end(self, batch, logs):self.per_batch_losses.append(logs.get("loss"))def on_epoch_end(self, epoch, logs):plt.clf()plt.plot(range(len(self.per_batch_losses)), self.per_batch_losses,label="Training loss for each batch")plt.xlabel(f"Batch (epoch {epoch})")plt.ylabel("Loss")plt.legend()plt.savefig(f"plot_at_epoch_{epoch}")self.per_batch_losses = []

咱们来测试一下

model = get_mnist_model()model.compile(optimizer="rmsprop",loss="sparse_categorical_crossentropy",metrics=["accuracy"])model.fit(train_images, train_labels,epochs=10,callbacks=[LossHistory()],validation_data=(val_images, val_labels))

自定义回调函数LossHistory的输出图像

利用TensorBoard进行监控和可视化

要想做好研究或开发出好的模型,你在实验过程中需要获得丰富且频繁的反馈,从而了解模型内部发生了什么。

这正是运行实验的目的:获取关于模型性能好坏的信息,并且越多越好。

取得进展是一个反复迭代的过程,或者说是一个循环:

首先,你有一个想法,并将其表述为一个实验,用于验证你的想法是否正确;

然后,你运行这个实验并处理生成的信息;

这又激发了你的下一个想法。在这个循环中,重复实验的次数越多,你的想法就会变得越来越精确、越来越强大。

Keras可以帮你尽快将想法转化成实验,高速GPU则可以帮你尽快得到实验结果。但如何处理实验结果呢?这就需要TensorBoard发挥作用了,如下图所示:

TensorBoard是一个基于浏览器的应用程序,可以在本地运行。它是在训练过程中监控模型的最佳方式。利用TensorBoard,你可以做以下工作:

在训练过程中以可视化方式监控指标

将模型架构可视化

将激活函数和梯度的直方图可视化

以三维形式研究嵌入

如果监控除模型最终损失之外的更多信息,则可以更清楚地了解模型做了什么、没做什么,并且能够更快地取得进展。

要将TensorBoard与Keras模型和fit()方法一起使用,最简单的方式就是使用keras.callbacks.TensorBoard回调函数。

在最简单的情况下,只需指定让回调函数写入日志的位置即可。

model = get_mnist_model()model.compile(optimizer="rmsprop",loss="sparse_categorical_crossentropy",metrics=["accuracy"])tensorboard = keras.callbacks.TensorBoard(log_dir="/full_path_to_your_log_dir",
)model.fit(train_images, train_labels,epochs=10,validation_data=(val_images, val_labels),callbacks=[tensorboard])

一旦开始运行,模型就将在目标位置写入日志。

如果在本地计算机上运行Python脚本,那么可以使用下列命令来启动TensorBoard本地服务器。(注意,如果你是通过pip安装TensorFlow的,那么tensorboard可执行文件应该已经可用;如果不可用,你可以通过pip install tensorboard手动安装TensorBoard。

tensorboard --logdir /full_path_to_your_log_dir

然后可以访问该命令返回的URL,以显示TensorBoard界面。

如果在Colab笔记本中运行脚本,则可以使用以下命令,将TensorBoard嵌入式实例作为笔记本的一部分运行。

%load_ext tensorboard
%tensorboard --logdir /full_path_to_your_log_dir

在TensorBoard界面中,你可以实时监控训练指标和评估指标的图像,如下图所示:

TensorBoard可用于监控训练指标和评估指标


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/710578.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

uniapp实战:父子组件传参之子组件数量动态变化

需求说明 现有的设置单元列表,每个带有虚线加号的可以看做是一组设置单元,点击加号可以添加一组设置单元.点击设置单元右上角可以删除对应的设置单元. 实现思路说明 利用数组元素添加或是删除的方式实现页面数量动态变化.由于每个设置单元内容都相同所以单独封装了一个子组件.…

高效备考2025年AMC8数学竞赛:2000-2024年AMC8真题练一练

如何提高小学和初中数学成绩?小学和初中可以参加的数学竞赛有哪些?不妨了解一下AMC8美国数学竞赛,现在许多小学生和初中生都在参加这个比赛。如果孩子有兴趣,有余力的话可以系统研究AMC8的历年真题,即使不参加AMC8竞赛…

YOLOv9大幅度按比例减小模型计算量!加快训练!

一、代码及论文链接: 代码链接:GitHub - WongKinYiu/yolov9: Implementation of paper - YOLOv9: Learning What You Want to Learn Using Programmable Gradient Information 论文链接:https://github.com/WongKinYiu/yolov9/tree/main 二…

02| JVM堆中垃圾回收的大致过程

如果一直在创建对象,堆中年轻代中Eden区会逐渐放满,如果Eden放满,会触发minor GC回收,创建对象的时GC Roots,如果存在于里面的对象,则被视为非垃圾对象,不会被此次gc回收,就会被移入…

深度学习500问——Chapter02:机器学习基础(1)

文章目录 前言 2.1 基本概念 2.1.1 大话理解机器学习本质 2.1.2 什么是神经网络 2.1.3 各种常见算法图示 2.1.4 计算图的导数计算 2.1.5 理解局部最优与全局最优 2.1.5 大数据与深度学习之间的关系 2.2 机器学习学习方式 2.2.1 监督学习 2.2.2 非监督式学习 2.2.3 …

TVM 和模型优化的概述(1)

文章目录 1. 从 Tensorflow、PyTorch 或 Onnx 等框架导入模型(model)。2.翻译成 Relay3. lower 到 张量表达式。4. 使用 auto-tuning 模块 AutoTVM 或 AutoScheduler 搜索最佳 schedule。5. 选择最佳配置进行模型编译。6. lower 到 TIR。7. 编译成机器码…

波奇学Linux:共享内存

进程通信的前提:不同的进程看到同一份的资源 直接原理:同一块物理内存映射到不同进程的共享区 共享内存拆解: 1.申请内存,通过页表映射到进程地址空间 2.返回首地址,便于进程利用 3.释放共享内存,去关联 4.内存的申请…

flex的5种常见使用

Flex 布局教程:语法篇 文章目录 一.基本概念二 例子 其实我每次记一个样式标签,都是根据英文来记,但是justify-content和align-items确实让我迷惑,这次我打算只记 justify-content属性定义了项目在主轴上的对齐方式,好好总结一下用法~ 一.基本概念 采用 Flex 布局…

SpringBoot 事务失效及其对应解决办法

简介 本文主要讲述Spring事务会去什么情况下失效及其解决办法 Spring 通过AOP 进行事务控制,如果操作数据库报异常,则会进行回滚;如果没有报异常则会提交事务;但是,如果Spring 事务失效,会导致数据缺失/重…

【STM32】STM32学习笔记-独立看门狗和窗口看门狗(47)

00. 目录 文章目录 00. 目录01. WDG概述02. 独立看门狗相关API2.1 IWDG_WriteAccessCmd2.2 IWDG_SetPrescaler2.3 IWDG_SetReload2.4 IWDG_ReloadCounter2.5 IWDG_Enable2.6 IWDG_GetFlagStatus2.7 RCC_GetFlagStatus 03. 独立看门狗接线图04. 独立看门狗程序示例105. 独立看门…

OD(12)之Mermaid思维导图(Mindmap)

OD(12)之Mermaid思维导图(Mindmap)使用详解 Author: Once Day Date: 2024年2月29日 漫漫长路才刚刚开始… 全系列文章可参考专栏: Mermaid使用指南_Once_day的博客-CSDN博客 参考文章: 关于 Mermaid | Mermaid 中文网 (nodejs.cn)Mermaid | Diagramming and charting tool…

postman传参与返回值切换为左右显示的操作

目录 第一步 点击“Settings”,在下拉框选择“Settings” 第二步 在默认打开的General页面,参照下图改动两处 第一步 点击“Settings”,在下拉框选择“Settings” 第二步 在默认打开的General页面,参照下图改动两处 附上修改后…

opencv中的rgb转gray的计算方法

转换原理 在opencv中,可以使用cv2.cvtColor函数将rgb图像转换为gray图像。示例代码如下, import cv2img_path "image.jpg" image cv2.imread(img_path) gray_image cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) mean gray_image.mean() pri…

【AI Agent系列】【MetaGPT多智能体学习】4. 基于MetaGPT的Team组件开发你的第一个智能体团队

本系列文章跟随《MetaGPT多智能体课程》(https://github.com/datawhalechina/hugging-multi-agent),深入理解并实践多智能体系统的开发。 本文为该课程的第四章(多智能体开发)的第二篇笔记。主要是对MetaGPT中Team组件…

Payment Without Change

题目链接&#xff1a;Problem - 1256A - Codeforces 解题思路&#xff1a; 题目的大致意思就是手中的硬币数拿出若干枚正好等于s&#xff0c;分三种情况 .如果n > s && b < s,输出no .如果b > s,输出yes .如果n * (a < (s / n) ? a : (s / n)) b >…

【iOS ARKit】RealityKit 同步机制

协作 Session 可以很方便地实现多用户之间的AR体验实时共享&#xff0c;但开发者需要自行负责并确保AR场景的完整性&#xff0c;自行负责虚拟物体的创建与销毁。为简化同步操作&#xff0c;RealityKit 内建了同步机制&#xff0c;RealityKit 同步机制基于 Multipeer Connectivi…

Python标准库sys常用函数、方法及代码实战解析【第108篇—标准库sys常用函数】

Python标准库sys常用函数、方法及代码实战解析 在Python的标准库中&#xff0c;sys 模块是一个常用而强大的工具&#xff0c;它提供了与Python解释器交互的函数和变量。本文将介绍sys模块的一些常用函数和方法&#xff0c;并通过实际的代码实例来解析它们的用法。 1. sys.argv…

2024.2.19

1.TCP模型 服务器端 #include <myhead.h> #define SER_IP "192.168.199.129" #define SER_PORT 8899int main(int argc, const char *argv[]) {//1.创建用于连接的套接字文件int sfdsocket(AF_INET,SOCK_STREAM,0);if(sfd-1){perror("socket error"…

react 原理揭秘

1.目标 A. 能够知道setState()更新数据是异步的 B. 能够知道JSX语法的转化过程 C. 能够说出React组件的更新机制 D. 能够对组件进行性能优化 E. 能够说出虚拟DOM和Diff算法 2.目录 A. setState()的说明 B. JSX语法的转化过程 C. 组件更新机制 D. 组件性能优化 E. 虚拟DOM和D…

[Vulnhub]靶场 Web Machine(N7)

kali:192.168.56.104 主机探测: arp-scan -l 靶机ip:192.168.56.104 端口扫描 nmap -p- 192.168.56.106 看一下web 目录扫描 gobuster dir -u http://192.168.56.106 -x html,txt,php,bak,zip --wordlist/usr/share/wordlists/dirbuster/directory-list-2.3-medium.txt exp…