AI大模型的推理显存占用分析

了解Transformer架构的AI大模型显存占用是非常重要的,特别是在训练和推理过程中。以下是详细解释和分析这些组成部分及其影响的专业描述:

1 显存占用

1.1 模型本身参数

模型的参数包括所有的权重和偏置项,这些参数需要存储在显存中,以便在训练和推理过程中进行计算。

  • 占用字节:每个FP32参数占用4个字节,每个FP16参数占用2个字节。
  • 计算:模型参数数量(例如,BERT-base模型大约有110M参数)。如果使用FP32表示,则总显存占用为 110M * 4 bytes

1.2 模型的梯度动态值

在训练过程中,每个模型参数都有对应的梯度值,这些梯度用于更新模型参数。梯度存储同样需要显存。

  • 占用字节:梯度和模型参数类型相同,所以FP32梯度占用4个字节,FP16梯度占用2个字节。
  • 计算:梯度存储显存占用与模型参数相同,例如,如果模型参数使用FP32,则梯度显存占用为 参数数量 * 4 bytes

1.3 优化器参数

优化器(如Adam)在训练过程中需要存储额外的参数,如一阶动量和二阶动量。这些参数也需要显存来存储。

  • Adam优化器:存储m和v两个参数,即需要2倍的模型参数量。
  • 占用字节:每个FP32参数占用4个字节,每个FP16参数占用2个字节。
  • 计算:例如,使用Adam优化器和FP32表示,则优化器参数显存占用为 2 * 参数数量 * 4 bytes

1.4 模型的中间计算结果

在前向传播和反向传播过程中,需要存储每一层的中间计算结果,这些结果用于反向传播的求导。这些中间结果的显存占用与批量大小(batch size)、序列长度(sequence length)和每层的输出维度(hidden size)有关。

  • 前向传播:每一层的输入x和输出都需要存储。
  • 反向传播:中间结果的计算图不会被释放,以便计算梯度。
  • 占用字节:这部分的显存占用难以精确计算,但可以通过调整batch size和sequence length来估算显存差值。
  • 计算方法:常用的方法是实验性地调整batch size和sequence length,观察显存变化来估算中间结果的显存占用。

1.5 KV Cache

在推理过程中,尤其是在自回归模型(如GPT)中,需要缓存先前计算的键和值(Key和Value)以加速计算。这些缓存需要显存来存储。

  • 占用字节:这部分的显存占用与输入的序列长度、批量大小和注意力头数有关。
  • 计算方法:具体计算公式取决于模型的架构和缓存策略。

不同的参数类型所占的字节对比表

类型所占字节
FP324
FP162
INT81

2 具体示例

假设我们有一个Transformer模型,其架构和超参数如下:

层数(layers):12
隐藏层大小(hidden_size):768
注意力头数(num_heads):12
词汇表大小(vocab_size):30522
最大序列长度(sequence_length):512
批量大小(batch_size):1
数据类型:FP32(每个参数4字节)

为了具体计算一个具有上述参数的Transformer模型在推理时的显存占用,我们需要考虑以下几个部分:

  1. 模型本身的参数
  2. 输入和输出激活值
  3. 中间计算结果
  4. KV Cache

2.1 模型本身的参数

嵌入层
  • 词嵌入矩阵vocab_size * hidden_size
    [
    30522 \times 768 = 23440896 \text{ 个参数}
    ]
  • 位置嵌入矩阵sequence_length * hidden_size
    [
    512 \times 768 = 393216 \text{ 个参数}
    ]

嵌入层总参数:
[
23440896 + 393216 = 23834112 \text{ 个参数}
]

Transformer 层

每层的主要参数包括:

  • 注意力层的 Q, K, V 权重和偏置
    [
    3 \times (hidden_size \times hidden_size) = 3 \times (768 \times 768) = 1769472 \text{ 个参数}
    ]
  • 输出权重和偏置
    [
    hidden_size \times hidden_size = 768 \times 768 = 589824 \text{ 个参数}
    ]
  • 前馈网络(两层)
    [
    2 \times (hidden_size \times 4 \times hidden_size) = 2 \times (768 \times 4 \times 768) = 4718592 \text{ 个参数}
    ]

每层总参数:
[
1769472 + 589824 + 4718592 = 7077888 \text{ 个参数}
]

12层总参数:
[
12 \times 7077888 = 84934656 \text{ 个参数}
]

总参数数量

模型总参数数量:
[
23834112 + 84934656 = 108768768 \text{ 个参数}
]

每个FP32参数占用4个字节:
[
108768768 \times 4 = 435075072 \text{ 字节} = 435.08 \text{ MB}
]

