谈谈神经网络的大规模训练优化

文 | 立交桥跳水冠军

源 | 知乎


大规模神经网络训练一般会涉及到几百个分布式节点同时工作,模型的参数量以及运算量往往很大,作者认为在这个task下当前的工作主要归结为以下三种:对通信本身的优化,神经网络训练通信的优化,大规模下如何保持精度。

之前一段时间接触了大规模神经网络训练,看了不少优秀的工作,在这里当做笔记记下来。同时也希望可以抛砖引玉,和各位大佬交流一下这方面的现有工作以及未来的方向(1)大规模训练工作的几种类型大规模训练和普通分布式训练还是有区别的,主要体现在这个字上面。一般来说会涉及到几百个分布式节点同时工作,模型的参数量以及运算量往往很大(比如BERT,GPT3等等)我认为在这个task下当前的工作主要归结为以下三种:

  1. 对通信本身的优化

  2. 神经网络训练通信的优化

  3. 大规模下如何保持精度

其中1主要是通信库的优化,严格来说和神经网络本身并没有关系,这里面比较优秀的工作有经典的ring-base all-reduce(最先在百度的工作中被用于神经网络训练baidu-research/baidu-allreduce

https://github.com/baidu-research/baidu-allreduce

腾讯的分层通信:

https://arxiv.org/abs/1807.11205

以及sony的2D all-reduce(Massively Distributed SGD: ImageNet/ResNet-50 Training in a Flash:

https://arxiv.org/abs/1811.05233

而第2部分的工作都针对于如何在神经网络这个训练模式下做通信优化。这方面的思路很广,比如商汤提出的稀疏通信

https://arxiv.org/abs/1902.06855

杜克大学提出的TernGrad (TernGrad: Ternary Gradients to Reduce Communication in Distributed Deep Learning:

https://arxiv.org/abs/1705.07878

第三部分和前两个不同,主要关注点在于精度而非性能。在大规模训练的情况下,一种常见的做法是做数据并行,即把batch size设的很大,那么原来跑90个epoch需要迭代1000次的话,把batch size扩大10倍,就只需要迭代100次,即参数的更新次数减少了很多。如何在这种情况下收敛到小batch size也是一个棘手的问题。在这个领域比较好的工作有face book的线性倍增学习率(https://arxiv.org/pdf/1706.02677.pdf)以及伯克利尤洋的LAR算法(https://arxiv.org/pdf/1709.05011.pdf)。

对通信本身的优化

(懒得写了,偷个懒)我对这方面了解十分有限,推荐大家读腾讯团队写的介绍(兰瑞Frank:腾讯机智团队分享--AllReduce算法的前世今生:
https://zhuanlan.zhihu.com/p/79030485

神经网络的通信优化

分布式神经网络训练目前主要有两种模式:数据并行和模型并行。

数据并行比较简单,下面这张图是经典的数据并行的同步训练的场景:所有节点(即图中的GPU0-GPU3)都保存整个模型(粉色的Params),每次迭代,不同的节点会得到不同的数据,每个节点用得到的数据做正向和反向计算,得到每个参数的梯度。之后整个分布式系统会同步所有节点的梯度,即每个节点的local gradient做一次all reduce操作,得到全局的global gradient(最下面蓝色的Gradients)。每个节点用这个global gradient更新参数。

显而易见,数据并行基于一个假设:每个节点都可以放下整个模型。这个假设在如今某些模型上(说的就是你,GPT3!!!)是不合理的,因此我们还需要模型并行,即不同节点负责计算神经网络模型的不同部分(比如有一个100层的网络,那么我们可以让第一个节点存储前50层的参数,并负责计算前50层,另一个网络则负责后面50层)。

下面这张图摘自英伟达的Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism:
https://arxiv.org/abs/1909.08053

在这里演示了如何用两个节点去算连续的两个矩阵乘法。

我们要做的操作是首先算出Y=GeLU(XA),再算Z=Dropoug(YB)。其中,X,A,B都是矩阵,而且矩阵规模都很大。

假设我们希望用两个分布式节点完成这个计算,那么我们可以把矩阵A按colum切成A1,A2两份,分别存到节点0和节点1中。同时我们也把矩阵B按行切成B1,B2两份,分别存到节点0和节点1中。然后我们将X做一个broadcast(图中f部分),分别发送到两个节点上,算得Z1和Z2,在做一次all reduce(图中g部)将Z1和Z2相加,得到最终的Z。

这里面有一个很巧(也很绕)的地方,那就是为什么A要按列切,B要按行切?我们可不可以把它们反过来?答案是:最好不要,因为如果反过来,的确计算上可行,但是我们就会增加一次通信(即算Y=XA的时候我们就要做一次通信),这样显然速度会变慢。

展开来讲,数据并行和模型并行也可以细分。数据并行可以分为:

  • 同步式数据并行

  • 异步式数据并行

同步式比较简单,就是我最上面那张图演示的。

异步式复杂一些:我们很容易发现,最后全局all reduce gradient的时候会耗时比较多,分布式系统越大,消耗越大,而且这样做还有一个隐藏的假设:分布式系统是homogeneous的,即每个分布式节点不会差的很多。举个例子,如果每个节点实力相当,那么都会算10s就可以结束一个iteration,那么我们10s之后就可以开始一次通信。然而如果有一个节点(害群之马)需要算100s,那么其他节点算完之后就得干等它90s才能做通信,那么是对资源的极大浪费。

想想看,你的老板绝对不允许你(打工人)干坐着什么事都不干,只因为你的进度被别的同事block了。研究员也是如此,于是为了解决上面的问题,引入了异步式通信。简单来说就是如果遭遇了上面的情况,快的节点等一会儿就不等了,他们之间做一次通信然后接着算下一轮。这个节点什么时候算好什么时候再和其他人一起all reduce梯度。

这样做快是快了,但引入了另一个问题,那就是每个人的参数都不一样了,那么他们根据不同的参数算得的梯度再去做all reduce就有一些不合理,就会导致神经网络精度受损。

有很多工作尝试解决异步并行带来的精度损失,不过据我所知并没有特别general的方法,因此异步并行如今也很少被使用了。模型并行可以分为:

  • 粗粒度并行

  • 细粒度并行

它们的区别在于并行的层级:粗粒度每个节点会算不同的layer,而细粒度会将layer也做拆。

分粗粒度并行比较优秀的工作有google的GPipe(https://arxiv.org/pdf/1811.06965.pdf)

在粗粒度并行中,每个节点负责不同的layer,但是layer之间是存在数据依赖的,这就导致在之前的节点算的时候,后面的节点干等着。GPipe提出把数据按照batch纬度做切分得到多个micro batch,这样第一个节点先算第一个micro batch(图中F[0,0]),把算到的结果发给第二个节点去算,于是下一个时刻第二个节点在算第一个micro batch(F[1,0]),而第一个节点开始算第二个micro batch(F[0,1])。

细粒度并行比较好的工作除了我之前介绍的Megatron之外,还有GShard(GShard: Scaling Giant Models with Conditional Computation and Automatic Sharding(https://arxiv.org/abs/2006.16668)

这个工作主要的贡献在于提供了一套原语,允许最高层的开发者(写python的人)通过简单的方式指导代码生成(即编译器)生成对应的模型并行的代码。

后台回复关键词【入群

加入卖萌屋NLP/IR/Rec与求职讨论群

后台回复关键词【顶会

获取ACL、CIKM等各大顶会论文集!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/479470.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

LeetCode 1108. IP 地址无效化

文章目录1. 题目2. 解题1. 题目 给你一个有效的 IPv4 地址 address,返回这个 IP 地址的无效化版本。 所谓无效化 IP 地址,其实就是用 “[.]” 代替了每个 “.”。 示例 1:输入:address "1.1.1.1" 输出:&…

Android NDK开发入门学习笔记(图文教程,极其详尽)

以前也简单用过JNI,但是只是简单用一下,好多都不明白。最近在看源码部分,有涉及到JNI调用的,所以这次打算彻底把它搞定。 先普及一下JNI的调用关系:JAVA------------------------>JNI----------------------------…

论文浅尝 | 利用问题生成提升知识图谱问答

论文笔记整理:谭亦鸣,东南大学博士生,研究方向为知识库问答。来源:NLPCC2019链接:http://tcci.ccf.org.cn/conference/2019/papers/183.pdf本文提出了一种利用问题生成提升知识图谱问答模型性能的方法(一个…

顶会论文:基于神经网络StarNet的行人轨迹交互预测算法

1.背景 民以食为天,如何提升超大规模配送网络的整体配送效率,改善数亿消费者在”吃“方面的体验,是一项极具挑战的技术难题。面向未来,美团正在积极研发无人配送机器人,建立无人配送开放平台,与产学研各方共…

python操作mysql数据库实现增删改查

python操作mysql数据库实现增删改查 Python 标准数据库接口为 Python DB-API,Python DB-API为开发人员提供了数据库应用编程接口。 Python 数据库接口支持非常多的数据库,你可以选择适合你项目的数据库: GadFlymSQLMySQLPostgreSQLMicrosoft …

LeetCode 654. 最大二叉树(递归)

文章目录1. 题目2. 解题1. 题目 给定一个不含重复元素的整数数组。一个以此数组构建的最大二叉树定义如下: 二叉树的根是数组中的最大元素。 左子树是通过数组中最大值左边部分构造出的最大二叉树。 右子树是通过数组中最大值右边部分构造出的最大二叉树。 通过给…

Probe:Android线上OOM问题定位组件

配送骑手端App是骑手用于完成配送履约的应用,帮助骑手完成接单、到店、取货及送达,提供各种不同的运力服务,也是整个外卖闭环中的重要节点。由于配送业务的特性,骑手App对于应用稳定性的要求非常高,体现App稳定性的一个…

Android中使用官方提供好的功能使用说明(比如系统图库获取),也作为延生学习的学习文档

这篇文章最核心的就是去学习如何学习Android,如何去使用Android文档。 我们一般在刚开始接触开发的时候,如果遇到无法解决的问题,常常会百度,或者google去寻找答案,比如有个需求是获取系统中的图片,你可能…

再介绍一篇Contrastive Self-supervised Learning综述论文

文 | 黄浴源 | 知乎之前已经介绍过三篇自监督学习的综述:《怎样缓解灾难性遗忘?持续学习最新综述三篇!》。这是最近2020年10月arXiv上的又一篇论文"A Survey On Contrastive Self-supervised Learning"。论文地址:https…

GCN-Based User Representation Learning for Unifying Robust Recommendation and Fraudster Detection

GCN-Based User Representation Learning for Unifying Robust Recommendation and Fraudster Detection 点击率预测:其主要思想是根据用户的历史行为对一组未评级的项目进行评级预测,然后从预测评级最高的项目中选择个性化推荐。 欺诈检测:…

公开课 | 知识图谱构建与应用概述

本文转载自公众号:博文视点Broadview。 AI是新的生产力,知识图谱是AI进步的阶梯。随着近年来人工智能的进一步发展,知识图谱也取得了一系列新的进展,并在各个行业中落地应用。知识图谱的相关技术已经在搜索引擎、智能问答、…

LeetCode 217. 存在重复元素(哈希)

文章目录1. 题目2. 解题1. 题目 给定一个整数数组,判断是否存在重复元素。 如果任何值在数组中出现至少两次,函数返回 true。如果数组中每个元素都不相同,则返回 false。 示例 1:输入: [1,2,3,1] 输出: true 示例 2:输入: [1,2,3,4] 输出:…

美团BERT的探索和实践

2018年,自然语言处理(Natural Language Processing,NLP)领域最激动人心的进展莫过于预训练语言模型,包括基于RNN的ELMo[1]和ULMFiT[2],基于Transformer[3]的OpenAI GPT[4]及Google BERT[5]等。下图1回顾了近…

论文浅尝 | 探索将预训练语言模型用于事件抽取和事件生成

论文笔记整理:郝凯龙,南京大学硕士链接:https://www.aclweb.org/anthology/P19-1522.pdf动机传统的 ACE 事件抽取任务依赖于人工标注的数据,耗费大量的人力并且数据量有限,数据量不足给事件抽取带来了阻碍。传统的事件…

谷歌、CMU发文:别压榨单模型了!集成+级联上分效率更高!

文 | Sherry 不是小哀集成模型(Ensemble)可以提升模型的精度,但往往面临提升计算量的困境,用级联模型(Cascade)在预测时提前中断则可解决计算量的问题。最近,谷歌和CMU的研究者对此进行了深入的…

LeetCode 219. 存在重复元素 II(哈希)

文章目录1. 题目2. 解题1. 题目 给定数组nums和常数k&#xff0c;存在不同的i、j使得nums[i] nums[j]&#xff0c;且abs(i-j) < k。 输入: nums [1,2,3,1], k 3 输出: true 示例 2:输入: nums [1,0,1,1], k 1 输出: true 示例 3:输入: nums [1,2,3,1,2,3], k 2 输出…

Android静态代码扫描效率优化与实践

背景与问题 DevOps实践中&#xff0c;我们在CI(Continuous Integration)持续集成过程主要包含了代码提交、静态检测、单元测试、编译打包环节。其中静态代码检测可以在编码规范&#xff0c;代码缺陷&#xff0c;性能等问题上提前预知&#xff0c;从而保证项目的交付质量。Andro…

还在用[CLS]?从BERT得到最强句子Embedding的打开方式!

文&#xff1a;涅生编&#xff1a;兔子酱你有尝试从 BERT 提取编码后的 sentence embedding 吗&#xff1f;很多小伙伴的第一反应是&#xff1a;不就是直接取顶层的[CLS] token的embedding作为句子表示嘛&#xff0c;难道还有其他套路不成&#xff1f;nono&#xff0c;你知道这…

论文浅尝 | BERT:Pre-training of Deep Bidirectional Transformers

论文笔记整理&#xff1a;王春培&#xff0c;天津大学硕士。链接&#xff1a;https://arxiv.org/pdf/1810.04805.pdf动机将预训练语言表示应用于下有任务现有两种策略&#xff1a;基于特征的和基于微调的。文章认为当前技术限制了预训练的能力&#xff0c;尤其是基于微调的方法…

欺诈检测相关论文

欺诈检测相关论文一、分类1、GEM2、HACUD3、MAHINDER4、Semi-GNN5、MvMoE6、AMG-DP7、AddGraph8、NetWalk9、DOMINANT10、GraphConsis11、PC-GNN12、TRUST二、类别不平衡一、分类 1、GEM 来自蚂蚁金服的论文&#xff0c;他们提出GEM模型&#xff0c;是一个异质图神经网络方法&a…