不到1000行代码,PyTorch团队让Llama 7B提速10倍

在过去的一年里,生成式 AI 发展迅猛,在这当中,文本生成一直是一个特别受欢迎的领域,很多开源项目如 llama.cpp、vLLM 、 MLC-LLM 等,为了取得更好的效果,都在进行不停的优化。

作为机器学习社区中最受欢迎框架之一的 PyTorch,自然也是抓住了这一新的机遇,不断优化。为此让大家更好的了解这些创新,PyTorch 团队专门设置了系列博客,重点介绍如何使用纯原生 PyTorch 加速生成式 AI 模型。

图片

代码地址:https://github.com/pytorch-labs/gpt-fast

在第一篇博客中,PyTorch 团队展示了仅使用纯原生 PyTorch 重写 Segment Anything(SAM)模型,比原始实现快 8 倍。在本博客中,他们又为我们带来了新的内容,即如何加快 LLM 推理。

我们先来看看结果,该团队重写 LLM,推理速度比基线足足快了 10 倍,并且没有损失准确率,只用了不到 1000 行的纯原生 PyTorch 代码!

图片

所有基准测试都在 A100-80GB 上运行的,功率限制在 330W。

这些优化包括:

  • Torch.compile:PyTorch 模型编译器, PyTorch 2.0 加入了一个新的函数,叫做 torch.compile (),能够通过一行代码对已有的模型进行加速;

  • GPU 量化:通过降低运算精度来加速模型;

  • Speculative Decoding:一种大模型推理加速方法,使用一个小的「draft」模型来预测大的「目标」模型的输出;

  • 张量并行:通过在多个设备上运行模型来加速模型推理。

接下来,我们看看每一步都是如何实现的。

6 步加快大模型推理

该研究表示,在没有优化之前,大模型的推理性能为 25.5 tok/s,效果不是很好:

图片

经过一番探索后终于找到了原因:CPU 开销过大。然后就有了下面的 6 步优化过程。

图片

第一步:通过 Torch.compile 和静态 KV 缓存减少 CPU 开销,实现 107.0 TOK/S

torch.compile 允许用户将更大的区域捕获到单个编译区域中,特别是在 mode=”reduce-overhead” 时(参考下面的代码),这一功能对于减少 CPU 开销非常有效,除此以外,本文还指定 fullgraph=True,用来验证模型中没有「图形中断」(即 torch.compile 无法编译的部分)。

图片

然而,即使有 torch.compile 的加持,还是会遇到一些障碍。

第一个障碍是 kv 缓存。即当用户生成更多的 token 时, kv 缓存的「逻辑长度(logical length)」会增长。出现这种问题有两个原因:一是每次缓存增长时重新分配(和复制)kv 缓存的成本非常高;其次,这种动态分配使得减少开销变得更加困难。

为了解决这个问题,本文使用静态 KV 缓存,静态分配 KV 缓存的大小,然后屏蔽掉注意力机制中未使用的值。

图片

第二个障碍是 prefill 阶段。用 Transformer 进行文本生成可视为一个两阶段过程:1. 用来处理整个提示的 prefill 阶段 2. 解码 token.

尽管 kv 缓存被设置为静态化,但由于提示长度可变 ,prefill 阶段仍然需要更多的动态性。因此,需要使用单独的编译策略来编译这两个阶段。

图片

虽然这些细节有点棘手,但实现起来并不困难,而且性能的提升是巨大的。这一通操作下来,性能提高了 4 倍多,从 25 tok/s 提高到 107 tok/s。

图片

第二步:通过 int8 权重量化缓解内存带宽瓶颈,实现 157.4 tok /s

通过上文,我们已经看到应用 torch.compile 、静态 kv 缓存等带来的巨大加速,但 PyTorch 团队并不满足于此,他们又找了其他角度进行优化。

他们认为加速生成式 AI 训练的最大瓶颈是将权重从 GPU 全局内存加载到寄存器的代价。换句话说,每次前向传播都需要「接触(touch)」GPU 上的每个参数。那么,理论上我们能够以多快的速度「接触」模型中的每个参数?

图片

为了衡量这一点,本文使用模型带宽利用率(MBU),计算它非常简单,如下所示:

图片

举例来说,对于一个 7B 参数模型,每个参数都存储在 fp16 中(每个参数 2 字节),可以实现 107 tokens/s。A100-80GB 理论上有 2 TB/s 的内存带宽。

如下图所示,将上述公式带入具体的数值,可以得到 MBU 为 72%!这个结果是相当不错的,因为很多研究很难突破 85%。

图片

