政安晨:【Keras机器学习实践要点】(十三)—— 利用 TensorFlow 进行多 GPU 分布式训练

目录

前言

设置

单主机、多设备同步培训

工作原理

如何使用

使用回调确保容错

tf.data 性能提示

数据集批处理注意事项

调用 dataset.cache()

调用 dataset.prefetch(buffer_size)


政安晨的个人主页政安晨

欢迎 👍点赞✍评论⭐收藏

收录专栏: TensorFlow与Keras机器学习实战

希望政安晨的博客能够对您有所裨益,如有不足之处,欢迎在评论区提出指正!

本文是使用 TensorFlow 对 Keras 模型进行多 GPU 训练的指南。

前言

在多台设备之间分配计算通常有两种方法:

数据并行即在多个设备或多台机器上复制单个模型。它们各自处理不同批次的数据,然后合并结果。这种设置有很多变体,不同的模型副本合并结果的方式不同,它们是在每个批次保持同步,还是更松散地耦合等。

模型并行即一个模型的不同部分在不同设备上运行,同时处理一批数据。这种方法最适用于具有天然并行架构的模型,例如具有多个分支的模型。

本指南侧重于数据并行性,尤其是同步数据并行性,即模型的不同副本在处理每个批次后保持同步。同步性可使模型收敛行为与单设备训练时的收敛行为保持一致。

具体来说,本文将教您如何使用 tf.distribute API 在单台机器上安装的多个 GPU(通常为 2 到 16 个)上对 Keras 模型进行训练,只需对代码进行最小的修改(单主机、多设备训练)。这是研究人员和小规模行业工作流程最常见的配置。


设置

import osos.environ["KERAS_BACKEND"] = "tensorflow"import tensorflow as tf
import keras

单主机、多设备同步培训

在这种设置中,一台机器上有多个 GPU(通常为 2 到 16 个)。每个设备将运行一个模型副本(称为副本)。为简单起见,在下文中,我们将假设使用 8 个 GPU,但这并不影响其通用性。

工作原理

训练的每个阶段

当前批次的数据(称为全局批次)会被分成 8 个不同的子批次(称为局部批次)例如,如果全局批次有 512 个样本,那么 8 个局部批次中的每个批次将有 64 个样本。
8 个副本中的每个副本都会独立处理一个本地批次它们先运行一个前向传递,然后运行一个后向传递,输出权重相对于本地批次上模型损失的梯度。
源于本地梯度的权重更新会在 8 个副本中有效合并由于这是在每一步结束时进行的,因此各副本始终保持同步。

实际上,同步更新模型副本权重的过程是在每个权重变量的层面上进行的。这是通过镜像变量对象完成的。

如何使用

要使用 Keras 模型进行单主机、多设备同步训练,您需要使用 tf.distribute.MirroredStrategy API。下面是其工作原理:

实例化 MirroredStrategy,可选择配置要使用的特定设备(默认情况下,该策略将使用所有可用的 GPU)。
使用该策略对象打开一个作用域,并在该作用域中创建所需的包含变量的所有 Keras 对象。通常,这意味着在分发作用域内创建和编译模型。在某些情况下,对 fit() 的首次调用也可能会创建变量,因此最好也将 fit() 调用放在该作用域中。
像往常一样通过 fit() 训练模型

重要的是,我们建议您使用 tf.data.Dataset 对象在多设备或分布式工作流中加载数据。

从结构上看,是这样的:

# Create a MirroredStrategy.
strategy = tf.distribute.MirroredStrategy()
print('Number of devices: {}'.format(strategy.num_replicas_in_sync))# Open a strategy scope.
with strategy.scope():# Everything that creates variables should be under the strategy scope.# In general this is only model construction & `compile()`.model = Model(...)model.compile(...)# Train the model on all available devices.model.fit(train_dataset, validation_data=val_dataset, ...)# Test the model on all available devices.model.evaluate(test_dataset)

下面是一个简单的端到端可运行示例

