分布式深度学习中的数据并行和模型并行

🎀个人主页: https://zhangxiaoshu.blog.csdn.net
📢欢迎大家:关注🔍+点赞👍+评论📝+收藏⭐️,如有错误敬请指正!
💕未来很长,值得我们全力奔赴更美好的生活!

前言

对于深度学习模型的预训练阶段,海量的训练数据、超大规模的模型给深度学习带来了日益严峻的挑战,因此,经常需要使用多加速卡和多节点来并行化训练深度神经网络。目前,数据并行和模型并行作为两种在深度神经网络中常用的并行方式,分别针对不同的适用场景,有时也可将两种并行混合使用。本文对数据并行和模型并行两种在深度神经网络中常用的并行方式原理及其通信容量的计算方法进行介绍。


文章目录

  • 前言
  • 一、深度神经网络求解原理回顾
  • 二、数据并行
  • 三、模型并行
  • 总结


一、深度神经网络求解原理回顾

深度神经网络是通过模仿生物大脑的神经元结构而设计的一种多层互连结构.在其训练过程中,数据输入神经网络经过网络的前向传播过程得到一个输出,然后对输出得预测值和真实值求相对误差将其作为损失函数,接着,对网络进行反向传播求得损失对权重参数得梯度信息,最后,使用得到的梯度信息对权重参数做梯度下降使得损失函数越来越小,如此反复这个过程,使得神经网络的预测结果变得越来越准确。
假设训练数据集为 X = [ x 1 , x 2 . . . x N , ] X=[x_1,x_2...x_N,] X=[x1,x2...xN,],数据集经过前向传播后输出预测值 f ( x i ) f(x_i) f(xi),真实值为 y i y_i yi,则损失函数可以表示为如下式所示。

R e m p ( f ) = 1 N ∑ i = 1 N L ( y i , f ( x i ) ) R_{emp}\left(f\right)=\frac{1}{N}\sum_{i=1}^{N}L\left(y_i,f\left(x_i\right)\right) Remp(f)=N1i=1NL(yi,f(xi))

其中, L ( ∙ ) L(\bullet) L()为损失函数,它主要用于衡量预测值和真实值之间差异的大小,差异越小,说明模型的预测越准确。对于不同问题的求解,往往具有不同的形式。根据上式可以得到求解深度神经网络的最优化表达式如下式所示。

f ∗ = arg ⁡ min ⁡ f ∈ F R e m p ( f ) f^{*}=\underset{f \in \boldsymbol{F}}{\arg \min } R_{\mathrm{emp}}(f) f=fFargminRemp(f)

即在假设空间 F F F中找到一个最优的模型 f ∗ f^\ast f使得 R e m p ( f ) R_{emp}(f) Remp(f)最小。

基于梯度的优化算法是DL中解决上述优化问题应用最广泛的算法。由于二阶梯度下降法的计算复杂度较高,一阶梯度下降法,尤其是带有mini-batch及其变体的随机梯度下降法(SGD)在DL中被广泛使用。SGD的更新规则如下式所示。

G t ( x t ) = ∇ F t ( x t ; ξ t ) G_t\left(x_t\right)=\nabla F_t\left(x_t;\xi_t\right) Gt(xt)=Ft(xt;ξt)

x t + 1 = x t − γ G t ( x t ) x_{t+1}=x_t-\gamma G_t\left(x_t\right) xt+1=xtγGt(xt)

这里的 x t ∈ R N x_t\in R^N xtRN是第 t t t次迭代时的N维模型参数, ξ t \xi_t ξt是随机抽样的小批量数据, γ \gamma γ是学习率(或步长)。SGD是一种迭代算法,迭代过程通常包含几个步骤:

  1. 它对一小批数据(即 ξ t \xi_t ξt)进行采样。
  2. 它执行前馈计算,以计算目标函数的损失值(即 F t ( x t ; ξ t ) F_t\left(x_t;\xi_t\right) Ft(xt;ξt))。
  3. 它执行反向传播以计算关于模型参数的梯度(即 ∇ F t ( x t ; ξ t ) ∇F_t\left(x_t;\xi_t\right) Ft(xt;ξt))。
  4. 最后,通过公式 x t + 1 = x t − γ G t ( x t ) x_{t+1}=x_t-\gamma G_t\left(x_t\right) xt+1=xtγGt(xt)更新模型参数。训练深层模型非常耗时,尤其是对于大型模型或数据集。使用分布式训练技术,利用多个处理器来加速训练过程变得很常见。

二、数据并行

