政安晨:【Keras机器学习示例演绎】(四十二)—— 使用 KerasNLP 和 tf.distribute 进行数据并行训练

目录

简介

导入

基本批量大小和学习率

计算按比例分配的批量大小和学习率


政安晨的个人主页政安晨

欢迎 👍点赞✍评论⭐收藏

收录专栏: TensorFlow与Keras机器学习实战

希望政安晨的博客能够对您有所裨益,如有不足之处,欢迎在评论区提出指正!

本文目标:使用 KerasNLP 和 tf.distribute 进行数据并行训练。

简介


分布式训练是一种在多台设备或机器上同时训练深度学习模型的技术。它有助于缩短训练时间,并允许使用更多数据训练更大的模型。KerasNLP 是一个为自然语言处理任务(包括分布式训练)提供工具和实用程序的库。

在本文中,我们将使用 KerasNLP 在 wikitext-2 数据集(维基百科文章的 200 万字数据集)上训练基于 BERT 的屏蔽语言模型 (MLM)。MLM 任务包括预测句子中的屏蔽词,这有助于模型学习单词的上下文表征。

本指南侧重于数据并行性,尤其是同步数据并行性,即每个加速器(GPU 或 TPU)都拥有一个完整的模型副本,并查看不同批次的部分输入数据。部分梯度在每个设备上计算、汇总,并用于计算全局梯度更新。

具体来说,本文将教您如何在以下两种设置中使用 tf.distribute API 在多个 GPU 上训练 Keras 模型,只需对代码做最小的改动:

—— 在一台机器上安装多个 GPU(通常为 2 至 8 个)(单主机、多设备训练)。这是研究人员和小规模行业工作流程最常见的设置。
—— 在由多台机器组成的集群上,每台机器安装一个或多个 GPU(多设备分布式训练)。这是大规模行业工作流程的良好设置,例如在 20-100 个 GPU 上对十亿字数据集进行高分辨率文本摘要模型训练。

!pip install -q --upgrade keras-nlp
!pip install -q --upgrade keras  # Upgrade to Keras 3.

导入

import osos.environ["KERAS_BACKEND"] = "tensorflow"import tensorflow as tf
import keras
import keras_nlp

在开始任何训练之前,让我们配置一下我们的单 GPU,使其显示为两个逻辑设备。

在使用两个或更多物理 GPU 进行训练时,这完全没有必要。这只是在默认 colab GPU 运行时(只有一个 GPU 可用)上显示真实分布式训练的一个技巧。

