硬核推导Google AdaFactor:一个省显存的宝藏优化器

一只小狐狸带你解锁炼丹术&NLP秘籍

作者:苏剑林(来自追一科技,人称“苏神”)

前言

自从GPT、BERT等预训练模型流行起来后,其中一个明显的趋势是模型越做越大,因为更大的模型配合更充分的预训练通常能更有效地刷榜。不过,理想可以无限远,现实通常很局促,有时候模型太大了,大到哪怕你拥有了大显存的GPU甚至TPU,依然会感到很绝望。比如GPT2最大的版本有15亿参数,最大版本的T5模型参数量甚至去到了110亿,这等规模的模型,哪怕在TPU集群上也没法跑到多大的batch size。

这时候通常要往优化过程着手,比如使用混合精度训练(tensorflow下还可以使用一种叫做bfloat16的新型浮点格式),即省显存又加速训练;又或者使用更省显存的优化器,比如RMSProp就比Adam更省显存。本文则介绍AdaFactor,一个由Google提出来的新型优化器,首发论文为《Adafactor: Adaptive Learning Rates with Sublinear Memory Cost》。

AdaFactor具有自适应学习率的特性,但比RMSProp还要省显存,并且还针对性地解决了Adam的一些缺陷。

Adam

首先我们来回顾一下常用的Adam优化器的更新过程。设为迭代步数,为当前学习率,是损失函数,是待优化参数,则是防止溢出的小正数,那么Adam的更新过程为

要省显存,就首先得知道显存花在哪里的。首先,计算量和显存的大头肯定都是,也就是说,计算梯度是很费资源的,这也是为啥“ALBERT相比BERT参数量虽然少了那么多,但训练速度也没见快多少”的原因了;除此之外,显存的消耗主要是了,我们要维护两组缓存变量,来滑动计算梯度的前两阶矩(也就是),用以计算参数的更新量。这两组变量每一组都跟训练参数本身一样大,因此对于参数比较多的模型,两组缓存变量所消耗的显存也不少。

AdaFactor

在这一节中,我们会相对详细地介绍一些AdaFactor优化器,介绍中会设计比较多的公式和推导。如果只求一个大致了解的读者,可以自行跳过部分数学内容~

抛弃动量

我们知道,CV模型很多时候要靠“SGD+动量”来炼出最优效果来,自适应学习率优化器通常训练不出最好的效果。但对于NLP模型来说,情况有点相反,自适应学习率显得更重要一些,很少听到由纯靠SGD调NLP模型的案例。因此,作为省显存的第一步,我们可以抛弃Adam里边的动量,这样就少一组缓存参数了,自然也就省了显存:

这其实就是RMSProp的变种,比RMSProp多了这一步。

低秩分解

去掉之后,缓存变量直接减少了一半,但AdaFactor还不满意,它希望保留自适应学习率功能,但把缓存变量的参数量再压一压。这一次,它用到了矩阵的低秩分解。

广义KL散度

在SGD中,所有参数都是共用一个标量学习率;在Adam中,则是每一个参数都有自己的学习率。我们知道通过精调学习率,SGD其实也能有不错的效果,这表明“每一个参数都有自己的学习率”这件事情都不是特别重要,或者换一种说法,就是“精调每一个参数自己的学习率”并不是特别重要。

这启发我们,将换一种参数更少的近似可能也就足够了。而“参数更少的近似”,我们就不难想到低秩分解了。对于的矩阵,我们希望找到的矩阵的矩阵,使得

足够小时,的参数总量就小于的参数量。为了“省”到极致,AdaFactor直接让,即寻找,使得

既然要近似,就要有一个度量的标准。很容易想到的标准是欧氏距离,即

但在这个距离之下,并没有解析解;此外,在优化过程中(即)是非负的,而通过上述目标优化出来的无法保证非负,因此很可能扰乱优化过程。原论文的作者们很机智地换了一个度量标准,使得有解析解。具体来说,它使用了“广义KL散度”,又称“I散度”,其形式为:

这个度量源自不等式,当且仅当时等号成立。所以代入,然后两端乘以,我们有

