一训练就显存爆炸?Facebook 推出 8 比特优化器,两行代码拯救你的显存!

9771ea9b5a85e28abc680a178cf23925.png

文 | jxyxiangyu
编 | 小轶

“小夕,小夕!又出来了个 SOTA 模型!赶紧 follow !”

小夕看了看新模型的参数量, 然后看了看实验室服务器的几张小破卡。

小夕,陷入了沉默。

自从人们发现越大的模型性能越好后,神经网络模型的参数量就在越来越大的道路上一去不复返了。从XX-large到GPT3,再到5300亿参数的Megatron Turing-NLG,深度学习越来越像是只有财大气粗的大公司才能玩得起的玩具。如果,我们想要在实验室“简陋”的环境下,尝试更大的模型,有什么行之有效的方法呢?

最近,Facebook 推出了支持 pytorch 的 8 位优化器在减小内存占用的同时,竟然还能保持和32位优化器相当的准确性。不得不说 facebook yyds。那么,下面就让我们一起来看看具体是怎么做的吧。

论文题目:
8-BIT OPTIMIZERS VIA BLOCK-WISE QUANTIZATION

论文链接:
https://arxiv-download.xixiaoyao.cn/pdf/2110.02861.pdf

开源链接:
https://github.com/facebookresearch/bitsandbytes

d4bdb0483fb314d23eb9b4e42fea75fb.png量化50372d1bfc20b28e12965ef6c8c3cc82.png

在介绍论文作者的解决方法之前,先补充一点关于量化的基本概念。通常意义上来说,量化是指将信号的连续取值近似为有限多个离散值的过程。具体到计算机系统,指的是将浮点数值映射到低bit数值的操作[1]

一般来说,我们可以通过以下手段应用量化

  1. 量化模型参数来压缩模型;

  2. 量化模型某些层的激活值来减少内存占用*;

(注:参数和梯度也会占用一定的内存空间,但相对于激活值而言,占用比例不大,一般来说,量化参数和梯度带来的内存收益没有量化激活值的大)

bb50f0e0c49956114edd13e60405706f.png
▲量化示意图

上图是 Song Han 在 ICLR'2016 上提出的量化方法。将模型参数分别聚类到几个质心,并将参数量化到对应的质心,在更新参数时,是将同一质心对应的梯度累加用于更新该质心对应的参数

可以看到,量化通过将参数(浮点值)映射到二值、三值或线性量化到一个区间(一般是低比特数值)的方式,减小了模型大小,在某些硬件上面,低比特数值运算速度高于浮点数值,一定程度上可以加速模型的训练和预测;除此之外,模型在训练和预测的时候,模型参数本身只占用了内存的一小部分,大部分存储了模型的激活值,如果将量化应用到激活值上,一定程度也减小了内存占用,这样,我们就可以尝试更大的模型和设置更大的mini-batch了。

当然,量化这么好,也不是没有缺点的,量化后的模型或多或少会引入精度损失;并且目前学术界多采用 pytorch 框架,好巧不巧的是 pytorch 框架对量化的支持没有 tensorflow 好,这总不能为了体验大模型的快感再转到 tensorflow 上面去吧,想想 tensorflow 混乱的 api 就头疼(╯°Д°)╯ ┻━┻

8fed30e9ced9615977e6c2634f9d465f.png

最近,Facebook 推出了支持 pytorch 的 8 位优化器,在减小内存占用的同时,竟然还能保持和32位优化器相当的准确性。

状态优化器

再简单介绍下带有状态的优化器(stateful optimizer)。和普通的随机梯度下降(SGD)相比,为了加速优化而提出的带有梯度统计信息的优化器,就是状态优化器。常见的例如带动量的 SGD 和 Adam 。计算公式如下:4e30e714a56a84b9eb6f3c657a905c9d.png其中, 和 是平滑因子, 是非常小的常量, 是学习率。

作者认为,状态优化器会维护历史梯度数据,一定程度上占用了内存。通过量化这些梯度,可以有效地降低内存占用

