TensorFlow入门(十二、分布式训练)

1、按照并行方式来分

        ①模型并行

                假设我们有n张GPU,不同的GPU被输入相同的数据,运行同一个模型的不同部分。

                在实际训练过程中,如果遇到模型非常庞大,一张GPU不够存储的情况,可以使用模型并行的分布式训练,把模型的不同部分交给不同的GPU负责。这种方式存在一定的弊端:①这种方式需要不同的GPU之间通信,从而产生较大的通信成本。②由于每个GPU上运行的模型部分之间存在一定的依赖,导致规模伸缩性差。

        ②数据并行

                假设我们有n张GPU,不同的GPU被输入不同的数据,运行相同的完整的模型。

                如果遇到一张GPU就能够存下一个模型的情况,可以采用数据并行的方式,这种方式的各部分独立,伸缩性好。

2、按照更新方式来分

        采用数据并行方式时,由于每个GPU负责一部分数据,涉及到如何更新参数的问题,因此分为同步更新和异步更新两种方式。

        ①同步更新

                所有GPU计算完每一个batch(也就是每批次数据)后,再统一计算新权值,等所有GPU同步新值后,再开始进行下一轮计算。

                同步更新的好处是loss的下降比较稳定,但是这个的坏处也很明显,这种方式有等待,处理的速度取决于最慢的那个GPU计算的时间。

        ②异步更新

                每个GPU计算完梯度后,无需等待其他GPU更新,立即更新整体权值并同步。

                异步更新的好处是计算速度快,计算资源能得到充分利用,但是缺点是loss的下降不稳定,抖动大。

3、按照算法来分

        ①Parameter Sever算法

                原理:假设我们有n张GPU,GPU0将数据分成n份分到各张GPU上,每张GPU负责自己那一批次数据的训练,得到梯度后,返回给GPU0上做累计,得到更新的权重参数后,再分发给各张GPU。

        ②Ring AllReduce算法

                原理:假设我们有n张GPU,它们以环形相连,每张GPU都有一个左邻和一个右邻,每张GPU向各自的右邻发送数据,并从它的左邻接近数据。循环n-1次完成梯度积累,再循环n-1次做参数同步。整个算法过程分两个步骤进行:首先是scatter_reduce,然后是allgather。在scatter-reduce,然后是allgather。在scatter-reduce步骤中,GPU将交换数据,使每个GPU可得到最终结果的一个块。在allgather步骤中,gpu将交换这些块,以便所有gpu得到完整的最终结果。

tf.distribute API:

        它是TensorFlow在多GPU、多机器上进行分布式训练用的API。使用这个API,可以在尽可能少改动代码的同时,分布式训练模型。

        它的核心API是tf.distribute.Strategy,只需简单几行代码就可以实现单机多GPU,多机多GPU等情况的分布式训练。

        它的主要优点:

                ①简单易用,开箱即用,高性能

                ②便于各种分布式Strategy切换

                ③支持Custom Training Loop、Estimator、Keras

                ④支持eager excution

tf.distribute.Strategy目前主要有四个Strategy:

        ①MirroredStrategy,即镜像策略

                MirroredStrategy用于单机多GPU、数据并行、同步更新的情况,它会在每个GPU上保存一份模型副本,模型中的每个变量都镜像在所有副本中。这些变量一起形成一个名为MirroredVariable的概念变量。通过apply相同的更新,这些变量保持彼此同步。

                创建一个镜像策略的方法如下:

                        mirrored_strategy = tf.distribute.MirroredStrategy()

                也可以自定义用哪些devices,如:

                        mirrored_strategy = tf.distribute.MirroredStrategy(devices=["/gpu:0","/gpu:1"])

                训练过程中,镜像策略用了高效的All-reduce算法来实现设备之间变量的传递更新。默认情况下它使用NVIDA NCCL (tf.distribute.NcclAllReduce)作为all-reduce算法的实现。通过apply相同的更新,这些变量保持彼此同步。

                官方也提供了其他的一些all-reduce实现方法,可供选择,如:

                        tf.distribute.CrossDeviceOps

                        tf.distribute.HierarchicalCopyAllReduce

                        tf.distribute.ReductionToOneDevice

        ②CentralStorageStrategy,即中心存储策略

                使用该策略时,参数被统一存在CPU里,然后复制到所有GPU上,它的优点是通过这种方式,GPU是负载均衡的,但一般情况下CPU和GPU通信代价比较大。

                创建一个中心存储策略的方法如下:

                             central_storage_strategy = tf.distribute.experimental.CentralStorageStratygy()

        ③MultiWorkerMirroredStrategy,即多端镜像策略

                该API和MirroredStrategy类似,它是其多机多GPU分布式训练的版本。

                创建一个多端镜像策略的方法如下:

                             multiworker_strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()

        ④ParameterServerStrategy,即参数服务策略

                简称PS策略,由于计算速度慢和负载不均衡,很少使用这种策略。

                创建一个参数服务策略的方法如下:

                              ps_strategy = tf.distribute.experimental.ParameterServerStrategy()