def get_compiled_model():# Make a simple 2-layer densely-connected neural network.inputs = keras.Input(shape=(784,))x = keras.layers.Dense(256, activation="relu")(inputs)x = keras.layers.Dense(256, activation="relu")(x)outputs = keras.layers.Dense(10)(x)model = keras.Model(inputs, outputs)model.compile(optimizer=keras.optimizers.Adam(),loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),metrics=[keras.metrics.SparseCategoricalAccuracy()],)return modeldef get_dataset():batch_size = 32num_val_samples = 10000# Return the MNIST dataset in the form of a [`tf.data.Dataset`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset).(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()# Preprocess the data (these are Numpy arrays)x_train = x_train.reshape(-1, 784).astype("float32") / 255x_test = x_test.reshape(-1, 784).astype("float32") / 255y_train = y_train.astype("float32")y_test = y_test.astype("float32")# Reserve num_val_samples samples for validationx_val = x_train[-num_val_samples:]y_val = y_train[-num_val_samples:]x_train = x_train[:-num_val_samples]y_train = y_train[:-num_val_samples]return (tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(batch_size),tf.data.Dataset.from_tensor_slices((x_val, y_val)).batch(batch_size),tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(batch_size),)# Create a MirroredStrategy.
strategy = tf.distribute.MirroredStrategy()
print("Number of devices: {}".format(strategy.num_replicas_in_sync))# Open a strategy scope.
with strategy.scope():# Everything that creates variables should be under the strategy scope.# In general this is only model construction & `compile()`.model = get_compiled_model()# Train the model on all available devices.train_dataset, val_dataset, test_dataset = get_dataset()model.fit(train_dataset, epochs=2, validation_data=val_dataset)# Test the model on all available devices.model.evaluate(test_dataset)

结果如下: 

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)
Number of devices: 1
Epoch 1/21563/1563 ━━━━━━━━━━━━━━━━━━━━ 7s 4ms/step - loss: 0.3830 - sparse_categorical_accuracy: 0.8884 - val_loss: 0.1361 - val_sparse_categorical_accuracy: 0.9574
Epoch 2/21563/1563 ━━━━━━━━━━━━━━━━━━━━ 9s 3ms/step - loss: 0.1068 - sparse_categorical_accuracy: 0.9671 - val_loss: 0.0894 - val_sparse_categorical_accuracy: 0.9724313/313 ━━━━━━━━━━━━━━━━━━━━ 1s 2ms/step - loss: 0.0988 - sparse_categorical_accuracy: 0.9673

使用回调确保容错

使用分布式训练时,应始终确保有从故障中恢复的策略(容错)。最简单的处理方法是将 ModelCheckpoint 回调传递给 fit(),以定期保存模型(例如每 100 个批次或每个历元)。然后,您可以从保存的模型重新开始训练。

这里有一个简单的例子:

# Prepare a directory to store all the checkpoints.
checkpoint_dir = "./ckpt"
if not os.path.exists(checkpoint_dir):os.makedirs(checkpoint_dir)def make_or_restore_model():# Either restore the latest model, or create a fresh one# if there is no checkpoint available.checkpoints = [checkpoint_dir + "/" + name for name in os.listdir(checkpoint_dir)]if checkpoints:latest_checkpoint = max(checkpoints, key=os.path.getctime)print("Restoring from", latest_checkpoint)return keras.models.load_model(latest_checkpoint)print("Creating a new model")return get_compiled_model()def run_training(epochs=1):# Create a MirroredStrategy.strategy = tf.distribute.MirroredStrategy()# Open a strategy scope and create/restore the modelwith strategy.scope():model = make_or_restore_model()callbacks = [# This callback saves a SavedModel every epoch# We include the current epoch in the folder name.keras.callbacks.ModelCheckpoint(filepath=checkpoint_dir + "/ckpt-{epoch}.keras",save_freq="epoch",)]model.fit(train_dataset,epochs=epochs,callbacks=callbacks,validation_data=val_dataset,verbose=2,)# Running the first time creates the model
run_training(epochs=1)# Calling the same function again will resume from where we left off
run_training(epochs=1)

执行结果如下:

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)
Creating a new model
1563/1563 - 7s - 4ms/step - loss: 0.2275 - sparse_categorical_accuracy: 0.9320 - val_loss: 0.1373 - val_sparse_categorical_accuracy: 0.9571
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)
Restoring from ./ckpt/ckpt-1.keras
1563/1563 - 6s - 4ms/step - loss: 0.0944 - sparse_categorical_accuracy: 0.9717 - val_loss: 0.0972 - val_sparse_categorical_accuracy: 0.9710

tf.data 性能提示

在进行分布式训练时,加载数据的效率往往至关重要。以下是一些确保 tf.data 管道尽可能快速运行的技巧。

数据集批处理注意事项

