政安晨:【Keras机器学习示例演绎】(四十二)—— 使用 KerasNLP 和 tf.distribute 进行数据并行训练

目录

简介

导入

基本批量大小和学习率

计算按比例分配的批量大小和学习率


政安晨的个人主页政安晨

欢迎 👍点赞✍评论⭐收藏

收录专栏: TensorFlow与Keras机器学习实战

希望政安晨的博客能够对您有所裨益,如有不足之处,欢迎在评论区提出指正!

本文目标:使用 KerasNLP 和 tf.distribute 进行数据并行训练。

简介


分布式训练是一种在多台设备或机器上同时训练深度学习模型的技术。它有助于缩短训练时间,并允许使用更多数据训练更大的模型。KerasNLP 是一个为自然语言处理任务(包括分布式训练)提供工具和实用程序的库。

在本文中,我们将使用 KerasNLP 在 wikitext-2 数据集(维基百科文章的 200 万字数据集)上训练基于 BERT 的屏蔽语言模型 (MLM)。MLM 任务包括预测句子中的屏蔽词,这有助于模型学习单词的上下文表征。

本指南侧重于数据并行性,尤其是同步数据并行性,即每个加速器(GPU 或 TPU)都拥有一个完整的模型副本,并查看不同批次的部分输入数据。部分梯度在每个设备上计算、汇总,并用于计算全局梯度更新。

具体来说,本文将教您如何在以下两种设置中使用 tf.distribute API 在多个 GPU 上训练 Keras 模型,只需对代码做最小的改动:

—— 在一台机器上安装多个 GPU(通常为 2 至 8 个)(单主机、多设备训练)。这是研究人员和小规模行业工作流程最常见的设置。
—— 在由多台机器组成的集群上,每台机器安装一个或多个 GPU(多设备分布式训练)。这是大规模行业工作流程的良好设置,例如在 20-100 个 GPU 上对十亿字数据集进行高分辨率文本摘要模型训练。

!pip install -q --upgrade keras-nlp
!pip install -q --upgrade keras  # Upgrade to Keras 3.

导入

import osos.environ["KERAS_BACKEND"] = "tensorflow"import tensorflow as tf
import keras
import keras_nlp

在开始任何训练之前,让我们配置一下我们的单 GPU,使其显示为两个逻辑设备。

在使用两个或更多物理 GPU 进行训练时,这完全没有必要。这只是在默认 colab GPU 运行时(只有一个 GPU 可用)上显示真实分布式训练的一个技巧。