!nvidia-smi --query-gpu=memory.total --format=csv,noheader
physical_devices = tf.config.list_physical_devices("GPU")
tf.config.set_logical_device_configuration(physical_devices[0],[tf.config.LogicalDeviceConfiguration(memory_limit=15360 // 2),tf.config.LogicalDeviceConfiguration(memory_limit=15360 // 2),],
)logical_devices = tf.config.list_logical_devices("GPU")
logical_devicesEPOCHS = 3
24576 MiB

要使用 Keras 模型进行单主机、多设备同步训练,您需要使用 tf.distribute.MirroredStrategy API。下面是其工作原理:

—— 实例化 MirroredStrategy,可选择配置要使用的特定设备(默认情况下,该策略将使用所有可用的 GPU)。
—— 使用该策略对象打开一个作用域,并在该作用域中创建所需的包含变量的所有 Keras 对象。通常情况下,这意味着在分发作用域内创建和编译模型。
—— 像往常一样通过 fit() 训练模型。

strategy = tf.distribute.MirroredStrategy()
print(f"Number of devices: {strategy.num_replicas_in_sync}")
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1')
Number of devices: 2

基本批量大小和学习率

base_batch_size = 32
base_learning_rate = 1e-4

计算按比例分配的批量大小和学习率

scaled_batch_size = base_batch_size * strategy.num_replicas_in_sync
scaled_learning_rate = base_learning_rate * strategy.num_replicas_in_sync

现在,我们需要下载并预处理 wikitext-2 数据集。该数据集将用于预训练 BERT 模型。我们将过滤掉短行,以确保数据有足够的语境用于训练。

keras.utils.get_file(origin="https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip",extract=True,
)
wiki_dir = os.path.expanduser("~/.keras/datasets/wikitext-2/")# Load wikitext-103 and filter out short lines.
wiki_train_ds = (tf.data.TextLineDataset(wiki_dir + "wiki.train.tokens",).filter(lambda x: tf.strings.length(x) > 100).shuffle(buffer_size=500).batch(scaled_batch_size).cache().prefetch(tf.data.AUTOTUNE)
)
wiki_val_ds = (tf.data.TextLineDataset(wiki_dir + "wiki.valid.tokens").filter(lambda x: tf.strings.length(x) > 100).shuffle(buffer_size=500).batch(scaled_batch_size).cache().prefetch(tf.data.AUTOTUNE)
)
wiki_test_ds = (tf.data.TextLineDataset(wiki_dir + "wiki.test.tokens").filter(lambda x: tf.strings.length(x) > 100).shuffle(buffer_size=500).batch(scaled_batch_size).cache().prefetch(tf.data.AUTOTUNE)
)

在上述代码中,我们下载并提取了 wikitext-2 数据集。然后,我们定义了三个数据集:wiki_train_ds、wiki_val_ds 和 wiki_test_ds。我们对这些数据集进行了过滤,以去除短行,并对其进行批处理,以提高训练效率。

在 NLP 训练/调整中,使用衰减学习率是一种常见的做法。在这里,我们将使用多项式衰减时间表(PolynomialDecay schedule)。

total_training_steps = sum(1 for _ in wiki_train_ds.as_numpy_iterator()) * EPOCHS
lr_schedule = tf.keras.optimizers.schedules.PolynomialDecay(initial_learning_rate=scaled_learning_rate,decay_steps=total_training_steps,end_learning_rate=0.0,
)class PrintLR(tf.keras.callbacks.Callback):def on_epoch_end(self, epoch, logs=None):print(f"\nLearning rate for epoch {epoch + 1} is {model_dist.optimizer.learning_rate.numpy()}")

我们还要回调 TensorBoard,这样就能在本教程后半部分训练模型时可视化不同的指标。我们将所有回调放在一起,如下所示:

callbacks = [tf.keras.callbacks.TensorBoard(log_dir="./logs"),PrintLR(),
]print(tf.config.list_physical_devices("GPU"))
[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]

准备好数据集后,我们现在要在 strategy.scope() 中初始化并编译模型和优化器:

with strategy.scope():# Everything that creates variables should be under the strategy scope.# In general this is only model construction & `compile()`.model_dist = keras_nlp.models.BertMaskedLM.from_preset("bert_tiny_en_uncased")# This line just sets pooled_dense layer as non-trainiable, we do this to avoid# warnings of this layer being unusedmodel_dist.get_layer("bert_backbone").get_layer("pooled_dense").trainable = Falsemodel_dist.compile(loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),optimizer=tf.keras.optimizers.AdamW(learning_rate=scaled_learning_rate),weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],jit_compile=False,)model_dist.fit(wiki_train_ds, validation_data=wiki_val_ds, epochs=EPOCHS, callbacks=callbacks)
Epoch 1/3
Learning rate for epoch 1 is 0.00019999999494757503239/239 ━━━━━━━━━━━━━━━━━━━━ 43s 136ms/step - loss: 3.7009 - sparse_categorical_accuracy: 0.1499 - val_loss: 1.1509 - val_sparse_categorical_accuracy: 0.3485
Epoch 2/3239/239 ━━━━━━━━━━━━━━━━━━━━ 0s 122ms/step - loss: 2.6094 - sparse_categorical_accuracy: 0.5284
Learning rate for epoch 2 is 0.00019999999494757503239/239 ━━━━━━━━━━━━━━━━━━━━ 32s 133ms/step - loss: 2.6038 - sparse_categorical_accuracy: 0.5274 - val_loss: 0.9812 - val_sparse_categorical_accuracy: 0.4006
Epoch 3/3239/239 ━━━━━━━━━━━━━━━━━━━━ 0s 123ms/step - loss: 2.3564 - sparse_categorical_accuracy: 0.6053
Learning rate for epoch 3 is 0.00019999999494757503239/239 ━━━━━━━━━━━━━━━━━━━━ 32s 134ms/step - loss: 2.3514 - sparse_categorical_accuracy: 0.6040 - val_loss: 0.9213 - val_sparse_categorical_accuracy: 0.4230

根据范围拟合模型后,我们对其进行正常评估!

model_dist.evaluate(wiki_test_ds)
 29/29 ━━━━━━━━━━━━━━━━━━━━ 3s 60ms/step - loss: 1.9197 - sparse_categorical_accuracy: 0.8527[0.9470901489257812, 0.4373602867126465]

对于跨多台计算机的分布式训练(而不是只利用单台计算机上的多个设备进行训练),您可以使用两种分布式策略:MultiWorkerMirroredStrategy 和 ParameterServerStrategy:

—— tf.distribution.MultiWorkerMirroredStrategy(多工作站策略)实现了一种 CPU/GPU 多工作站同步解决方案,可与 Keras 风格的模型构建和训练循环配合使用,并使用跨副本的梯度同步还原。
—— tf.distribution.experimental.ParameterServerStrategy(参数服务器策略)实现了一种异步 CPU/GPU 多工作站解决方案,其中参数存储在参数服务器上,工作站异步更新梯度到参数服务器。


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/835484.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Excel日期数字转化成时间格式

1、5位数字转化成yyyy/mm/dd 要考虑闰年的小细节 // 输入数字转成日期(5位,excel表格日期),默认转换成YYYY-MM-DD export function numberToDate(number, format) {if (number ! undefined) {let date new Date((number - 1) *…

IDEA中的常见注解

下面是对每个注解的详细解释: Override:这个注解用于标记一个方法覆盖或实现了父类或接口中的方法。如果一个方法标记为Override,但实际上没有覆盖或实现父类或接口中的方法,编译器会报错。 Deprecated:这个注解用于标…

机器学习求数组的迹

机器学习求数组的迹、也叫求矩阵的迹。 矩阵的迹,也称为迹数,是矩阵主对角线上所有元素的和。矩阵的迹具有以下重要性质:- 不变性:矩阵的迹在转置、加法、乘法等运算下保持不变。- 特征值关系:一个方阵的迹等于其所有特…

微服务全局异常处理

1.使用两个注解RestControllerAdvice 和 Excetionhandler(valueExcetption.class) 2.第一个注解RestcontrollerAdvice用于注解类,RestControllerAdvice可以捕获整个应用程序中抛出的异常,并对它们进行处理。这样可以实现在整个应用程序范围内统一处理异…

高标准农田建设项目天空地一体化智慧监管平台

一、建设背景 党中央、国务院高度重视高标准农田建设。国务院办公厅印发的《关于切实加强高标准农田建设提升国家粮食安全保障能力的意见》 明确提出,大力推进高标准农田建设,到2022年,建成10亿亩高标准农田,以此稳定保障1万亿斤以…

《C语言文件处理:从新手到高手的跃迁》

📃博客主页: 小镇敲码人 💚代码仓库,欢迎访问 🚀 欢迎关注:👍点赞 👂🏽留言 😍收藏 🌏 任尔江湖满血骨,我自踏雪寻梅香。 万千浮云遮碧…

寻找最大价值的矿堆 - 矩阵

系列文章目录 文章目录 系列文章目录前言一、题目描述二、输入描述三、输出描述四、Java代码五、测试用例 前言 本人最近再练习算法,所以会发布一些解题思路,希望大家多指教 一、题目描述 给你一个由’0’(空地)、‘1’(银矿)、‘2’(金矿)组成的地图…

Spring Cloud Gateway 全局过滤器

系列文章目录 文章目录 系列文章目录前言 前言 前些天发现了一个巨牛的人工智能学习网站,通俗易懂,风趣幽默,忍不住分享一下给大家。点击跳转到网站,这篇文章男女通用,看懂了就去分享给你的码吧。 全局过滤器作用于所…

TypeScript:JavaScript的超集

什么是TypeScript? TypeScript是一种由Microsoft开发的开源语言,它在JavaScript的基础上增加了类型系统和编译时的类型检查。TypeScript旨在解决JavaScript在大规模应用开发中遇到的问题,特别是在类型安全性方面。它可以编译成纯JavaScript代…

Visual Studio 安装教程 超级详细 (亲测有效)

1.1 VS2019安装 网址:Visual Studio: 面向软件开发人员和 Teams 的 IDE 和代码编辑器 下载完成之后双击.exe文件 步骤严格如下安装 默认语音包为中文(简体) 安装位置可以自行选择,完成以后就可以点击安装了。 安装完毕以后需要重…

深度探索Java工厂模式:创新与灵活性的结合

在软件设计中,有效地组织对象的创建过程是至关重要的。Java工厂模式是一种优秀的设计模式,它能够在对象创建的过程中提供更大的灵活性和可扩展性。本文将深入探讨工厂模式的不同实现方式,并提供详细的代码示例,以帮助读者更好地理…

docker cuda 宿主机访问docker 内部jupyter notebook

先运行一个容器,并且把宿主机端口映射到jupyter的8888 docker run -it --gpus all -p 9099:8888 --networkmy_network - ubuntu_zzc_0510 1.生成配置文件 jupyter notebook --generate-config 2.修改配置文件 vim ~/.jupyter/jupyter_notebook_config.py c.S…

java技术总结

1.java基本数据类型? byte 1,short 2 ,int 4,long 8 ,float 4,double 8,boolean 1,char 2 2.java为什么要有包装类型? 前 6 个类派生于公共的超类 Number,而 Character 和 Boolean 是 Object 的直接子类。 被 final 修饰, Java 内置的包装类是无法被继承的。 包装…

ubuntu postgresql 安装

在Ubuntu上安装PostgreSQL,你可以按照以下步骤进行: 使用apt包管理器安装 更新系统: 在安装任何软件之前,建议先更新你的操作系统。 sudo apt update sudo apt upgrade 安装PostgreSQL: 使用apt包管理器来安装Postg…

QT 小项目:登录注册账号和忘记密码(下一章实现远程登录)

一、环境搭建 参考上一章环境 二、项目工程目录 三、主要源程序如下: registeraccountwindow.cpp 窗口初始化: void registeraccountWindow::reginit() {//去掉?号this->setWindowFlags(windowFlags() & ~Qt::WindowContextHelpButt…

用标准的GNU/Linux命令替换Alpine上的精简版命令

Alpine Linux 是一个基于 musl libc 和 busybox 的轻量级Linux发行版,busybox 实现了很多常用类Unix命令的精简版,特点是体积很小,舍弃了很多不常用参数,我们简单对比一下标准Linux自带的 date 命令 和 Alpine下默认的 date 命令便…

【联通支付注册/登录安全分析报告】

联通支付注册/登录安全分析报告 前言 由于网站注册入口容易被黑客攻击,存在如下安全问题: 暴力破解密码,造成用户信息泄露短信盗刷的安全问题,影响业务及导致用户投诉带来经济损失,尤其是后付费客户,风险巨…

格雷希尔GripSeal:E10系列低压信号电测试连接器,应用于新能源汽车的DCR测试和EOL测试

新能源车的电驱动、电池包等都有一些信号接口,从几针到几十针不等,而且每种接口都有独特的电性能要求,这些接口在电池包进DCR测试或是EOL测试时,为了满足这些信号接口的需求,我们设计了E10系列信号针快速接头&#xff…

5月10日学习记录

[NCTF2019]True XML cookbook(xxe漏洞利用) 这题是关于xxe漏洞的实际应用,利用xxe漏洞的外部实体来进行ssrf探针内网的主机 和[NCTF2019]Fake XML cookbook的区别就在于xxe漏洞的利用方向,一个是命令执行,一个是SSRF 看题,打开…

Java进阶08 集合(续)Stream流

Java进阶08 集合(续)&Stream流 一、HashSet集合类(续) 1、JDK7(-)HashSet原理解析 1.1 底层结构 数组链表 1.2 执行过程 ①创建一个默认长度为16的数组,数组名为table ②根据元素的哈希值跟数组的长度求余计…