数据并行是在不同设备上放置完整的模型,然后将数据划分在每个设备并行计算,如下图所示。
在这里插入图片描述
数据并行性是深度学习中普遍存在的一种技术,对每个输入批训练数据在所有设备之间分配,每个设备中存储着网络模型完整的权重。在更新模型权重之前,梯度在所有设备之间进行通信和聚合。数据并行性拥有计算效率高和易于实现等优点。然而,数据并行性依赖于数据并行工作块的数量来缩放批处理大小,并且不能在不影响模型质量的情况下任意增大。对于参数不能存储在单个设备的大型模型,数据并行性便不在适应。

在小型分布式规模下,数据并行可以具有非常不错的扩展性。然而,梯度聚合的通信成本随着深度学习模型大小的增大而增加,并极大的限制了大模型和较低通信带宽系统的训练效率。针对分布式深度学习的数据并行训练,其训练过程如下:

  1. 计算节点会从将硬盘或者网络中读出mini-batch大小的数据复制到内存中;
  2. 将数据从 CPU内存复制到 GPU内存;
  3. 加载GPU kernel并从前到后分层进行计算输入数据的预测值(正向传播);
  4. 计算预测值和真实值的损失函数(loss)并进行反向传播,逐层求出损失对权重参数的梯度值;
  5. 将各个结点的梯度值进行同步 (发送和接收梯度,即,梯度通信);
  6. 利用同步后的梯度值结合优化算法对神经网络的权重参数进行更新;

以上6步构成了一个神经网络的学习过程,也就是一个Itera。在实际训练中,为了实现对神经网络的参数进行训练,必须进行多次的训练。在以上的训练过程中,网络通信发生的环节为一、二、五步。在第一步中,如果使用本机磁盘来提供资料,那么就不会有通信处理。第二步包括服务器之间的通信,这是用PCI-e把数据传送到 GPU。在第五步中,网络的参数量大小和规模主要由神经网络的参数和网络层的数目决定。在一般情况下,一个 Iter中的各个结点所需传送和接收的通信数据量均与神经网络的总参数值大小相等,而所需传送的数目则与神经网路的层数有关。所以,在每个层次上传送的通信数据量是不一样的,而频率区间也是由运算速度决定的。对于常见的CNN网络,其卷积层参数量要小于全连接层,所以在反向传输时,各个网络层的通信量会出现先大后小的不平衡问题。

因为数据并行需要每个设备将自己模型参数的梯度信息向其他设备传输。所以其通信容量往往与每一个批次的数据量多少无关,而与模型的大小和并行的设备数量有关。则对于深度神经网络的数据并行训练,其总的通信容量如下式所示。

V c o m m u n i c a t i o n = P a r × B y t e × N × ( N − 1 ) V_{communication}=Par×Byte×N×(N-1) Vcommunication=Par×Byte×N×(N1)

其中 P a r Par Par表示模型的参数量, B y t e Byte Byte是参数的表示形式,单位为字节,深度神经网络训练时通常取4字节即32位来表示参数。 N N N为并行计算的设备数量。

三、模型并行

模型并行是将模型分割成不同的块放到不同的设备上,按照划分方式的不同主要有以下图所示两种形式。
在这里插入图片描述
在数据并行的情况下,整个模型都存储在内存中,不过有时会数据量很大。如果是一般的计算机,那么内存就会不够,面对这种情况,这个巨大的模型可以分解成不同的部分用不同的机器进行计算,从计算角度上讲,就是将张量分成几个部分,从模型上讲,就是将网络的结构分割开来。切分方法有两种,一种是垂直切分(左图),另一种是水平切分(右图)。

垂直切分时形成多个分区,相同的分区放在同一设备上,每一个分区在不同的设备上并行执行。在这种形式下,某一层某个神经元的输入只有此设备上来自上一层的特征,而位于其他设备上的输入却不能得到。因此,为了避免这种情况,需要在关键的一些层处进行设备之间的通信,以融合不同设备上的特征。对于第i层其总的通信容量如下式所示。

V i = o u t × B y t e × N × ( N − 1 ) V_i=out\times Byte\times N\times(N-1) Vi=out×Byte×N×(N1)

其中 o u t out out表示每一个设备上输出的特征数量, B y t e Byte Byte是参数的表示形式,单位为字节, N N N为并行计算的设备数量。故,对于垂直切分时的模型并行来说其总的通信容量如下式所示。

V c o m m u n i c a t i o n = ∑ V i i ∈ ( 1 , 2... L ) V_{communication}=\sum V_i\ \ \ \ \ i\in(1,2...L) Vcommunication=Vi     i(1,2...L)

其中 L L L表示模型总的层数,这里的 i i i根据具体情况选取 1 1 1 L L L中的几个。

从以上两式中可以看到,对于垂直切分的模型并行来说,其通信容量主要受到输出特征值数量、选取的通信层数量、设备数量有关。