创建数据集时,请确保使用全局批处理大小对数据集进行批处理。例如,如果 8 个 GPU 中的每个都能运行 64 个样本的批次,则全局批次大小为 512。

调用 dataset.cache()

如果在数据集上调用 .cache(),数据集的数据将在第一次迭代后被缓存。随后的每次迭代都将使用缓存数据。缓存数据可以是内存中的数据(默认),也可以是你指定的本地文件中的数据。

这可以在以下情况下提高性能

每次迭代时数据不会发生变化
从远程分布式文件系统读取数据
从本地磁盘读取数据,但数据可以放在内存中,而且工作流程对 IO 有很大限制(例如读取和解码图像文件)。

调用 dataset.prefetch(buffer_size)

创建数据集后,几乎总是要调用 .prefetch(buffer_size)。这意味着您的数据管道将与模型异步运行,在当前批次样本用于训练模型时,新样本将被预处理并存储在缓冲区中。当前批次结束时,下一批样本将被预取到 GPU 内存中。


这就是全部内容啦。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/786451.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

ssm015基于java的健身房管理系统的设计与实现+vue

健身房管理系统设计与实现 摘 要 现代经济快节奏发展以及不断完善升级的信息化技术,让传统数据信息的管理升级为软件存储,归纳,集中处理数据信息的管理方式。本健身房管理系统就是在这样的大环境下诞生,其可以帮助管理者在短时间…

MacBook 访达使用技巧【mac 入门】

快捷键 打开访达搜索窗口默认快捷键【⌥ ⌘ 空格键】可以在键盘【系统偏好设置 -> 键盘->快捷键->聚焦】修改 但是我不会去修改它,因为我不常用访达的搜索窗口,更多的是想快速打开访达文件夹窗口,可以通过第三方软件定义访达的快…

利用Node.js实现拉勾网数据爬取

引言 拉勾网作为中国领先的互联网招聘平台,汇集了丰富的职位信息,对于求职者和人力资源专业人士来说是一个宝贵的数据源。通过编写网络爬虫程序,我们可以自动化地收集这些信息,为求职决策和市场研究提供数据支持。Node.js以其非阻…

Qt 实现简易的视频播放器,功能选择视频,播放,暂停,前进,后退,进度条拖拉,视频时长显示

1.效果图 2.代码实现 2.1 .pro文件 QT core gui multimedia multimediawidgets 2.2 .h文件 #ifndef VIDEOPLAYING_H #define VIDEOPLAYING_H#include <QWidget> #include<QFileDialog>#include<QMediaPlayer> #include<QMediaRecorder> #in…

vue3中播放flv流视频,以及组件封装超全

实现以上功能的播放&#xff0c;只需要传入一个流的地址即可&#xff0c;当然组件也只有简单的实时播放功能 下面直接上组件 里面的flvjs通过npm i flv.js直接下载 <template><div class"player" style"position: relative;"><p style&…

深入了解 Vue 3 中的 Keyframes 动画

在本文中&#xff0c;我们将探讨如何在 Vue 3 中实现 Keyframes 动画。Keyframes 动画允许我们通过定义关键帧来创建复杂的动画效果&#xff0c;从而为用户提供更吸引人的界面体验。 transition动画适合用来创建简单的过渡效果。CSS3中支持使用animation属性来配置更加复杂的动…

Day5-

Hive 窗口函数 案例 需求&#xff1a;连续三天登陆的用户数据 步骤&#xff1a; -- 建表 create table logins (username string,log_date string ) row format delimited fields terminated by ; -- 加载数据 load data local inpath /opt/hive_data/login into table log…

蓝桥杯刷题day13——乘飞机【算法赛】

一、问题描述 等待登机的你看着眼前有老有小长长的队伍十分无聊&#xff0c;你突然想要知道&#xff0c;是否存在两个年龄相仿的乘客。每个乘客的年龄用一个 0 到 36500 的整数表示&#xff0c;两个乘客的年龄相差 365 以内就认为是相仿的。 具体来说&#xff0c;你有一个长度…

测开——Java、python、SQL、数据结构面试题整理

一、Java 1.Java中finally、final、finalize的区别 1.性质不同 &#xff08;1&#xff09;final为关键字; &#xff08;2&#xff09;finalize()为方法; &#xff08;3&#xff09;finally为为区块标志,用于try语句中; 2. 作用 &#xff08;1&#xff09;final为用于标识…