但 PyTorch 团队还想将这个数值在提高一些。他们发现无法改变模型中参数的数量,也无法改变 GPU 的内存带宽。但他们发现可以更改每个参数存储的字节数!

图片

因此,他们打算用 int8 量化。 

图片

请注意,这仅是量化权重,计算本身仍然在 bf16 中完成。此外,有了 torch.compile,可以轻松生成 int8 量化的高效代码。

图片

图片

就像上图所展示的,从深蓝色线(torch.compile + int8)可以看出,使用 torch.compile + int8 仅权重量化时,性能有显着提升。

将 int8 量化应用于 Llama-7B 模型,性能提高了约 50%,达到 157.4 tokens/s。

图片

第三步:使用 Speculative Decoding

即使在使用了 int8 量化等技术之后,该团队仍然面临着另一个问题,即为了生成 100 个 token,必须加载权重 100 次。

图片

即使权重被量化,一遍又一遍地加载权重也避免不了,这种问题该如何解决呢?事实证明,利用 speculative decoding 能够打破这种严格的串行依赖性并获得加速。

图片

该研究使用草稿(draft)模型生成 8 个 token,然后使用验证器模型并行处理,丢弃不匹配的 token。这一过程打破了串行依赖。整个实现过程大约 50 行原生 PyTorch 代码。

图片

第四步:使用 int4 量化和 GPTQ 方法进一步减小权重,实现 202.1 tok/s

本文发现,当权重为 4-bits 时,模型的准确率开始下降。

图片

为了解决这个问题,本文使用两个技巧来解决:第一个是拥有更细粒度的缩放因子;另一种是使用更先进的量化策略。将这些操作组合在一起,得到如下:

图片

第五步:将所有内容组合在一起,得到 244.7 tok/s

最后,将所有技术组合在一起以获得更好的性能,得到 244.7 tok/s。

图片

第六步:张量并行性

到目前为止,本文一直是在单个 GPU 上最大限度地减少延迟。其实,使用多个 GPU 也是可以的,这样一来,延迟现象会得到进一步改善。

非常庆幸的是,PyTorch 团队提供了张量并行的低级工具,只需 150 行代码,并且不需要任何模型更改。

图片

前面提到的所有优化都可以继续与张量并行性组合,将这些组合在一起,能以 55 tokens/s 的速度为 Llama-70B 模型提供 int8 量化。

图片

最后,简单总结一下文章主要内容。在 Llama-7B 上,本文使用「compile + int4 quant + speculative decoding」这一套组合拳,实现 240+ tok/s。在 Llama-70B,本文还通过引入张量并行性以达到约 80 tok/s,这些都接近或超过 SOTA 性能。

图片

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/199925.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

面试就是这么简单,offer拿到手软(四)—— 常见java152道基础面试题

面试就是这么简单,offer拿到手软(一)—— 常见非技术问题回答思路 面试就是这么简单,offer拿到手软(二)—— 常见65道非技术面试问题 面试就是这么简单,offer拿到手软(三&#xff…

WIN10下解决HIVE 初始化MYSQL表报错:Unknown version specified for initialization

今天本地WINDOWS装HIVE,走到最后一步初始化数据库死活不通过: D:\hive\hive-rel-release-3.1.3\bin\ext>hive --service schematool -dbType mysql -initSchema --verbose SLF4J: Class path contains multiple SLF4J bindings. SLF4J: Found bind…

flask 上传文件

from flask import Flask, request, render_template,redirect, url_for from werkzeug.utils import secure_filename import os from flask import send_from_directory # send_from_directory可以从目录加载文件app Flask(__name__)#UPLOAD_FOLDER media # 注意&#xff…

深入理解强化学习——马尔可夫决策过程:占用度量-[代码实现]

分类目录:《深入理解强化学习》总目录 在文章《深入理解强化学习——马尔可夫决策过程:占用度量-[基础知识]》我们介绍了占用度量的基础知识,本文我们编写代码来近似估计占用度量。这里我们采用近似估计,即设置一个较大的采样轨迹…

会声会影2024购买多少钱 会声会影在哪里购买

掌握视频编辑技术,能为我们的工作和生活带来很多帮助。例如:将我们精心编辑的视频,上传到抖音、快手等平台进行变现;通过天马行空的视频创意,摇身一变成为B站up主。因此,拥有一款像会声会影这样的视频编辑软…

信号可靠性剖析

问题 基于信号发送的进程间通信方式可靠吗??? 信号查看(kill -l) 信号的分类 不可靠信号 (传统信号) 信号值在 [1, 31] 之间的所有信号 可靠信号 (实时信号) 信号值在 [SIGRTMIN,SIGRTMAX],即:[34&…

计算机组成原理学习-总线总结

复习本章时,思考以下问题: 1)引入总线结构有什么好处?2)引入总线结构会导致什么问题?如何解决?

