TensorFlow入门(十二、分布式训练)

1、按照并行方式来分

        ①模型并行

                假设我们有n张GPU,不同的GPU被输入相同的数据,运行同一个模型的不同部分。

                在实际训练过程中,如果遇到模型非常庞大,一张GPU不够存储的情况,可以使用模型并行的分布式训练,把模型的不同部分交给不同的GPU负责。这种方式存在一定的弊端:①这种方式需要不同的GPU之间通信,从而产生较大的通信成本。②由于每个GPU上运行的模型部分之间存在一定的依赖,导致规模伸缩性差。

        ②数据并行

                假设我们有n张GPU,不同的GPU被输入不同的数据,运行相同的完整的模型。

                如果遇到一张GPU就能够存下一个模型的情况,可以采用数据并行的方式,这种方式的各部分独立,伸缩性好。

2、按照更新方式来分

        采用数据并行方式时,由于每个GPU负责一部分数据,涉及到如何更新参数的问题,因此分为同步更新和异步更新两种方式。

        ①同步更新

                所有GPU计算完每一个batch(也就是每批次数据)后,再统一计算新权值,等所有GPU同步新值后,再开始进行下一轮计算。

                同步更新的好处是loss的下降比较稳定,但是这个的坏处也很明显,这种方式有等待,处理的速度取决于最慢的那个GPU计算的时间。

        ②异步更新

                每个GPU计算完梯度后,无需等待其他GPU更新,立即更新整体权值并同步。

                异步更新的好处是计算速度快,计算资源能得到充分利用,但是缺点是loss的下降不稳定,抖动大。

3、按照算法来分

        ①Parameter Sever算法

                原理:假设我们有n张GPU,GPU0将数据分成n份分到各张GPU上,每张GPU负责自己那一批次数据的训练,得到梯度后,返回给GPU0上做累计,得到更新的权重参数后,再分发给各张GPU。

        ②Ring AllReduce算法

                原理:假设我们有n张GPU,它们以环形相连,每张GPU都有一个左邻和一个右邻,每张GPU向各自的右邻发送数据,并从它的左邻接近数据。循环n-1次完成梯度积累,再循环n-1次做参数同步。整个算法过程分两个步骤进行:首先是scatter_reduce,然后是allgather。在scatter-reduce,然后是allgather。在scatter-reduce步骤中,GPU将交换数据,使每个GPU可得到最终结果的一个块。在allgather步骤中,gpu将交换这些块,以便所有gpu得到完整的最终结果。

tf.distribute API:

        它是TensorFlow在多GPU、多机器上进行分布式训练用的API。使用这个API,可以在尽可能少改动代码的同时,分布式训练模型。

        它的核心API是tf.distribute.Strategy,只需简单几行代码就可以实现单机多GPU,多机多GPU等情况的分布式训练。

        它的主要优点:

                ①简单易用,开箱即用,高性能

                ②便于各种分布式Strategy切换

                ③支持Custom Training Loop、Estimator、Keras

                ④支持eager excution

tf.distribute.Strategy目前主要有四个Strategy:

        ①MirroredStrategy,即镜像策略

                MirroredStrategy用于单机多GPU、数据并行、同步更新的情况,它会在每个GPU上保存一份模型副本,模型中的每个变量都镜像在所有副本中。这些变量一起形成一个名为MirroredVariable的概念变量。通过apply相同的更新,这些变量保持彼此同步。

                创建一个镜像策略的方法如下:

                        mirrored_strategy = tf.distribute.MirroredStrategy()

                也可以自定义用哪些devices,如:

                        mirrored_strategy = tf.distribute.MirroredStrategy(devices=["/gpu:0","/gpu:1"])

                训练过程中,镜像策略用了高效的All-reduce算法来实现设备之间变量的传递更新。默认情况下它使用NVIDA NCCL (tf.distribute.NcclAllReduce)作为all-reduce算法的实现。通过apply相同的更新,这些变量保持彼此同步。

                官方也提供了其他的一些all-reduce实现方法,可供选择,如:

                        tf.distribute.CrossDeviceOps

                        tf.distribute.HierarchicalCopyAllReduce

                        tf.distribute.ReductionToOneDevice

        ②CentralStorageStrategy,即中心存储策略

                使用该策略时,参数被统一存在CPU里,然后复制到所有GPU上,它的优点是通过这种方式,GPU是负载均衡的,但一般情况下CPU和GPU通信代价比较大。

                创建一个中心存储策略的方法如下:

                             central_storage_strategy = tf.distribute.experimental.CentralStorageStratygy()

        ③MultiWorkerMirroredStrategy,即多端镜像策略

                该API和MirroredStrategy类似,它是其多机多GPU分布式训练的版本。

                创建一个多端镜像策略的方法如下:

                             multiworker_strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()

        ④ParameterServerStrategy,即参数服务策略

                简称PS策略,由于计算速度慢和负载不均衡,很少使用这种策略。

                创建一个参数服务策略的方法如下:

                              ps_strategy = tf.distribute.experimental.ParameterServerStrategy()