非线性量化

前述已经介绍了量化就是将信号的连续取值近似为有限多个离散值的过程,在降低模型参数量的同时,也会带来一定的精度损失,为减小精度损失,多采用非线性的量化方式,大致可以归纳为三个步骤:

  1. 对于输入张量,计算归一化因子;

  2. 将张量通过归一化后,找到在量化空间中距离最近的值;

  3. 将量化后的张量的每个元素的索引存储下来

那么,我们就可以遍历存储的索引并通过下式得到反量化张量:,其中,是反量化映射

为了使不同元素值量级一致,一般会将张量归一化到的区间范围,这时,取的是输入张量中绝对值的最大值,即,然后通过二分查找的方式找到量化空间中距离该值最近的量化值a5f024e5985f3480aa114fccb197fdec.png

动态树量化

上一节看到,非线性量化在归一化时会严重依赖输入张量中的最值,像某些特别大或特别小的异常值,对量化会产生较大的精度影响。动态树量化(dynamic tree quantization)就是一种以较低的量化精度损失处理这种情况的方法。e71a7f605527968a3e81c2758c4fed73.png

与浮点数的存储方式类似,动态树量化以这类方式解释存储在内存中的数值,以此实现量化,具体由四部分组成:

  1. 首位是符号位

  2. 符号位后连续的0的数量表示指数大小

  3. 再之后的第一位是指示位,如果指示位为1表示后续剩余的位为线性量化区域

  4. 线性量化区域

其中,指示位是可以动态移动的。通过移动指示位,可以表示指数为或者精度为的数值,表示范围为

0e52e725e8505e837aab66394661f394.png8位优化器71591bb3331e39c9e0047b2dde3de6f9.png

有了前面的知识铺垫,下面就可以详细地说明作者提出的8位优化器了。该8位优化器由三部分构成:

  1. 逐块量化(block-wise quantization)

  2. 动态量化(dynamic quantization)

  3. 稳定的词嵌入层(stable embedding layer)

应用上述组件,将8位优化器的状态反量化为32位并更新状态和参数,然后将这些状态量化回8位进行存储。由于是在寄存器中进行8位和32位的转换,一定程度上可以减小内存占用并加速训练。

逐块量化

常见的量化需要将原始的张量在张量级别归一化,这样可能会引入核之间的多次信息通信和同步,造成额外的时间开销,而逐块量化则是将张量分成多个小块并在块级别归一化,减小了核之间的通信开销,除此之外,还可以将张量元素中的异常值的影响限制在单个块中。假设为有个元素的张量,分成每个大小为的块,那么,可以分成个块,对每个块分别做归一化,归一化因子为,每个块分别通过下式进行量化操作:其中,为块索引,为块中元素的索引

动态量化

8位优化器的动态量化部分是对前面提到的动态树量化的扩展,对于像的第二个状态这种严格为正的数值,符号位就显得有些多余,而在语言模型的训练过程中,作者发现的变化范围在3~5个数量级,小于动态树量化的7个数量级,因此,可以用固定的位将只会用到的位划分开,进一步减小内存占用。对于其他带符号的状态张量,则继续使用动态树量化。

稳定的词嵌入层

为了确保nlp任务中模型的稳定训练,作者还添加了稳定的词嵌入层。作者使用Xavier uniform对词嵌入层进行初始化,并且在与位置向量合并前进行层归一化操作,这样可以使参数在初始化和训练期间保持1左右的方差。词嵌入层的优化器状态用32位存储,权重和梯度用16位存储。

a202bf69891a718662f2c82236f03c4d.png8位优化器 vs 32位优化器62a64fb77fd8d8b5eedfe3d7ab2b2efe.png

作者在多个任务(包括机器翻译、大规模语言模型的预训练以及微调、图像分类和图像预训练以及微调)上比较了8位优化器和32位优化器的性能,比较的优化器包括、或,实验中,除了将32位优化器替换为8位优化器外,没有改动超参和权重、梯度以及激活值的精度。