当且仅当成立,如果有多个分量,那么对多个分量的结果求和即可,这就得到了度量。显然,广义KL散度是概率的KL散度的自然推广,但它不要求满足归一化,只要求它们非负,这正好对应了AdaFactor的场景。而且巧妙的是,这种情形配上这个目标,刚好有解析解:

其实这个解析解也很形象,就是行、列分别求和,然后相乘,再除以全体的和。

推导过程

直接对求偏导数并让偏导数等于0,得

整理得

注意到如果是一组最优解,那么也是,说白了,所有的乘以一个常数,所有的也除以这个常数,是不变的。那么我们就可以随意指定,因为它们就只是一个缩放标量而已。不失一般性,我们指定,那么就解得

直观理解

我们也可以从另一个角度理解结果。由于是非负的,我们可以将它归一化,变成具有概率分布的特性,即,然后我们试图完成分解,由于现在相当于一个二元联合概率分布,那么就相当于它们的边缘分布,即

现在还需要乘上一个,我们可以把它乘到中,不失一般性,我们假设乘到上,那么就得到

AdaFactor雏形

有了结果后,我们就可以用它来构建更省内存的优化器了,这就是AdaFactor的雏形。简单来说,当参数是普通一维向量时,优化过程保持不变;但的矩阵时,算出来的梯度也是矩阵,从而也是矩阵,这时候我们对做低秩分解,然后维护两组缓存变量,分别滑动平均低秩分解后的结果,最后用共同调整学习率:

(把加到上去而不是上去,这是AdaFactor整出来的形式,不是笔者的锅~).

滑动权重

在Adam以及上述AdaFactor雏形中,滑动权重都是恒为常数,AdaFactor指出这是不科学的,并提出新的策略。

等价形式

为了认识到这一点,我们重写一下Adam的的更新过程:

所以如果设,那么更新公式就是

问题是这个够不够合理呢?答案是可能不大够。当,这时候就是,也就是用实时梯度来校正学习率,这时候校正力度最大;当时,,这时候是累积梯度平方与当前梯度平方的加权平均,由于,所以意味着当前梯度的权重不为0,这可能导致训练不稳定,因为训练后期梯度变小,训练本身趋于稳定,校正学习率的意义就不大了,因此学习率的校正力度应该变小,并且,学习率最好恒定为常数(这时候相当于退化为SGD),这就要求时,

新的衰减策略

为了达到这个目的,AdaFactor采用如下的衰减策略

它满足。但即便如此,也不是任何都适合,必须有好理解,那为什么要呢?原论文包含了对它的分析,大家可以去读读,但笔者觉得原论文的推导过于晦涩,所以这里给出自己的理解。

首先,对于来说,一个很容易想到的方案是所有梯度平方的平均,即:

所以这等价于让。这个方案美中不足的一点是,每一步梯度都是平权的,这不符合直觉,因为正常来说越久远的梯度应该越不重要才对,所以应该适当降低历史部分权重,而当时,,因此一个简洁的方案是在式中取,AdaFactor默认的

层自适应

最后,我们还可以进一步根据参数的模长来校正更新量,这个思路来自LAMB优化器,在之前的文章《6个派生优化器的简单介绍及其实现》中也介绍过。简单来说,它就是将最后的更新量标准化,然后乘以参数的模长,说白了,就是不管你怎么折腾,最后的更新量我只要你的方向,而大小由参数本身的模长和预先设置学习率共同决定,使得所有层所有参数的相对变化程度保持一致。

AdaFactor完整版

至此,我们终于可以写出完整版AdaFactor的更新过程了:

其中是模长的变种,这一步相当于做了个截断,即时才执行归一化。原论文中的默认参数为

如果参数是一维向量而不是矩阵,那么使用普通的更新公式就行了。此外,论文还提出如果没有传入学习率,那么可以使用为默认学习率,但笔者看源码的时候发现这个默认学习率很少使用,基本上还是需要自己传入学习率的。

开源实现

为了方便大家使用,笔者开源了自己实现的AdaFactor:

https://github.com/bojone/adafactor

开源包括纯keras版和tf.keras版,使用方法跟普通keras优化器一样,tf.keras版也可以当做一个普通的tensorflow优化器使用。开源实现参考了mesh_tensorflow版的源码,在此表示感谢。优化器也已经内置在bert4keras中,方便大家调用。