示例代码如下:

import tensorflow as tf#设置总训练轮数
num_epochs = 5
#设置每轮训练的批大小
batch_size_per_replica = 64
#设置学习率,指定了梯度下降算法中用于更新权重的步长大小
learning_rate = 0.001#创建镜像策略
strategy = tf.distribute.MirroredStrategy()
#通过同步更新时副本的数量计算出本机的GPU设备数量
print("Number of devices: %d"% strategy.num_replicas_in_sync)
#通过副本数量乘以每轮训练的批大小,得出训练总数据量的大小
batch_size = batch_size_per_replica * strategy.num_replicas_in_sync#函数将输入的图片调整为224x224大小,再将像素值除以255进行归一化,同时返回标签信息
def resize(image,label):image = tf.image.resize(image,[224,224])/255.0return image,label#载入数据集并预处理
dataset,_ = tf.keras.datasets.cifar10.load_data()
images,labels = dataset
dataset = tf.data.Dataset.from_tensor_slices((images, labels))
dataset = dataset.map(resize).shuffle(1024).batch(batch_size)#在strategy.scope下创建模型和优化器
with strategy.scope():#载入了MobileNetV2模型,该模型在ImageNet上预先训练好了,并可以在分类问题上进行微调model = tf.keras.applications.MobileNetV2()#设置训练时用的优化器、损失函数和准确率评测标准model.compile(optimizer = tf.keras.optimizers.Adam(learning_rate = learning_rate),loss = tf.keras.losses.sparse_categorical_crossentropy,metrics = [tf.keras.metrics.sparse_categorical_accuracy])#执行训练过程
model.fit(dataset,epochs = num_epochs)

对于CIFAR-10数据集下载过慢的问题,可以手动去官网下载

https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gzicon-default.png?t=N7T8https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz下载完成后将其放在如下图的路径下,并将数据集文件改名为cifar-10-batches-py.tar.gz并解压

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/103828.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

sklearn处理离散变量的问题——以决策树为例

最近做项目遇到的数据集中,有许多高维类别特征。catboost是可以直接指定categorical_columns的【直接进行ordered TS编码】,但是XGboost和随机森林甚至决策树都没有这个接口。但是在学习决策树的时候(无论是ID3、C4.5还是CART)&am…

嵌入式养成计划-40----C++菱形继承--虚继承--多态--模板--异常

九十四、菱形继承 94.1 概念 菱形继承又称为钻石继承,是由公共基类派生出多个中间子类,又由中间子类共同派生出汇聚子类,汇聚子类会得到多份中间子类从公共基类继承下来的数据成员,会造成空间浪费,没有必要。 所以存…

工程师必须记住的电路元件符号及英语翻译