除了GLUE任务之外,其余的NLP任务均使用了作者提出的稳定词嵌入层。为确保实验结果的可信度,还在不同随机数种子下多次实验,选取了实验结果的中位数作为最终性能。实验结果如下所示:de56f82da2dd72746126a8aee69ee2cc.png

可以看到,8位优化器在多个任务上均达到甚至是超过了32位优化器的性能,与此同时,还能大幅减小内存开销加速训练。此外,作者还列出了在同等显存大小的条件下,使用8位优化器和32位优化器可以支持训练的模型。可以说,非常贴心了ヾ(●゜ⅴ゜)ノ03505ccc8ddf9851d9d000eae748ae09.png

dac9f003bf2f326909ecabe3cff6ed67.png消融研究e0d33ed365528aa15c454a01c544f4b3.png

作者基于语料库训练了多个模型,用于研究8位优化器中各个组件的影响。实验结果如下:90dde65ddd7d1fcabc3ea3d72dce9bac.png其中,32位优化器(baseline)采用的是线性量化。

为测试优化器的稳定性,对于小规模的模型,作者分别训练了不同的超参数下的模型,超参数为 {1e-8, 1e-7, 1e-6}, {0.90, 0.87, 0.93}, {0.999, 0.99, 0.98}以及学习率方面的改动,而对于超过1B的大规模模型,则是在相同超参下采用不同的随机数种子多次运行。所有的结果均是选择可以成功训练完(没有因梯度爆炸或弥散而无法训练)的模型性能的中值。

可以看出,逐块量化、动态量化和稳定的词嵌入层对结果都有正向影响

此外,作者还对比了32位优化器和8位优化器对超参的敏感程度,比较了32位和8位优化器在、、和的变化下的走势4d1d8be47516fe494dda8432132ddb5d.png

可以看到,8位优化器和32位相比,困惑度走势基本一致,表明对超参不敏感,在将32位优化器替换为8位优化器后,超参不需要进一步的调整

f91e6f830338363b6ed8062561397706.png局限性1a5faed5ded027ad1cf97867bc4ac7a5.png

从实验结果可以看出,8位优化器完全可以作为32位优化器的替代品。当然,8位优化器也存在一些局限性:

  1. 8位优化器需要稳定的词嵌入层来达到32位优化器的性能;

  2. 8位优化器减小内存的大小与模型参数量成正比,对于像cnn这种激活值比参数占内存多得多的模型,8位优化器并没有特别明显的内存减小,反而更适合transformer这种架构的大规模模型

d0b53f44f4f5a5ce3fb091e72855e379.png总结5811043a2fc7745c7427e51fb538152f.png

不得不说,Facebook的8位优化器简直是我等“穷困”炼丹党的福音。现在,8位优化器已经开源,开源地址已经在文章开头提到。目前,8位优化器已经支持Adam, AdamW, RMSProp, LARS, LAMB优化器。使用时,需要安装并导入包bitsandbytes-cudaXXX,其中,XXX是本地环境的cuda工具包版本号,注释掉原有的优化器,调用8位优化器就可以了。

import bitsandbytes as bnb# adam = torch.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.995)) # comment out old optimizer
adam = bnb.optim.Adam8bit(model.parameters(), lr=0.001, betas=(0.9, 0.995)) # add bnb optimizer
adam = bnb.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.995), optim_bits=8) # equivalenttorch.nn.Embedding(...) ->  bnb.nn.StableEmbedding(...) # recommended for NLP models

据官网描述,仅仅需要改动两行代码,就可以节省75%的内存!小伙伴们,还不想抓紧时间上车体验一下嘛?

970967267d1544988180a0db3d23658c.png
▲没时间解释了,快上车

0f514e4371daccdace52d0eb85cb6246.png后台回复关键词【入群

加入卖萌屋NLP/IR/Rec与求职讨论群

后台回复关键词【顶会

获取ACL、CIKM等各大顶会论文集!