!nvidia-smi --query-gpu=memory.total --format=csv,noheader
physical_devices = tf.config.list_physical_devices("GPU")
tf.config.set_logical_device_configuration(physical_devices[0],[tf.config.LogicalDeviceConfiguration(memory_limit=15360 // 2),tf.config.LogicalDeviceConfiguration(memory_limit=15360 // 2),],
)logical_devices = tf.config.list_logical_devices("GPU")
logical_devicesEPOCHS = 3
24576 MiB

要使用 Keras 模型进行单主机、多设备同步训练,您需要使用 tf.distribute.MirroredStrategy API。下面是其工作原理:

—— 实例化 MirroredStrategy,可选择配置要使用的特定设备(默认情况下,该策略将使用所有可用的 GPU)。
—— 使用该策略对象打开一个作用域,并在该作用域中创建所需的包含变量的所有 Keras 对象。通常情况下,这意味着在分发作用域内创建和编译模型。
—— 像往常一样通过 fit() 训练模型。

strategy = tf.distribute.MirroredStrategy()
print(f"Number of devices: {strategy.num_replicas_in_sync}")
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1')
Number of devices: 2

基本批量大小和学习率

base_batch_size = 32
base_learning_rate = 1e-4

计算按比例分配的批量大小和学习率

scaled_batch_size = base_batch_size * strategy.num_replicas_in_sync
scaled_learning_rate = base_learning_rate * strategy.num_replicas_in_sync

现在,我们需要下载并预处理 wikitext-2 数据集。该数据集将用于预训练 BERT 模型。我们将过滤掉短行,以确保数据有足够的语境用于训练。

keras.utils.get_file(origin="https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip",extract=True,
)
wiki_dir = os.path.expanduser("~/.keras/datasets/wikitext-2/")# Load wikitext-103 and filter out short lines.
wiki_train_ds = (tf.data.TextLineDataset(wiki_dir + "wiki.train.tokens",).filter(lambda x: tf.strings.length(x) > 100).shuffle(buffer_size=500).batch(scaled_batch_size).cache().prefetch(tf.data.AUTOTUNE)
)
wiki_val_ds = (tf.data.TextLineDataset(wiki_dir + "wiki.valid.tokens").filter(lambda x: tf.strings.length(x) > 100).shuffle(buffer_size=500).batch(scaled_batch_size).cache().prefetch(tf.data.AUTOTUNE)
)
wiki_test_ds = (tf.data.TextLineDataset(wiki_dir + "wiki.test.tokens").filter(lambda x: tf.strings.length(x) > 100).shuffle(buffer_size=500).batch(scaled_batch_size).cache().prefetch(tf.data.AUTOTUNE)
)

在上述代码中,我们下载并提取了 wikitext-2 数据集。然后,我们定义了三个数据集:wiki_train_ds、wiki_val_ds 和 wiki_test_ds。我们对这些数据集进行了过滤,以去除短行,并对其进行批处理,以提高训练效率。

在 NLP 训练/调整中,使用衰减学习率是一种常见的做法。在这里,我们将使用多项式衰减时间表(PolynomialDecay schedule)。

total_training_steps = sum(1 for _ in wiki_train_ds.as_numpy_iterator()) * EPOCHS
lr_schedule = tf.keras.optimizers.schedules.PolynomialDecay(initial_learning_rate=scaled_learning_rate,decay_steps=total_training_steps,end_learning_rate=0.0,
)class PrintLR(tf.keras.callbacks.Callback):def on_epoch_end(self, epoch, logs=None):print(f"\nLearning rate for epoch {epoch + 1} is {model_dist.optimizer.learning_rate.numpy()}")

我们还要回调 TensorBoard,这样就能在本教程后半部分训练模型时可视化不同的指标。我们将所有回调放在一起,如下所示:

callbacks = [tf.keras.callbacks.TensorBoard(log_dir="./logs"),PrintLR(),
]print(tf.config.list_physical_devices("GPU"))
[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]

准备好数据集后,我们现在要在 strategy.scope() 中初始化并编译模型和优化器:

with strategy.scope():# Everything that creates variables should be under the strategy scope.# In general this is only model construction & `compile()`.model_dist = keras_nlp.models.BertMaskedLM.from_preset("bert_tiny_en_uncased")# This line just sets pooled_dense layer as non-trainiable, we do this to avoid# warnings of this layer being unusedmodel_dist.get_layer("bert_backbone").get_layer("pooled_dense").trainable = Falsemodel_dist.compile(loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),optimizer=tf.keras.optimizers.AdamW(learning_rate=scaled_learning_rate),weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],jit_compile=False,)model_dist.fit(wiki_train_ds, validation_data=wiki_val_ds, epochs=EPOCHS, callbacks=callbacks)
Epoch 1/3
Learning rate for epoch 1 is 0.00019999999494757503239/239 ━━━━━━━━━━━━━━━━━━━━ 43s 136ms/step - loss: 3.7009 - sparse_categorical_accuracy: 0.1499 - val_loss: 1.1509 - val_sparse_categorical_accuracy: 0.3485
Epoch 2/3239/239 ━━━━━━━━━━━━━━━━━━━━ 0s 122ms/step - loss: 2.6094 - sparse_categorical_accuracy: 0.5284
Learning rate for epoch 2 is 0.00019999999494757503239/239 ━━━━━━━━━━━━━━━━━━━━ 32s 133ms/step - loss: 2.6038 - sparse_categorical_accuracy: 0.5274 - val_loss: 0.9812 - val_sparse_categorical_accuracy: 0.4006
Epoch 3/3239/239 ━━━━━━━━━━━━━━━━━━━━ 0s 123ms/step - loss: 2.3564 - sparse_categorical_accuracy: 0.6053
Learning rate for epoch 3 is 0.00019999999494757503239/239 ━━━━━━━━━━━━━━━━━━━━ 32s 134ms/step - loss: 2.3514 - sparse_categorical_accuracy: 0.6040 - val_loss: 0.9213 - val_sparse_categorical_accuracy: 0.4230

根据范围拟合模型后,我们对其进行正常评估!

model_dist.evaluate(wiki_test_ds)
 29/29 ━━━━━━━━━━━━━━━━━━━━ 3s 60ms/step - loss: 1.9197 - sparse_categorical_accuracy: 0.8527[0.9470901489257812, 0.4373602867126465]

对于跨多台计算机的分布式训练(而不是只利用单台计算机上的多个设备进行训练),您可以使用两种分布式策略:MultiWorkerMirroredStrategy 和 ParameterServerStrategy:

—— tf.distribution.MultiWorkerMirroredStrategy(多工作站策略)实现了一种 CPU/GPU 多工作站同步解决方案,可与 Keras 风格的模型构建和训练循环配合使用,并使用跨副本的梯度同步还原。
—— tf.distribution.experimental.ParameterServerStrategy(参数服务器策略)实现了一种异步 CPU/GPU 多工作站解决方案,其中参数存储在参数服务器上,工作站异步更新梯度到参数服务器。


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/835484.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

机器学习求数组的迹

机器学习求数组的迹、也叫求矩阵的迹。 矩阵的迹,也称为迹数,是矩阵主对角线上所有元素的和。矩阵的迹具有以下重要性质:- 不变性:矩阵的迹在转置、加法、乘法等运算下保持不变。- 特征值关系:一个方阵的迹等于其所有特…

高标准农田建设项目天空地一体化智慧监管平台

一、建设背景 党中央、国务院高度重视高标准农田建设。国务院办公厅印发的《关于切实加强高标准农田建设提升国家粮食安全保障能力的意见》 明确提出,大力推进高标准农田建设,到2022年,建成10亿亩高标准农田,以此稳定保障1万亿斤以…

《C语言文件处理:从新手到高手的跃迁》

📃博客主页: 小镇敲码人 💚代码仓库,欢迎访问 🚀 欢迎关注:👍点赞 👂🏽留言 😍收藏 🌏 任尔江湖满血骨,我自踏雪寻梅香。 万千浮云遮碧…

寻找最大价值的矿堆 - 矩阵

系列文章目录 文章目录 系列文章目录前言一、题目描述二、输入描述三、输出描述四、Java代码五、测试用例 前言 本人最近再练习算法,所以会发布一些解题思路,希望大家多指教 一、题目描述 给你一个由’0’(空地)、‘1’(银矿)、‘2’(金矿)组成的地图…

Spring Cloud Gateway 全局过滤器

系列文章目录 文章目录 系列文章目录前言 前言 前些天发现了一个巨牛的人工智能学习网站,通俗易懂,风趣幽默,忍不住分享一下给大家。点击跳转到网站,这篇文章男女通用,看懂了就去分享给你的码吧。 全局过滤器作用于所…

Visual Studio 安装教程 超级详细 (亲测有效)

1.1 VS2019安装 网址:Visual Studio: 面向软件开发人员和 Teams 的 IDE 和代码编辑器 下载完成之后双击.exe文件 步骤严格如下安装 默认语音包为中文(简体) 安装位置可以自行选择,完成以后就可以点击安装了。 安装完毕以后需要重…

java技术总结

1.java基本数据类型? byte 1,short 2 ,int 4,long 8 ,float 4,double 8,boolean 1,char 2 2.java为什么要有包装类型? 前 6 个类派生于公共的超类 Number,而 Character 和 Boolean 是 Object 的直接子类。 被 final 修饰, Java 内置的包装类是无法被继承的。 包装…

QT 小项目:登录注册账号和忘记密码(下一章实现远程登录)

一、环境搭建 参考上一章环境 二、项目工程目录 三、主要源程序如下: registeraccountwindow.cpp 窗口初始化: void registeraccountWindow::reginit() {//去掉?号this->setWindowFlags(windowFlags() & ~Qt::WindowContextHelpButt…

用标准的GNU/Linux命令替换Alpine上的精简版命令

Alpine Linux 是一个基于 musl libc 和 busybox 的轻量级Linux发行版,busybox 实现了很多常用类Unix命令的精简版,特点是体积很小,舍弃了很多不常用参数,我们简单对比一下标准Linux自带的 date 命令 和 Alpine下默认的 date 命令便…

【联通支付注册/登录安全分析报告】

联通支付注册/登录安全分析报告 前言 由于网站注册入口容易被黑客攻击,存在如下安全问题: 暴力破解密码,造成用户信息泄露短信盗刷的安全问题,影响业务及导致用户投诉带来经济损失,尤其是后付费客户,风险巨…

格雷希尔GripSeal:E10系列低压信号电测试连接器,应用于新能源汽车的DCR测试和EOL测试

新能源车的电驱动、电池包等都有一些信号接口,从几针到几十针不等,而且每种接口都有独特的电性能要求,这些接口在电池包进DCR测试或是EOL测试时,为了满足这些信号接口的需求,我们设计了E10系列信号针快速接头&#xff…

5月10日学习记录

[NCTF2019]True XML cookbook(xxe漏洞利用) 这题是关于xxe漏洞的实际应用,利用xxe漏洞的外部实体来进行ssrf探针内网的主机 和[NCTF2019]Fake XML cookbook的区别就在于xxe漏洞的利用方向,一个是命令执行,一个是SSRF 看题,打开…

Java进阶08 集合(续)Stream流

Java进阶08 集合(续)&Stream流 一、HashSet集合类(续) 1、JDK7(-)HashSet原理解析 1.1 底层结构 数组链表 1.2 执行过程 ①创建一个默认长度为16的数组,数组名为table ②根据元素的哈希值跟数组的长度求余计…

AcwingWeb应用课学习笔记

VSCode自动格式化 选中Format On Save不起作用 在设置中搜索default formatter,修改成Prettier-Code formatter meta标签 HTML 元素表示那些不能由其它 HTML 元相关(meta-related)元素((、,

网络补充笔记

目录 OSI 开放式系统互联参考模型 --- 7层参考模型 UDP:用户数据报文协议 --- 非面向不可靠的传输协议;传输层基本协议,仅完成传输层的基本工作 --- 分段、端口号 TCP:传输控制协议 --- 面向连接的可靠性传输协议 出了完成传输层…

揭秘APP广告变现:自建平台收益倍增秘诀

在数字广告领域,应用(APP)广告变现项目是实现收益的重要途径。随着移动互联网的蓬勃发展,自建平台进行广告投放和收益优化成为了众多开发者和企业关注的焦点。为了确保最大化收益,我们不仅需要对广告市场有深刻的了解&…

高性能运营级流媒体服务框架:支持多协议互转 | 开源日报 No.250

ZLMediaKit/ZLMediaKit Stars: 12.6k License: NOASSERTION ZLMediaKit 是一个基于 C11 的高性能运营级流媒体服务框架。 使用 C11 开发,避免裸指针,代码稳定可靠,性能优越。支持多种协议 (RTSP/RTMP/HLS/HTTP-FLV/WebSocket-FLV/GB28181 等…

武汉星起航助力新手卖家掌握亚马逊政策,开启跨境电商新征程

在数字化浪潮席卷全球的今天,亚马逊平台以其强大的影响力和广阔的市场前景,吸引了越来越多的卖家涌入其中。然而,对于初涉亚马逊市场的新手卖家而言,如何在激烈的市场竞争中立足,并成功开展跨境电商业务,却…

LaTeX公式学习笔记

\sqrt[3]{100} \frac{2}{3} \sum_{i0}^{n} x^{3} \log_{a}{b} \vec{a} \bar{a} \lim_{x \to \infty} \Delta A B C \alpha αΑ\xiξ\XiΞ\beta βΒ\pi π\PiΠ\gamma γ\GammaΓ\varpiϖ\delta δ\DeltaΔ\rhoρΡ\epsilon ϵΕ\varrho ϱ\varepsilo…

MySql数据库基础知识

大家好,在当今软件世界中,软件测试人员肩负着至关重要的职责,确保软件的质量与稳定性。而对于软件测试工作来说,了解 MySQL 基础知识是一项极具价值的技能。MySQL 作为广泛应用的关系型数据库管理系统,在众多软件项目中…