农业信息管理(源码+文档)

农业信息管理系统&#xff08;小程序、ios、安卓都可部署&#xff09; 文件包含内容程序简要说明功能项目截图客户端首页我的今日动态动态详情登录修改资料今日价格今日报价注册页 后端管理文章管理用户管理分类管理 文件包含内容 1、搭建视频 2、流程图 3、开题报告 4、数据库…

Python请求示例京东、淘宝商品数据(属性详情,sku价格页面上的数据抓取)

抓取京东、淘宝等电商平台的商品数据是一个复杂且需要谨慎处理的任务&#xff0c;因为这些平台通常会有反爬虫机制&#xff0c;并且页面结构也可能经常变化。以下是一个简化的Python请求示例&#xff0c;展示如何发起HTTP请求来获取页面内容&#xff0c;但这仅作为起点&#xf…

vue2+elementUi的两个el-date-picker日期组件进行联动

vue2elementUi的两个el-date-picker日期组件进行联动 <template><el-form><el-form-item label"起始日期"><el-date-picker v-model"form.startTime" change"startTimeChange" :picker-options"startTimePickerOption…

提升LLM效果的几种简单方法

其实这个文章想写很久了&#xff0c;最近一直在做大模型相关的产品&#xff0c;经过和团队成员一段时间的摸索&#xff0c;对大模型知识库做一下相关的认知和总结。希望最终形成一个系列。 对于知识库问答&#xff0c;现在有两种方案&#xff0c;一种基于llamaindex&#xff0…

Git简介

文章目录 Git是什么&#xff1f;安装Git使用git配置个人标识文件状态状态变化 汇总常用命令 Git是什么&#xff1f; Git是我们的代码管理工具&#xff0c;是一个代码仓库&#xff0c;让我们存储代码的。 安装Git 官网&#xff1a;https://git-scm.com/ &#xff08;傻瓜式安…

数据结构-链表的基本操作

前言&#xff1a; 在dotcpp上碰到了一道题&#xff0c;链接放这了&#xff0c;这道题就是让你自己构建一遍链表的创建&#xff0c;插入节点&#xff0c;删除节点&#xff0c;获取节点&#xff0c;输出链表&#xff0c;题目给了几张代码图&#xff0c;不过不用管那些图&#xf…

STM32一个地址未对齐引起的 HardFault 异常

1. 概述 客户在使用 STM32G070 的时候&#xff0c;KEIL MDK 为编译工具&#xff0c;当编译优化选项设置为Level0 的时候&#xff0c;程序会出现 Hard Fault 异常&#xff0c;而当编译优化选项设置为 Level1 的时候&#xff0c;则程序运行正常。表面上看&#xff0c;这似乎是 K…

ansible-tower安装

特别注意&#xff1a;不需要提前安装ansible&#xff0c;因为ansible tower中的setup.sh脚本会下载对应的ansible版本 ansible tower不支持Ubuntu系统,对cenos系统版本也有一定的限制&#xff0c;建议使用centos7.9。 准备一台全新的机器安装&#xff0c;因为ansible tower需要…

ARMv8-A架构下的外部debug模型(external debug)简介

Armv8-A external debug Armv8-A debug模型一&#xff0c;外部调试 External debug 简介二&#xff0c;Debug state2.1 Debug state的进入与退出 三&#xff0c;DAP&#xff0c;Debug Access Port3.1 EDSCR, External Debug Status and Control Register调试状态标识&#xff0…

Docker搭建LNMP环境实战(07):安装nginx

1、模拟应用场景描述 假设我要搭建一个站点&#xff0c;假设虚拟的域名为&#xff1a;api.test.site&#xff0c;利用docker实现nginxphp-fpmmariadb部署。 2、目录结构 2.1、dockers根目录 由于目前的安装是基于Win10VMWareCentOS虚拟机&#xff0c;同时已经安装了VMWareT…

《2023腾讯云容器和函数计算技术实践精选集》--在 K8s 上跑腾讯云 Serverless 函数,打破传统方式造就新变革

目录 目录 前言 《2023腾讯云容器和函数计算技术实践精选集》带来的思考 1、特色亮点 2、阅读体验 3、实用建议 4、整体评价 Serverless 和 K8s 的优势 1、关于Serverless 函数的特点 2、K8s 的特点 腾讯云 Serverless 函数在 K8s 上的应用对企业服务的影响 案例分…