而对于水平切分,在这种模型并行形式下,可以将几个层划分给一个设备,不同设备划分得到的层不一致,因为在这种形式下后后面的层需要前面层的输出结果,每个设备要将自己计算的特征传输给下一层。所以前后阶段流水分批工作,然而,在这种情况下,第一个设备计算时,后面的设备都处于不工作状态,这很大程度上降低了并行性。 为了提高并行度,将每一个层再进行按区划分,第一个设备先执行第一个层的分区1,执行完之后开始执行分区2,这时设备2执行第二个层的分区1,如此反复计算传播以得到最终结果。对于水平切分时的模型并行来说其总的通信容量如下式所示。

V c o m m u n i c a t i o n = ∑ i = 1 N o u t i × B y t e V_{communication}=\sum_{i=1}^{N}{{out}_i\times B y t e} Vcommunication=i=1Nouti×Byte

其中 o u t i {out}_i outi表示第 i i i个设备的输出特征量, B y t e Byte Byte是参数的表示形式,单位为字节, N N N为并行计算的设备数量。

从上式中可以看到,对于水平切分的模型并行来说,其通信容量主要受输出特征值数量、设备数量等影响。


总结

数据并行和模型并行是在分布式计算中常用的两种并行计算策略,用于加速机器学习模型的训练过程。以下是它们的主要特点和区别总结:

数据并行(Data Parallelism):

  • 特点: 在数据并行中,不同的处理单元(通常是不同的计算节点或设备)负责处理不同的数据子集。每个处理单元独立地计算模型的梯度,并在一定周期后进行参数更新。
  • 优点: 数据并行易于实现,尤其是在拥有大量相似数据的情况下。它能够有效地利用大规模并行计算资源。
  • 缺点: 数据传输和同步操作可能成为性能瓶颈,尤其是当模型参数量较大时。此外,对于某些较大的模型结构,数据并行可能会受到单卡GPU显存的限制。

模型并行(Model Parallelism):

  • 特点: 在模型并行中,模型被划分成多个部分,不同的处理单元负责计算不同部分的输出。这通常用于处理较大且无法完全放入内存的模型。
  • 优点: 模型并行可以处理超大规模的模型,因为不需要一次性加载整个模型。这对于深度、复杂的模型是一个重要的优势。
  • 缺点: 实现模型并行通常较为复杂,因为需要确保各个部分的输出正确传递并在联合训练中协同工作。此外,同步问题也可能影响性能。

总体而言,数据并行和模型并行通常可以结合使用,以充分发挥分布式计算资源的优势。同时,具体选择使用哪种并行策略取决于问题的性质、模型的结构以及可用的硬件资源。

另外,在部分其他文献或是介绍中,模型并行的垂直切分往往被称之为Tensor并行而模型并行的水平切分往往被称之为流水并行

文中有不对的地方欢迎指正。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/641449.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

LLMs之Vanna:Vanna(利用自然语言查询数据库的SQL工具+底层基于RAG)的简介、安装、使用方法之详细攻略

LLMs之Vanna:Vanna(利用自然语言查询数据库的SQL工具底层基于RAG)的简介、安装、使用方法之详细攻略 目录 Vanna的简介 1、用户界面 2、RAG vs. Fine-Tuning 3、为什么选择Vanna? 4、扩展Vanna Vanna的安装和使用方法 1、安装 2、训练 (1)、使用…

c#中使用UTF-8编码处理多语言文本的有效策略

使用UTF-8编码处理多语言文本的有效策略 在当今的全球化时代,软件开发者常常需要处理包含多种语言的文本。这不仅涉及英文和其他西方语言,还包括中文、日文、韩文等多字节字符系统。在这篇博客中,我将探讨如何有效地使用UTF-8编码来处理混合语…

项目管理认证 | 什么是PMP项目管理?PMP证书有什么用?

01 什么是项目管理? 项目管理?听起来似乎离我们很遥远。其实不然, 学习了项目管理知识后,你会发现,“一切都是项目,一切也将成为项目”。 你可以把港珠澳大桥的建设、开发一款新型手机、开发一个好用的C…

HarmonyOS 发送http网络请求

好 本文 我们来说 http请求 首先 我们要操作网络内容 需要申请权限 项目中找到 main目录下的module.json5 最下面加上 "requestPermissions": [{"name": "ohos.permission.INTERNET"} ]这里 我在本地写了一个get接口 大家可以想办法 弄一个后…

RabbitMQ交换机

目录 交换机类型 直连交换机:Direct exchange 主题交换机:Topic exchange 扇形交换机:Fanout exchange 首部交换机:Headers exchange 死信交换机:Dead Letter Exchange 交换机的属性 代码实战 直连&#…

x-cmd pkg | frp - 内网穿透工具

简介 frp(Fast Reverse Proxy)是一个专注于内网穿透的高性能反向代理应用,可以将内网服务以安全、便捷的方式通过具有公网 IP 节点的中转暴露到公网。 它采用 C/S 模式,将服务端部署在具有公网 IP 的机器上,客户端部…