示例代码如下:

import tensorflow as tf#设置总训练轮数
num_epochs = 5
#设置每轮训练的批大小
batch_size_per_replica = 64
#设置学习率,指定了梯度下降算法中用于更新权重的步长大小
learning_rate = 0.001#创建镜像策略
strategy = tf.distribute.MirroredStrategy()
#通过同步更新时副本的数量计算出本机的GPU设备数量
print("Number of devices: %d"% strategy.num_replicas_in_sync)
#通过副本数量乘以每轮训练的批大小,得出训练总数据量的大小
batch_size = batch_size_per_replica * strategy.num_replicas_in_sync#函数将输入的图片调整为224x224大小,再将像素值除以255进行归一化,同时返回标签信息
def resize(image,label):image = tf.image.resize(image,[224,224])/255.0return image,label#载入数据集并预处理
dataset,_ = tf.keras.datasets.cifar10.load_data()
images,labels = dataset
dataset = tf.data.Dataset.from_tensor_slices((images, labels))
dataset = dataset.map(resize).shuffle(1024).batch(batch_size)#在strategy.scope下创建模型和优化器
with strategy.scope():#载入了MobileNetV2模型,该模型在ImageNet上预先训练好了,并可以在分类问题上进行微调model = tf.keras.applications.MobileNetV2()#设置训练时用的优化器、损失函数和准确率评测标准model.compile(optimizer = tf.keras.optimizers.Adam(learning_rate = learning_rate),loss = tf.keras.losses.sparse_categorical_crossentropy,metrics = [tf.keras.metrics.sparse_categorical_accuracy])#执行训练过程
model.fit(dataset,epochs = num_epochs)

对于CIFAR-10数据集下载过慢的问题,可以手动去官网下载

https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gzicon-default.png?t=N7T8https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz下载完成后将其放在如下图的路径下,并将数据集文件改名为cifar-10-batches-py.tar.gz并解压

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/103828.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

sklearn处理离散变量的问题——以决策树为例

最近做项目遇到的数据集中,有许多高维类别特征。catboost是可以直接指定categorical_columns的【直接进行ordered TS编码】,但是XGboost和随机森林甚至决策树都没有这个接口。但是在学习决策树的时候(无论是ID3、C4.5还是CART)&am…

嵌入式养成计划-40----C++菱形继承--虚继承--多态--模板--异常

九十四、菱形继承 94.1 概念 菱形继承又称为钻石继承,是由公共基类派生出多个中间子类,又由中间子类共同派生出汇聚子类,汇聚子类会得到多份中间子类从公共基类继承下来的数据成员,会造成空间浪费,没有必要。 所以存…

工程师必须记住的电路元件符号及英语翻译