2.2 输入和输出激活值

假设模型在推理时的输入和输出激活值为 batch_size * sequence_length * hidden_size,对于每个层的激活值也相同。

每层激活值:
[
batch_size \times sequence_length \times hidden_size = 1 \times 512 \times 768 = 393216 \text{ 个元素}
]

每个FP32激活值占用4个字节:
[
393216 \times 4 = 1572864 \text{ 字节} = 1.57 \text{ MB}
]

2.3 中间计算结果

由于反向传播不需要考虑推理时的显存占用,我们可以忽略这部分。

2.4 KV Cache

在推理过程中,需要缓存每一层的键和值(Key和Value):

每层的KV Cache占用:
[
2 \times batch_size \times sequence_length \times hidden_size = 2 \times 1 \times 512 \times 768 = 786432 \text{ 个元素}
]

每个FP32值占用4个字节:
[
786432 \times 4 = 3145728 \text{ 字节} = 3.14 \text{ MB}
]

12层的KV Cache总占用:
[
12 \times 3.14 \text{ MB} = 37.68 \text{ MB}
]

2.5 总显存占用

[
\text{模型参数} + \text{输入和输出激活值} + \text{KV Cache}
]

显存占用计算:

  • 模型参数:435.08 MB
  • 激活值:1.57 MB(每层)× 12层 = 18.84 MB
  • KV Cache:37.68 MB

[
\text{总显存占用} = 435.08 \text{ MB} + 18.84 \text{ MB} + 37.68 \text{ MB} = 491.60 \text{ MB}
]

在推理过程中,一个具有上述配置的Transformer模型大约需要491.60 MB的显存。这一估算没有包括额外的显存开销,例如模型加载时的一些临时数据结构和框架本身的开销。实际使用中,可能还需要一些额外的显存来处理这些开销。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/17884.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

四川景源畅信:新人做抖店的成本很高吗?

随着社交媒体的兴起,抖音成为了一个新兴的电商平台——抖店。不少创业者和商家看中了其庞大的用户基础,想要通过开设抖店来拓展销路。然而,对于刚入行的新手来说,成本问题总是让人犹豫不决。究竟新人做抖店的成本高不高?本文将围…

ML307R OpenCPU TCP使用