Squid安装与配置(ip代理)

继前面一篇Tinyproxy安装与配置(ip代理),在实际使用中会发现在请求一些网站时会被拒绝,那是因为Tinyproxy其实不支持所谓的高匿代理。所以这次用功能更加丰富的squid试试。 1、安装 yum install squid -y yum install httpd-tools -y2、配置 1、备份原…

变电站设计综合应用软件

产品概述 变电站设计综合应用软件,以下称为软件,是一款面向智能变电站虚拟二次回路设计和光纤回路设计的单机版桌面应用软件。软件为用户提供了直观易用、一键安装、功能齐全的轻量级的设计支撑。像常规的工具化软件一样,该软件在开始设计时需要通过新建一个项目,开启一段…

Oracle初始化参数文件pfile和spfile

pfile :Oracle 9i之前,ORACLE一直采用PFILE方式存储初始化参数,该文件为文本文件,可以在操作系统级别修改。当spfile文件修改出现错误导致oracle无法启动时,可以使用 pfile文件启动数据库 spfile:从Oracle…

Python 元组详解(tuple)

文章目录 1 概述1.1 性质1.2 下标1.3 切片 2 常用方法2.1 访问:迭代、根据下标2.2 删除:del2.3 运算符:、*2.4 计算元组中元素个数:len()2.5 返回元组中元素最大值:max()2.6 返回元组中元素最小值:min()2.7…

Linux快速给用户改密码

由于巡检过程中需要修改部分用户名密码,这些强密码包含大小写、数字和特殊符号,完全没有规律,让我手动输是不可能的,于是使用以下命令来输入,但是为了不在history里面留下痕迹,所以先关闭了历史命令功能&am…

【C++】const关键字的详解!!

💐 🌸 🌷 🍀 🌹 🌻 🌺 🍁 🍃 🍂 🌿 🍄🍝 🍛 🍤 📃个人主页 :阿然成长日记 …

Android MVVM+coroutine+retrofit+flow+hilt

文章目录 Android MVVMcoroutineretrofitflowhilt概述依赖注入层数据层视图层模型视图层代码下载 Android MVVMcoroutineretrofitflowhilt 概述 代码结构: 依赖注入层 数据库: Module InstallIn(SingletonComponent::class) class DBModule {Singleto…

数据结构与算法-D2D3线性表之顺序表

线性表:包含若干数据元素的一个线性序列,特征如下: 1)对非空表,a0是表头,无前驱; 2)an-1是表尾,无后继; 3)其他元素仅且仅有一个前驱,…

【PTA题目】6-1 猴子吃桃-递归 分数 10

6-1 猴子吃桃-递归 分数 10 全屏浏览题目 切换布局 作者 ZZULI 单位 郑州轻工业大学 小猴子第一天摘下桃子若干,当即吃掉一半,还不过瘾,又多吃一个,第二天又将剩下的桃子吃掉一半多一个,以后每天吃掉前一天剩下的一…

完美洗牌问题学习笔记

问题背景 完美洗牌:一副 52 张的排序好的扑克牌,从中间分为两半,每部分各 26 张。假设每次都分为左右两部分,然后将 右部分的牌和左部分的牌按顺序交错穿插,每张左部分牌后面加入一张右部分的,依序加入所有…

基于YOLOv8深度学习的120种犬类检测与识别系统【python源码+Pyqt5界面+数据集+训练代码】目标检测、深度学习实战、狗类检测、犬种识别

《博主简介》 小伙伴们好,我是阿旭。专注于人工智能、AIGC、python、计算机视觉相关分享研究。 ✌更多学习资源,可关注公-仲-hao:【阿旭算法与机器学习】,共同学习交流~ 👍感谢小伙伴们点赞、关注! 《------往期经典推…

【“C++ 精妙之道:解锁模板奇谭与STL精粹之门“】

【本节目标】 1. 泛型编程 2. 函数模板 3. 类模板 4. 什么是STL 5. STL的版本 6. STL的六大组件 7. STL的重要性 8. 如何学习STL 9.STL的缺陷 1. 泛型编程 如何实现一个通用的交换函数呢? void Swap(int& left, int& right) {int temp left;lef…

odoo自定义提示性校验

背景: 在odoo16的原生的代码里,可以给按钮添加一个 confirm属性,从而达到 提示性校验的效果。 问题: 这个属性加了之后一定会弹出提示性校验的对话框,于是如何根据我们的实际业务,从后端返回提示性信息,…