需要提醒的是,用AdaFactor的时候,batch_size最好大一些,因为本身低秩分解会带来误差,而如果batch_size过小,那么梯度估算本身也带来较大的误差,两者叠加优化过程可能还不收敛。对于预训练模型来说,batch_size通常还是很大的,所以现在不少预训练模型开始用AdaFactor优化器了;对于普通的下游任务来说,AdaFactor也可以尝试,但可能需要多炼炼丹,才能搞出由于无脑Adam的效果。

文章小结

本文介绍了Google提出来的AdaFactor优化器,一个旨在减少显存占用的优化器,并且针对性地分析并解决了Adam的一些缺陷。笔者认为,AdaFactor针对Adam所做的分析相当经典,值得我们认真琢磨体味,对有兴趣研究优化问题的读者来说,更是一个不可多得的分析案例。

当然,没有什么绝对能有效的方法,有的只是方法虽好,要想实际有效,依然要用心炼丹。 

夕小瑶的卖萌屋

_

关注&星标小夕,带你解锁AI秘籍

订阅号主页下方「撩一下」有惊喜哦

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/480627.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

领域应用 | 用知识图谱玩唐诗,“唐诗别苑”附庸端午节的别样风雅!

本文转载自公众号:互联网教育国家工程实验室 。 端午节在每年的农历五月初五,又称端阳节、午日节、五月节等。端午节起源于中国,是古代百越一带崇拜龙图腾的部族举行图腾祭祀的节日。五月初五也是缅…

参加完阿里蚂蚁金服Java中间件6轮面试题!6点血泪总结~

蚂蚁金服一面:分布式架构 50分钟 1、个人介绍加项目介绍20分钟 2、微服务架构是什么,它的优缺点? 3、ACID CAP BASE理论 4、分布式一致性协议,二段、三段、TCC,优缺点 5、RPC过程 6、服务注册中心宕机了怎么办&am…

特定热点事件监控与分析项目

EventMonitor Event monitor based on online news corpus built by Baidu search enginee using event keyword for event storyline and analysis,基于给定事件关键词,采集事件资讯,对事件进行挖掘和分析。 项目地址:https://g…

深度好文:2018 年 NLP 应用和商业化调查报告

、 深度好文:2018 年 NLP 应用和商业化调查报告 Debra 阅读数:7650 2019 年 1 月 11 日近年来,自然语言处理技术已经取得了长足进步,成为应用范围最广泛,也是最为成熟的 AI 技术之一。但实际上,自然语言处理…

论文浅尝 | 通过多原型实体指称向量关联文本和实体

Cao Y,Huang L, Ji H, et al. Bridge Text and Knowledge by Learning Multi-Prototype Entity Mention Embedding[C]// Meeting of the Association for Computational Linguistics. 2017:1623-1633.导读:学术界近两年来十分关注如何将文本等非结构化数据和知识库等…

【面试必备】奉上最通俗易懂的XGBoost、LightGBM、BERT、XLNet原理解析

一只小狐狸带你解锁 炼丹术&NLP 秘籍在非深度学习的机器学习模型中,基于GBDT算法的XGBoost、LightGBM等有着非常优秀的性能,校招算法岗面试中“出镜率”非常高。这些经典的机器学习算法不仅是数据科学竞赛神器,在工业界中也被广泛地使用。…

2019手把手教你Java面试通关BAT

金三银四俗称跳槽黄金季,很多同学都想趁着这段时间拿高薪,去更牛逼的公司工作,认识更多大牛,提升自己的职场竞争力。 那怎样才能通过BAT面试官的考核?怎样成为一名Offer收割机? 之前讲过收割Offer有一个最…

特定领域因果事件图谱构建项目

CausalityEventExtraction self complement of templated based causality event extraction 基于因果关系知识库的因果事件图谱构建demo 项目地址:https://github.com/liuhuanyong/CausalityEventGraph 项目介绍 现实社会是个逻辑社会,大量的逻辑即逻…

斯坦福李纪为博士毕业论文:让机器像人一样交流

https://cloud.tencent.com/developer/article/1120019 选自GitHub机器之心编译自然语言处理(NLP)是人工智能领域下的一个庞大分支,其中面临很多机遇与挑战。斯坦福大学李纪为博士在他的毕业论文《Teaching Machines to Converse》中对 NLP 领…