使用torch实现RNN

在实验室的项目遇到了困难,弄不明白LSTM的原理。到网上搜索,发现LSTM是RNN的变种,那就从RNN开始学吧。 带隐藏状态的RNN可以用下面两个公式来表示: 可以看出,一个RNN的参数有W_xh,W_hh,b_h&am…

[AutoSar]BSW_OS 06 Autosar OS_Alarms

一、 目录 一、关键词平台说明一、Timer1.1 配置1.2Periodical Interrupt Timer (PIT)和High Resolution Timer (HRT) 二、Alarm 工作机制三、Code3.1创建一个15ms的runnable3.2mapping到basic task3.3生成代码 关键词 嵌入式、C语言、autosar、OS、BSW 平台说明 项目ValueO…

k8s的helm

1、在没有helm之前,部署deployment、service、ingress等等 2、helm的作用:通过打包的方式,deployment、service、ingress这些打包在一块,一键部署服务、类似于yum功能 3、helm:官方提供的一种类似于仓库的功能&#…

时间轮设计

目录 基本概念 函数定义 函数实现与测试 测试1结果如下 测试2结果如下 基本概念 时间轮 是一种 实现延迟功能(定时器) 的 巧妙算法。如果一个系统存在大量的任务调度,时间轮可以高效的利用线程资源来进行批量化调度。把大批量的调度任务…

React16源码: React中的resetChildExpirationTime的源码实现

resetChildExpirationTime 1 )概述 在 completeUnitOfWork 当中,有一步比较重要的一个操作,就是重置 childExpirationTimechildExpirationTime 是非常重要的一个时间节点,它用来记录某一个节点的子树当中,目前优先级最…

C++提高编程——STL:string容器、vector容器

本专栏记录C学习过程包括C基础以及数据结构和算法,其中第一部分计划时间一个月,主要跟着黑马视频教程,学习路线如下,不定时更新,欢迎关注。 当前章节处于: ---------第1阶段-C基础入门 ---------第2阶段实战…

数据结构:堆与堆排序

目录 堆的定义: 堆的实现: 堆的元素插入: 堆元素删除: 堆初始化与销毁: 堆排序: 堆的定义: 堆是一种完全二叉树,完全二叉树定义如下: 一棵深度为k的有n个结点的二…

ffmpeg使用及java操作

1.文档 官网: FFmpeg 官方使用文档: ffmpeg Documentation 中文简介: https://www.cnblogs.com/leisure_chn/p/10297002.html 函数及时间: ffmpeg日记1011-过滤器-语法高阶,逻辑,函数使用_ffmpeg gte(t,2)-CSDN博客 java集成ffmpeg: SpringBoot集成f…

科技云报道:金融大模型落地,还需跨越几重山?

科技云报道原创。 时至今日,大模型的狂欢盛宴仍在持续,而金融行业得益于数据密集且有强劲的数字化基础,从一众场景中脱颖而出。 越来越多的公司开始布局金融行业大模型,无论是乐信、奇富科技、度小满、蚂蚁这样的金融科技公司&a…

深度学习如何弄懂那些难懂的数学公式?是否需要学习数学?

经过1~2年的学习,我觉得还是需要数学有一定认识,重新捡起高等数学、概率与数理、线代等这几本,起码基本微分方程、求导、对数、最小损失等等还是会用到。 下面给出几个链接,可以用于平时充电学习。 知乎上的: 机器学…

计算机毕业设计 基于SpringBoot的律师事务所案件管理系统的设计与实现 Java实战项目 附源码+文档+视频讲解

博主介绍:✌从事软件开发10年之余,专注于Java技术领域、Python人工智能及数据挖掘、小程序项目开发和Android项目开发等。CSDN、掘金、华为云、InfoQ、阿里云等平台优质作者✌ 🍅文末获取源码联系🍅 👇🏻 精…

git merge和git rebase区别

具体详情 具体常见如下,假设有master和change分支,从同一个节点分裂,随后各自进行了两次提交commit以及修改。随后即为change想合并到master分支中,但是直接git commit和git push是不成功的,因为分支冲突了【master以…

上位机图像处理和嵌入式模块部署(流程)

【 声明:版权所有,欢迎转载,请勿用于商业用途。 联系信箱:feixiaoxing 163.com】 前面我们说过,传统图像处理的方法,一般就是pccamera的处理方式。camera本身只是提供基本的raw data数据,所有的…

基于ADAS的车道线检测算法matlab仿真

目录 1.课题概述 2.系统仿真结果 3.核心程序与模型 4.系统原理简介 4.1 图像预处理 4.2 车道线特征提取 4.3 车道线跟踪 5.完整工程文件 1.课题概述 基于ADAS的车道线检测算法,通过hough变换和边缘检测方法提取视频样板中的车道线,然后根据车道线的弯曲情况…