f3fe03552ce4deaf19e8def01ad759bc.gif 266c7e30d54fca3929e7f1ae278c0699.png

[1] 商汤科技SenseTime, 模型量化了解一下?(https://zhuanlan.zhihu.com/p/132561405)

[2] Song, H. ,  H. Mao , and  W. J. Dally . "Deep Compression: Compressing Deep Neural Networks with Pruning, Trained Quantization and Huffman Coding." ICLR 2016. (https://arxiv-download.xixiaoyao.cn/pdf/1510.00149.pdf)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/477890.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

论文浅尝 | 基于正交普鲁克分析的高效知识图嵌入学习

笔记整理:朱渝珊,浙江大学在读博士,研究方向为快速知识图谱的表示学习,多模态知识图谱。1.Motivation知识图谱是许多NLP任务和下游应用的核心,如问答、对话代理、搜索引擎和推荐系统。知识图中存储的事实总是以元组的形…

LeetCode 979. 在二叉树中分配硬币(DFS)

文章目录1. 题目2. DFS 解题1. 题目 给定一个有 N 个结点的二叉树的根结点 root,树中的每个结点上都对应有 node.val 枚硬币,并且总共有 N 枚硬币。 在一次移动中,我们可以选择两个相邻的结点,然后将一枚硬币从其中一个结点移动…

有福利! 好书推荐:从《实用推荐系统》学习寻找用户行为之法

大多数关于推荐系统的图书都讲述了算法及其优化方法。这些书都认为你已经有了一个大的数据集来供算法使用。数据集不会像变魔术那样凭空出现。要想收集到正确的用户偏好数据,就需要投入精力和进行思考。它会成就你的系统,或者搞砸你的系统。“垃圾进&…

灵活强大的构建系统Gradle

前言 构建,软件生命周期中重要的一环,在现代软件开发过程中,起着越来越重要的作用。过去在Java或类Java的世界里,Ant、Maven再熟悉不过了,Maven凭借其强大的依赖配置战胜Ant,基本上成为了Java构建的标准。而…

LeetCode 791. 自定义字符串排序(map)

1. 题目 字符串S和 T 只包含小写字符。在S中,所有字符只会出现一次。 S 已经根据某种规则进行了排序。我们要根据S中的字符顺序对T进行排序。更具体地说,如果S中x在y之前出现,那么返回的字符串中x也应出现在y之前。 返回任意一种符合条件的…

6万字解决算法面试中的深度学习基础问题

文 | 清卢雨源 | 对白的算法屋前言真的是千呼万唤始出来emmmm,去年春招结束写了篇面试的经验分享。在文中提到和小伙伴整理了算法岗面试时遇到的常见知识点及回答,本想着授人以渔,但没想到大家都看上了我家的 !但因本人执行力不足…

OpenKG开源系列 | 海洋鱼类百科知识图谱(浙江大学)

OpenKG地址:http://openkg.cn/dataset/ocean开放许可协议:CC BY-SA 4.0贡献者:浙江大学(徐雅静、邓鸿杰、唐坤、郑国轴)1、背景海洋是生命的摇篮,是人类文明的重要发祥地,在人类社会发展的进程中起着举足轻重的作用。海…

Presto实现原理和美团的使用实践

Facebook的数据仓库存储在少量大型Hadoop/HDFS集群。Hive是Facebook在几年前专为Hadoop打造的一款数据仓库工具。在以前,Facebook的科学家和分析师一直依靠Hive来做数据分析。但Hive使用MapReduce作为底层计算框架,是专为批处理设计的。但随着数据越来越…

图谱实战 | 徐美兰:深度应用驱动的医学知识图谱构建

转载公众号 | DataFunSummit分享嘉宾:徐美兰 浙江数字医疗卫生技术研究院 数字医学知识中心主任编辑整理:李杰 京东出品平台:DataFunTalk导读:数研院这些年在知识图谱建设上取得了丰硕成果,今天我们将图谱构建过程中的…

6 年大厂面试官,谈谈我对算法岗面试的一些看法

文 | 不敢透露姓名的 Severus 和小轶面试官坐在那撇着大嘴的,“咳,给你一机会,最短的时间内让我记住你。”这个我会,我抡圆了“啪!”,扭头我就走。我刚到家,录取通知书就来了,请你务…

美团Android自动化之旅—生成渠道包

每当发新版本时,美团团购Android客户端会被分发到各个应用市场,比如豌豆荚,360手机助手等。为了统计这些市场的效果(活跃数,下单数等),需要有一种方法来唯一标识它们。 团购客户端目前通过渠道号…

开源开放 | 细粒度可循证医学文档知识融合表示和推理(CCKS2021)

OpenKG地址:http://openkg.cn/dataset/mdo-dataset开放许可协议:GPL 3.0贡献者:武汉科技大学(高峰、龚珊珊、顾进广、徐芳芳)摘要本开放资源在医学文档知识的基础上,使用知识图谱相关技术,解决了…

图灵奖大佬 Lecun 发表对比学习新作,比 SimCLR 更好用!

文 | Rukawa_Y编 | 智商掉了一地,Sheryc_王苏比 SimCLR 更好用的 Self-Supervised Learning,一起来看看吧!Self-Supervised Learning作为深度学习中的独孤九剑,当融汇贯通灵活应用之后,也能打败声名在外的武当太极剑。…

5whys分析法在美团工程师中的实践

前言 网站的质量和稳定性对于用户和公司来说至关重要,但是在网站的快速发展过程中,由于各种原因导致事故不可避免的发生,这些大大小小的事故对公司难免会造成一些负面的影响,为了避免同类事故的再次发生,美团的工程师们…

LeetCode 382. 链表随机节点(概率)

1. 题目 给定一个单链表,随机选择链表的一个节点,并返回相应的节点值。保证每个节点被选的概率一样。 进阶: 如果链表十分大且长度未知,如何解决这个问题?你能否使用常数级空间复杂度实现? 来源:力扣&am…

图谱实战 | 斯坦福黄柯鑫:图机器学习在生物图上的应用

转载公众号 | DataFunSummit分享嘉宾:黄柯鑫 斯坦福大学 博士生编辑整理:元玉蒲 西北大学出品平台:DataFunTalk导读:大家好,我叫黄柯鑫。我现在是斯坦福大学的计算机科学博士第一年级,研究方向是机器学习在…

排得更好VS估得更准VS搜的更全「推荐、广告、搜索」算法间到底有什么区别?...

文 | 王喆源 | 王喆的机器学习笔记作为互联网的核心应用“搜广推”,三个方向基本都是互联网公司的标配。各头部公司的搜广推系统也都各自发展成了集成了多种模型、算法、策略的庞然大物,想一口气讲清楚三者的区别并不容易。不过万事总有一个头绪&#xf…

Solr Facet技术的应用与研究

问题背景 在《搜索引擎关键字智能提示的一种实现》一文中介绍过,美团的CRM系统负责管理销售人员的门店(POI)和项目(DEAL)信息,提供统一的检索功能,其索引层采用的是SolrCloud。在用户搜索时,如果能直观地给出每个品类的POI数目&am…

LeetCode 129. 求根到叶子节点数字之和(DFS)

1. 题目 给定一个二叉树,它的每个结点都存放一个 0-9 的数字,每条从根到叶子节点的路径都代表一个数字。 例如,从根到叶子节点路径 1->2->3 代表数字 123。 计算从根到叶子节点生成的所有数字之和。 说明: 叶子节点是指没有子节点的…

推荐精排之锋:FM的一小步,泛化的一大步

文 | 水哥源 | 知乎1.如果说LR是复读机,那么FM可以算作是电子词典2.泛化就是我没见过你,我也能懂你,但是泛化有时候和个性化有点矛盾,属于此消彼长的关系3.实践中的泛化往往来源于拆解,没见过组成的产品,但…