陈华钧 | 知识图谱构建,将成为智能金融的突破口

本文转载自公众号:恒生技术之眼。“ 我们太容易被机器下棋这样的事所吸引,以至于现在谈到人工智能就基本都是在说机器学习和深度学习,而相对忽视了与人工智能相关的另外一个重要的方向:知识图谱。——陈华钧”尽管人工智能依靠机器…

万字长文梳理CTR点击预估模型发展过程与关系图谱

背景在推荐、搜索、广告等领域,CTR(click-through rate)预估是一项非常核心的技术,这里引用阿里妈妈资深算法专家朱小强大佬的一句话:“它(CTR预估)是镶嵌在互联网技术上的明珠”。本篇文章主要…

基于携程游记的出行领域顺承事件图谱项目

EvolutionaryEventGraph 项目地址:https://github.com/liuhuanyong/SequentialEventExtration Evolutionary Event Graph based on Travel note crawled from XieCheng,基于50W携程出行攻略的顺承事件抽取与事件图谱构建. 项目来源 目前,以谓词性短语…

5步教你成功求职进入BAT

有读者朋友希望我能写一部分关于BAT内部的文章,比如,怎么进入BAT,BAT内部的项目的流程,有挑战性的项目实践,大概是怎么样的? 我希望用这篇文章开启整个进入BAT系列篇,让大家更好的了解BAT内部的…

机器阅读理解任务综述

http://forum.yige.ai/thread/27 2016年 <div class"markdown-body" id"emojify">作者&#xff1a;林鸿宇 韩先培 简介 自然语言处理的长期目标是让计算机能够阅读、处理文本&#xff0c;并且理解文本的内在含义。理解&#xff0c;意味着计算机在接…

论文浅尝 | 基于知识图谱子图匹配以回答自然语言问题

Citation: Hu,S., Zou, L., Yu, J. X., Wang, H., & Zhao, D. (2018). Answering natural language questions by subgraph matching over knowledge graphs. IEEE Transactions on Knowledge & Data Engineering, PP(99), 1-1.动机对于基于知识图谱的事实性问答&#…

新闻文本内容知识图谱表示项目

TextGrapher 项目地址&#xff1a;https://github.com/liuhuanyong/TextGrapher Text Content Grapher based on keyinfo extraction by NLP method。输入一篇文档&#xff0c;将文档进行关键信息提取&#xff0c;进行结构化&#xff0c;并最终组织成图谱组织形式&#xff0c;…

BAT Java面试完整汇总:面试准备(心态+简历)+面试题目+6条面试经验

今天分享的BAT面试完整内容主要包含&#xff1a; 面试前的心态准备&#xff08;3点建议&#xff09; 技术硬实力包含的范围&#xff08;50题目&#xff09; 个人简历突出和优化&#xff08;3点优化步骤&#xff09; 个人软实力的提升&#xff08;6点提升维度&#xff09; B…

算法工程师的效率神器——vim篇

一只小狐狸带你解锁炼丹术&NLP秘籍我相信&#xff0c;有很多小伙伴在看到这篇文章时就有了很多问号&#xff1a;用vim&#xff1f;疯了吧&#xff1f;sublime不香吗&#xff1f;pycharm不香吗&#xff1f;jupyter notebook不香吗&#xff1f;我这可是最新版的windows 100操…

论文浅尝 | 端到端神经视觉问答之上的显式推理

链接&#xff1a;http://www.public.asu.edu/~cbaral/papers/2018-aaai-psl.pdf概述视觉问答(Visual Question Answering)现有两大类主流的问题, 一是基于图片的视觉问答(ImageQuestion Answering), 二是基于视频的视觉问答( Video Question Answering).而后者在实际处理过程中…

机器阅读理解首次超越人类!云从刷新自然语言处理新纪录

媒体动态发展历程资质荣誉人才招聘机器阅读理解首次超越人类&#xff01;云从刷新自然语言处理新纪录 2019-03-11 10:06 浏览&#xff1a;454 近日&#xff0c;云从科技和上海交通大学在自然语言处理领域取得重大突破&#xff0c;在卡内基-梅隆大学发起的大型深层阅读理解任务数…