很多电子小白第一次接触印刷电路板(PCB)时,总会头痛那些密密麻麻的元件字母符号,这些电路元件符号基本上都是采用英语缩写,下面我们来看看这些电路元件的英语符号有哪些? 电阻器(Resistor&#…

C++入门指南:类和对象总结友元类笔记(下)

C入门指南:类和对象总结友元类笔记(下) 一、深度剖析构造函数1.1 构造函数体赋值1.2 初始化列表1.3 explicit关键字 二、static成员2.1 概念2.2 特性 三、友元3.1 友元函数3.2 友元类 四、 内部类4.1 概念4.2 特征 五、拷贝对象时的一些编译器优化六、深…

Linux进阶-加深进程印象

目录 进程 进程状态转换 进程状态 启动新进程 system()函数 system.c文件 Makefile文件 执行过程 fork()函数 函数原型 fork.c文件 Makefile文件 执行过程 exec系列函数 函数原型 execl.c文件 Makrfile文件 执行过程 终止进程 exit()函数和_exit()函数 头…

机器人制作开源方案 | 杠杆式6轮爬楼机器人

1. 功能描述 本文示例将实现R281b样机杠杆式6轮爬楼机器人爬楼梯的功能(注意:演示视频中为了增加轮胎的抓地力,在轮胎上贴了双面胶,请大家留意)。 2. 结构说明 杠杆式6轮爬楼机器人是一种专门用于爬升楼梯或不平坦地面…

【elasticsearch】elasticsearch8.0.1使用rpm包安装并启用TLS

背景 公司的业务需要在加密的情况下使用,为此,研究测试了一下es8是如何启用TLS的。以下是测试使用过程。 x-pack了解 在 Elasticsearch 7.11.0 版本及更高版本中,X-Pack 功能在默认情况下已经整合到 Elastic Stack 的各个组件中&#xff0…

M2芯片的Mac上安装Linux虚拟机——提前帮你踩坑

M2芯片的Mac上安装Linux虚拟机——提前帮你踩坑 1. 前言1.1 系统说明1.2 Linux系统选择——提前避坑1.3 下载vmware_fusion1.3.1 官网下载1.3.2 注册 CAPTCHA验证码问题1.3.3 产品说明 1.4 下载操作系统镜像1.4.1 下载centos(如果版本合适的)1.4.2 下载…

Excel 自动提取某一列不重复值

IFERROR(INDEX($A$1:$A$14,MATCH(0,COUNTIF($C$1:C1,$A$1:$A$14),0)),"")注意:C1要空置,从C2输入公式 参考: https://blog.csdn.net/STR_Liang/article/details/105182654 https://zhuanlan.zhihu.com/p/55219017?utm_id0

c++视觉处理---直方图均衡化

直方图均衡化 直方图均衡化是一种用于增强图像对比度的图像处理技术。它通过重新分布图像的像素值,以使图像的直方图变得更均匀,从而提高图像的视觉质量。在OpenCV中,您可以使用 cv::equalizeHist 函数来执行直方图均衡化。以下是 cv::equal…

06-Zookeeper选举Leader源码剖析

上一篇:05-Zookeeper典型使用场景实战 一、为什么要看源码 提升技术功底:学习源码里的优秀设计思想,比如一些疑难问题的解决思路,还有一些优秀的设计模式,整体提升自己的技术功底深度掌握技术框架:源码看多…

Jenkins更换主目录

Jenkins储存所有的数据文件在这个目录下. 你可以通过以下几种方式更改: 使用你Web容器的管理工具设置JENKINS_HOME环境参数.在启动Web容器之前设置JENKINS_HOME环境变量.(不推荐)更改Jenkins.war(或者在展开的Web容器)内的web.xml配置文件. 这个值在Jenkins运行时…

ExcelBDD Python指南

在Python里面支持BDD Excel BDD Tool Specification By ExcelBDD Method This tool is to get BDD test data from an excel file, its requirement specification is below The Essential of this approach is obtaining multiple sets of test data, so when combined with…

【【萌新的SOC学习之自定义IP核 AXI4接口】】

萌新的SOC学习之自定义IP核 AXI4接口 自定义IP核-AXI4接口 AXI接口时序 对于一个读数据信号 AXI突发读 不要忘记 最后还有拉高RLAST 表示信号的中止 实验任务 : 通过自定义一个AXI4接口的IP核 ,通过AXI_HP接口对PS端 DDR3 进行读写测试 。 S_AXI…

软件设计之抽象工厂模式

抽象工厂模式指把一个产品变成一个接口,它的子产品作为接口的实现,所以还需要一个总抽象工厂和它的分抽象工厂。 下面我们用一个案例去说明抽象工厂模式。 在class中可以选择super类和medium类,即选择一个产品的子类。在type中可以选择产品的…

c++处理图像---绘制物体的凸包:cv::convexHull

绘制物体的凸包:cv::convexHull cv::convexHull 是OpenCV中用于计算点集的凸包(convex hull)的函数。凸包是包围点集的最小凸多边形,该多边形的所有内部角都小于或等于 180 度。 cv::convexHull 函数的基本用法如下:…

Android Studio for Platform (ASfP) 使用教程

文章目录 编写脚本下载源代码lunch 查看版本 归纳的很清楚,下载Repo并下载源码->可以参考我的 Framework入门のPiex 6P源码(下载/编译/刷机) 启动图标(重启生效) [Desktop Entry] EncodingUTF-8 NameAndroidStudio …

大模型微调学习

用好大模型的层次:1. 提示词工程(prompt engineering); 2. 大模型微调(fine tuning)为什么要对大模型微调: 1. 大模型预训练成本非常高; 2. 如果prompt engineering的效果达不到要求,企业又有比较好的自有数据,能够通过…

Django实现音乐网站 ⒆

使用Python Django框架做一个音乐网站, 本篇主要为排行榜功能及音乐播放器部分功能实现。 目录 推荐排行榜优化 设置歌手、单曲跳转链接 排行榜列表渲染优化 视图修改如下: 模板修改如下: 单曲详情修改 排行榜列表 设置路由 视图处理…

MySQL建表操作和用户权限

1.创建数据库school,字符集为utf8 mysql> create database school character set utf8; 2.在school数据库中创建Student和Score表 mysql> create table school.student( -> Id int(10) primary key, -> Stu_id int(10) not null, -> C_n…