一、TCP通信流程 二、示例 三、TCP通信代码 一、TCP通信流程 ML307R TCP 是使用LWIP的标准的socket通信,具体TCP流程可以自行百度 二、示例 实验目的:实现把接收的数据再发送到服务端 测试网址:TCP电脑端测试网址 因为是4G,所以必须用外网的 /* 测试前请先补充如下…

Flutter 中的 CupertinoDatePicker 小部件:全面指南

Flutter 中的 CupertinoDatePicker 小部件:全面指南 在 Flutter 中,CupertinoDatePicker 是 Cupertino 组件库的一部分,它提供了一个 iOS 风格的日期选择器。这个选择器允许用户选择日期和时间,非常适合需要符合 iOS 设计指南的应…

YOLOv10:实时端到端目标检测

Ao Wang Hui Chen∗  Lihao Liu Kai Chen Zijia Lin  Jungong Han Guiguang Ding Tsinghua University Corresponding Author. 文献来源:中英文对照阅读 摘要 在过去的几年里,YOLO 因其在计算成本和检测性能之间的有效平衡而成为实时目标检测领…

纯干货:做好数据库防泄密的关键

在当今数字化时代,数据库的安全与保密性对于企业和个人来说至关重要。数据库防泄密工作涉及到多种技术和策略,其中沙盒技术作为一种强大的安全机制,为数据库防泄密提供了新的可能性。那么,我们是否可以通过沙盒来实现数据库防泄密…

2024年5月22日 (周三) 叶子游戏新闻

《奇星协力》Steam抢先体验开启 求生城市建造Leikir Studio工作室开发的一款求生城市建造新游《奇星协力》Steam抢先体验开启,限时九折优惠,本作支持中文,感兴趣的玩家可以关注下了。 《原神》预告4.7版本前瞻特别节目 5月24日播出5月22日&am…

Qt 控件提升

什么是控件提升(Widget Promotion) 控件提升是一个在Qt编程中常见但容易被忽视的概念。简单来说,控件提升就是将一个基础控件(Base Widget)转换为一个更特定、更复杂的自定义控件(Custom Widget)。这样做的目的是为了在设计界面时能够使用更多高级功能,而不仅仅是Qt库提…

基于FPGA实现LED的闪烁——HLS

基于FPGA实现LED的闪烁——HLS 引言: ​ 随着电子技术的飞速发展,硬件设计和开发的速度与效率成为了衡量一个项目成功与否的关键因素。在传统的硬件开发流程中,工程师通常需要使用VHDL或Verilog等硬件描述语言来编写底层的硬件逻辑&#xff0…

springboot517基于SpringBoot+Vue的高校线上心理咨询室的设计与实现-手把手调试搭建

springboot517基于SpringBootVue的高校线上心理咨询室的设计与实现-手把手调试搭建 springboot517基于SpringBootVue的高校线上心理咨询室的设计与实现-手把手调试搭建-2024-3-17

基于Python实现可视化分析中国500强排行榜数据的设计与实现

基于Python实现可视化分析中国500强排行榜数据的设计与实现 “Design and Implementation of Visual Analysis for China’s Top 500 Companies Ranking Data using Python” 完整下载链接:基于Python实现可视化分析中国500强排行榜数据的设计与实现 文章目录 基于Python实现…

Docker 基础使用 (1)

文章目录 Docker 软件安装Docker 镜像仓库Docker 仓库指令Docker 镜像指令Docker 容器指令Docker 使用实例 —— 搭建 nginx 服务nginx 概念nginx 使用用 docker 启动 nginx 侧重对docker基本使用的概览。 Docker 软件安装 Linux Ubuntu 依次执行以下指令即可 # 更新软件包列…

第十二周 5.20 面向对象的三大特性(封装、继承、多态)(一)

一、封装 1.目前的程序无法保证数据的安全性、容易造成业务数据的错误 2.private:私有的,被private修饰的内容只能在本类中访问 3.为私有化的属性提供公开的get和set方法 (1)get方法,获取私有化属性的值: public 返回值类型 get属性名…

[SWPUCTF 2022 新生赛]奇妙的MD5... ...

目录 [SWPUCTF 2022 新生赛]奇妙的MD5 [GDOUCTF 2023]受不了一点 [LitCTF 2023]作业管理系统 注入点一:文件上传 注入点二:创建文件直接写一句话木马 注入点三:获取数据库备份文件 [LitCTF 2023]1zjs [SWPUCTF 2022 新生赛]奇妙的MD5 …

生成式AI的GPU网络技术架构

生成式AI的GPU网络 引言:超大规模企业竞相部署拥有64K GPU的大型集群,以支撑各种生成式AI训练需求。尽管庞大Transformer模型与数据集需数千GPU,但实现GPU间任意非阻塞连接或显冗余。如何高效利用资源,成为业界关注焦点。 张量并…

单调栈--

1.每日温度 那么单调栈的原理是什么呢?为什么时间复杂度是O(n)就可以找到每一个元素的右边第一个比它大的元素位置呢? 单调栈的本质是空间换时间,因为在遍历的过程中需要用一个栈来记录右边第一个比当前元素高的元素,优点是整个数…

利用迭代方法求解线性方程组(Matlab)

一、问题描述 利用迭代方法求解线性方程组。 二、实验目的 掌握Jacobi 方法和Gauss-Seidel 方法的原理,能够编写代码实现两种迭代方法;能够利用代码分析线性方程组求解中的误差情况。 三、实验内容及要求 用代码实现:对下列方程中重新组织…

基于盲源分离和半盲源分离的心电信号伪影消除方法(MATLAB 2018)

心电信号是通过测量放置在人体皮肤上的电极之间的电位差来获取的,其本身具有信号微弱、频段低、不稳定等特性。因此ECG信号在实际采集时极易受到不同噪声的影响,这会造成心电图本身的波形形态特征的失真,从而导致错误诊断和对患者的不当治疗。…

2024年5月软考成绩什么时候出?附查询方式

2024年5月软考成绩查询时间及查询方式: 查询时间:预计在2024年7月上旬进行。 查询方式: 方式一:登陆中国计算机技术职业资格网(www.ruankao.org.cn),点击报名系统,输入注册账号和…

echart图表legend每列固定宽度

修改前: 修改后: 关键代码: 设置一个背景并使之透明,否则宽度不生效,配合formatter使用 formatter: {a|{name}},rich:{a: {width: 48,fontSize: 12,backgroundColor: "rgba(11, 39, 52, 0)" // 关键代码&a…

C++语法|thread_local详解

文章内容全部来自: 【C入门到进阶 多线程 thread_local 关键字】 【CPU眼里的:thread_local】 简介 thread_local 是一个关键字,它用来修饰变量,并被他修饰的变量有以下特征: 它指示对象拥有线程静态存储期 线程存储…