很多电子小白第一次接触印刷电路板(PCB)时,总会头痛那些密密麻麻的元件字母符号,这些电路元件符号基本上都是采用英语缩写,下面我们来看看这些电路元件的英语符号有哪些? 电阻器(Resistor&#…

C++入门指南:类和对象总结友元类笔记(下)

C入门指南:类和对象总结友元类笔记(下) 一、深度剖析构造函数1.1 构造函数体赋值1.2 初始化列表1.3 explicit关键字 二、static成员2.1 概念2.2 特性 三、友元3.1 友元函数3.2 友元类 四、 内部类4.1 概念4.2 特征 五、拷贝对象时的一些编译器优化六、深…

Linux进阶-加深进程印象

目录 进程 进程状态转换 进程状态 启动新进程 system()函数 system.c文件 Makefile文件 执行过程 fork()函数 函数原型 fork.c文件 Makefile文件 执行过程 exec系列函数 函数原型 execl.c文件 Makrfile文件 执行过程 终止进程 exit()函数和_exit()函数 头…

机器人制作开源方案 | 杠杆式6轮爬楼机器人

1. 功能描述 本文示例将实现R281b样机杠杆式6轮爬楼机器人爬楼梯的功能(注意:演示视频中为了增加轮胎的抓地力,在轮胎上贴了双面胶,请大家留意)。 2. 结构说明 杠杆式6轮爬楼机器人是一种专门用于爬升楼梯或不平坦地面…

【elasticsearch】elasticsearch8.0.1使用rpm包安装并启用TLS

背景 公司的业务需要在加密的情况下使用,为此,研究测试了一下es8是如何启用TLS的。以下是测试使用过程。 x-pack了解 在 Elasticsearch 7.11.0 版本及更高版本中,X-Pack 功能在默认情况下已经整合到 Elastic Stack 的各个组件中&#xff0…

Element-UI的使用——表格el-table组件去除边框、滚动条设置、隔行变色、去除鼠标悬停变色效果(基于less)

// Element-ui table表格去掉所有边框,如下: // 备注:若去掉所有边框,可自行将头部边框注释掉即可 // 该样式写在style scoped外面在el-table 中添加class"customer-table"类名 //去掉每行的下边框/deep/ .el-table td.el-table__c…

LCR 161.连续天数最高销售额

​​题目来源: leetcode题目,网址:LCR 161. 连续天数的最高销售额 - 力扣(LeetCode) 解题思路: 动态规划。对于第 i 个元素 sales[i],若以第 i-1 个元素 sales[i-1] 为结尾的最大连续和 f(n-2)…

使用js怎么设置视频背景

要使用JavaScript设置网页的视频背景&#xff0c;你需要将视频元素添加到你的HTML文档中&#xff0c;然后使用JavaScript来控制它 首先&#xff0c;在你的HTML文件中添加一个 <video> 元素 <video id"video-background" autoplay muted loop><sourc…

勒索病毒最新变种.Malloxx勒索病毒来袭,如何恢复受感染的数据?

导言&#xff1a; 勒索病毒已经成为网络威胁的一个突出问题。其中&#xff0c;.Malloxx勒索病毒是一个危险的勒索软件&#xff0c;它能够加密你的数据文件&#xff0c;使其无法访问。本文91数据恢复将向您介绍.Malloxx勒索病毒的特点&#xff0c;以及如何恢复被其加密的数据文…

【MySQL】表的内连和外连

文章目录 一. 内连接二. 外连接1. 左外连接2. 右外连接 一. 内连接 利用where子句对两种表形成的笛卡尔积进行筛选&#xff0c;其实就是内连接的一种方式 另一种方式是inner join select 字段 from 表1 inner join 表2 on 连接条件 and 其他条件现在有如下表 mysql> desc…

M2芯片的Mac上安装Linux虚拟机——提前帮你踩坑

M2芯片的Mac上安装Linux虚拟机——提前帮你踩坑 1. 前言1.1 系统说明1.2 Linux系统选择——提前避坑1.3 下载vmware_fusion1.3.1 官网下载1.3.2 注册 CAPTCHA验证码问题1.3.3 产品说明 1.4 下载操作系统镜像1.4.1 下载centos&#xff08;如果版本合适的&#xff09;1.4.2 下载…

c 几种父进程主动终止子进程的方法

1.如子进程是循环状态 子进程循环等待键盘的输入&#xff0c;如父进程模拟键盘输入一个字符&#xff0c;子进程收到就跳出scanf&#xff08;&#xff09;。同理&#xff0c;如是socket 的accept&#xff08;&#xff09;等待&#xff0c;也可以发送一字符&#xff0c;让子进程…

Excel 自动提取某一列不重复值

IFERROR(INDEX($A$1:$A$14,MATCH(0,COUNTIF($C$1:C1,$A$1:$A$14),0)),"")注意&#xff1a;C1要空置&#xff0c;从C2输入公式 参考&#xff1a; https://blog.csdn.net/STR_Liang/article/details/105182654 https://zhuanlan.zhihu.com/p/55219017?utm_id0

c++视觉处理---直方图均衡化

直方图均衡化 直方图均衡化是一种用于增强图像对比度的图像处理技术。它通过重新分布图像的像素值&#xff0c;以使图像的直方图变得更均匀&#xff0c;从而提高图像的视觉质量。在OpenCV中&#xff0c;您可以使用 cv::equalizeHist 函数来执行直方图均衡化。以下是 cv::equal…

06-Zookeeper选举Leader源码剖析

上一篇&#xff1a;05-Zookeeper典型使用场景实战 一、为什么要看源码 提升技术功底&#xff1a;学习源码里的优秀设计思想&#xff0c;比如一些疑难问题的解决思路&#xff0c;还有一些优秀的设计模式&#xff0c;整体提升自己的技术功底深度掌握技术框架&#xff1a;源码看多…

Jenkins更换主目录

Jenkins储存所有的数据文件在这个目录下. 你可以通过以下几种方式更改&#xff1a; 使用你Web容器的管理工具设置JENKINS_HOME环境参数.在启动Web容器之前设置JENKINS_HOME环境变量.(不推荐)更改Jenkins.war(或者在展开的Web容器)内的web.xml配置文件. 这个值在Jenkins运行时…

ExcelBDD Python指南

在Python里面支持BDD Excel BDD Tool Specification By ExcelBDD Method This tool is to get BDD test data from an excel file, its requirement specification is below The Essential of this approach is obtaining multiple sets of test data, so when combined with…

【【萌新的SOC学习之自定义IP核 AXI4接口】】

萌新的SOC学习之自定义IP核 AXI4接口 自定义IP核-AXI4接口 AXI接口时序 对于一个读数据信号 AXI突发读 不要忘记 最后还有拉高RLAST 表示信号的中止 实验任务 &#xff1a; 通过自定义一个AXI4接口的IP核 &#xff0c;通过AXI_HP接口对PS端 DDR3 进行读写测试 。 S_AXI…