transformer系列5---transformer显存占用分析

Transformer显存占用分析

  • 1 影响因素概述
  • 2 前向计算临时Tensor显存占用
    • 2.1 self-attention显存占用
    • 2.2 MLP显存占用
  • 3 梯度和优化器显存占用
    • 3.1 模型训练过程两者显存占用
    • 3.2 模型推理过程两者显存占用

1 影响因素概述

  1. 模型训练框架:例如pytorch框架的cuda context会占用大约几百MB显存,与版本有关;
  2. 模型参数大小,比如7B的模型以FP16格式要占用14GB显存;
  3. 前向计算过程中产生的临时Tensor:这部分Tensor需要被临时保存,以便在反向传播计算梯度时使用
  4. 反向传播计算得到的梯度:
  5. 优化器状态:全量微调的情况下,梯度与参数一样大,普通SGD没有动量,一阶动量优化器的自身参数大小与模型大小一样,比如momentum-SGD,二阶动量优化器一般为模型大小的两倍,比如Adam, transformer系列的大模型最常用的是Adam优化器

2 前向计算临时Tensor显存占用

2.1 self-attention显存占用

这部分Tensor的大小和模型的每一层结构形状有关(必须根据具体模型的每层形状来计算)也和具体的batch_size大小以及输入数据input_data的大小有关。

  1. 输入矩阵I:首先计算 Q = I ∗ W q Q =I * W^{q} Q=IWq K = I ∗ W k K = I * W^{k} K=IWk V = I ∗ W v V = I * W^{v} V=IWv,输入I是临时Tensor,假设输入I的形状为 [b, s, d],元素个数为 bsd,占用显存大小为2bytes*bsd=2bsd bytes.
  2. Q K T QK^{T} QKT:Q和K是临时Tensor,假设形状为 [b, s, d],元素个数为 bsd,占用显存大小为22bytesbsd=4bsd bytes。
  3. softmax: A = Q K T A=QK^{T} A=QKT,输入形状[b, h, s, d] × [b, h, s, d],A矩阵输出形状为 [b, h, s, s],h是头个数。保存A矩阵占用的显存大小为=2bytes* b h s 2 bhs^{2} bhs2= 2 b h s 2 2bhs^{2} 2bhs2 bytes。
  4. dropout:需要保存一个mask矩阵,mask矩阵的形状与A相同,mask矩阵的元素为0或1,用1个byte表示,占用显存大小为 b h s 2 bhs^{2} bhs2 bytes。
  5. score* V加权:score矩阵的形状与A相同,占用显存大小为 2 b h s 2 2bhs^{2} 2bhs2 bytes。V矩阵形状[b, s, d],占用显存大小为2bytes*bsd=2bsd bytes。该步骤占用显存大小为 2 b h s 2 + 2 b s d 2bhs^{2}+2bsd 2bhs2+2bsd bytes。
  6. W O W^{O} WO输出映射:需要临时保存输入矩阵,形状[b, s, d],占用显存大小为2bytes*bsd=2bsd bytes。
  7. dropout:需要保存一个mask矩阵,mask矩阵的形状为上一步输出形状[b, s, d],mask矩阵的元素为0或1,用1个byte表示,占用显存大小为1bytes*bsd=bsd bytes。
    综上步骤,self-attention块的占用显存大小为2bsd+4bsd+ 2 b h s 2 2bhs^{2} 2bhs2+ 2 b h s 2 2bhs^{2} 2bhs2+ 2 b h s 2 + 2 b s d 2bhs^{2}+2bsd 2bhs2+2bsd+2bsd+2bsd=11bsd+ 5 b h s 2 5bhs^{2} 5bhs2

2.2 MLP显存占用

  1. 第一个线性层需要保存其输入,输入形状为[b, s, d],占用显存大小为 2bytes*bsd=2bsd bytes。
  2. 激活函数需要保存其输入,为第一步的输出形状为[b, s, 4d],占用显存大小为2bytes*4bsd=8bsd bytes。
  3. 第二个线性层需要保存其输入,输入形状为[b, s, 4d],占用显存大小为2bytes*4bsd=8bsd bytes。
  4. 最后有一个dropout操作,需要保存mask矩阵,形状是上一步的输出形状[b, s, d],mask矩阵的元素为0或1,用1个byte表示,占用显存大小为1bytes*bsd=bsd bytes。

综上步骤,MLP的占用显存大小为2bsd+8bsd+8bsd+bsd=19bsd.

3 梯度和优化器显存占用

3.1 模型训练过程两者显存占用

参数占用显存 = 参数数目 × n
n = 2 : float16
n = 4 : float32
n = 8 : double64
其中,float32是最常用的类型,n是数据类型占用的bytes。
训练过程通常为模型参数前向传播,反向传播计算梯度,优化器更新,以Adam优化器为例分析,假如模型参数量为P:

  1. 混合精度训练:
    1)使用float16的模型参数进行前向传递和反向传播,计算得到float16的梯度;
    2)在优化器更新模型参数时,使用float32的优化器状态、float32的梯度、float32的模型参数来更新模型参数。
    3)对于每个可训练模型参数,模型参数在步骤1)和步骤2)分别是2bytes,4bytes;梯度在步骤1)和步骤2)分别是分别是2bytes,4bytes;优化器状态是2* 模型大小=2*4bytes=8bytes。

每个参数占用(2+4)+(2+4)+8 = 20bytes。模型参数量M时总计20P bytes。

  1. 普通训练:
    上述步骤1)2)均使用float32类型。对于每个可训练模型参数,模型参数在步骤1)和步骤2)分别是4bytes,4bytes;梯度在步骤1)和步骤2)分别是分别是4bytes,4bytes;优化器状态是2* 模型大小=2*4bytes=8bytes。

每个参数占用(4+4)+(4+4)+8 = 24bytes,模型参数量M时总计24P bytes。

3.2 模型推理过程两者显存占用

推理占用显存主要是模型参数,假如模型参数量为P,使用float16来进行推理,推理阶段模型参数占用的显存约2P bytes,使用float32来进行推理,推理阶段模型参数占用的显存约 4P bytes。

参考文章:https://zhuanlan.zhihu.com/p/624740065?utm_id=0

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/98613.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

深入了解归并排序:原理、性能分析与 Java 实现

归并排序(Merge Sort)是一种高效且稳定的排序算法,其优雅的分治策略使它成为排序领域的一颗明珠。它的核心思想是将一个未排序的数组分割成两个子数组,然后递归地对子数组进行排序,最后将这些排好序的子数组合并起来。…

一个开源的安卓相机:OpenCamera

原网址 Open Camera download | SourceForge.net 我也上传了一个 https://github.com/quantum6/Android-OpenCamera

TensorFlow入门(十四、数据读取机制(1))

TensorFlow的数据读取方式 TensorFlow的数据读取方式共有三种,分别是: ①预加载数据(Preloaded data) 预加载数据的方式,其实就是静态图(Graph)的模式。即将数据直接内嵌到Graph中,再把Graph传入Session中运行。 示例代码如下: import tensorflow.compat.v1 as tf tf.disabl…

符合 EN55022B 规格、LTM4613EY、LTM4613MPV直流µModule稳压器【RG500Q 5G Sub-6 GHz 模块】

一、LTM4613,符合 EN55022B 规格的 36VIN、15VOUT、8A、DC/DC Module 稳压器 (简介)LTM4613 是一款完整、超低噪声、8A 开关模式 DC/DC 电源。封装中内置了开关控制器、功率 FET、电感器和所有的支持元件。LTM4613 的工作输入电压范围为 5V 至…

基于maven的项目搭建(已跑通)

1、直接选择archetype-webapp即可 (这里很多人会觉得很慢–解决方案:https://blog.csdn.net/qq_45591895/article/details/133705674?spm1001.2014.3001.5501) 2、手动添加一个java目录即可。 3、添加Tomcat 3、这就跑通了,可以…

Python 樱花

Python实现樱花 效果图 (源码在下面) 源码: from turtle import * from random import * from math import *def tree(n, l):pd() # 下笔# 阴影效果t cos(radians(heading() 45)) / 8 0.25pencolor(t, t, t)pensize(n / 3)forward(l…

dp=[[0]*n]*m 和dp = [[0] * n for _ in range(m)]的区别是什么?

在力扣刷题需要二维数组时,我选的方法是dp[[0]*n]*m,结果想修改一行的值,全二维数组都被更改了,不懂,查chat: 1.dp[[0]*n]*m: 这种方式创建了一个包含 m 行和 n 列的二维列表 dp,并将…

基于 FPGA 的机器博弈五子棋游戏

基于 FPGA 的机器博弈五子棋游戏 一,设计目的 五子棋是一种深受大众喜爱的游戏,其规则简单,变化多端,非常富有趣味性 和消遣性。棋类游戏在具备娱乐性、益智性的同时也因为其载体大多是手机, 电脑等移动互联网设备导致现代社会低头族等现象更加严重,危害青少年的身 体健康…

ThreeJS-3D教学五-材质

我们在ThreeJS-3D教学二&#xff1a;基础形状展示中有简单介绍过一些常用的材质&#xff0c;这次我们举例来具体看下效果&#xff1a; 代码是这样的&#xff1a; <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8">&…

记一次生产环境cdh6.3.2集群yarn组件nodemanager节点down掉的事故分析

有关2023.10.2日发现的yarn部分nodeManager组件节点不可用的原因分析 yarn组件异常情况始于2023.09.30日06时00分&#xff0c;恢复于2023.10.02日10点35分。每日凌晨6点&#xff0c;大数据定时任务&#xff1a;task1启动&#xff0c;该任务持续时长1小时20~25分钟左右&#xf…

计算机的进制转换

复习一下二进制的理论知识。 计算机为什么要用二进制表示一切数据&#xff1f; 因为2个数可以表示一切&#xff0c;而且电极高低2种对硬件人员制作来说比较友好。 于是0就表示为空&#xff0c;1就表示有东西。 十进制转换为二进制 二进制是由0和1组成的&#xff0c;如01,000…

java多线程卖电影票的三种实现方式

java多线程卖电影票的三种实现方式 一、需求描述二、实现方式1、继承Thread类的方式2、实现Runnable接口的方式3、使用Lock锁的方式 一、需求描述 某电影院目前正在上映国产大片&#xff0c;共有1000张票&#xff0c;而它有2个窗口卖票&#xff0c;请设计一个程序模拟该电影院…

【ARM CoreLink 系列 1 -- SoC 片上互联介绍】

文章目录 概述1.1 片上互连架构的发展1.1.1 BUS 共享总线结构1.1.2 Crossbar 结构1.1.3 Ring 结构1.1.4 Mesh 网格结构 1.2 ARM 总线互联特点小结1.2.1 NOC 总线互联的特点 下篇文章&#xff1a;【ARM CoreLink 系列 1.1 – CoreLink 系列 产品介绍】 概述 在摩尔定律的推动下…

【Linux基础】Linux的基本指令使用(超详细解析,小白必看系列)

&#x1f449;系列专栏&#xff1a;【Linux基础】 &#x1f648;个人主页&#xff1a;sunnyll 目录 &#x1f4a6; ls 指令 &#x1f4a6; pwd指令 &#x1f4a6;cd指令 &#x1f4a6;touch指令 &#x1f4a6;mkdir指令&#xff08;重要&#xff09; &#x1f4a6;rmdir指令…

智慧工地:数字革命下的建筑业新趋势

在当今建筑领域&#xff0c;智慧工地正迅速崭露头角。这个概念不仅代表了技术进步&#xff0c;还预示着建筑行业的数字化和智能化未来。从多个角度来看&#xff0c;智慧工地都具有深远的意义&#xff0c;它正在改变着我们建筑的方式和未来。 提高工程效率 智慧工地利用物联网&…

c 语言基础题目:PTA L1-030 一帮一

“一帮一学习小组”是中小学中常见的学习组织方式&#xff0c;老师把学习成绩靠前的学生跟学习成绩靠后的学生排在一组。本题就请你编写程序帮助老师自动完成这个分配工作&#xff0c;即在得到全班学生的排名后&#xff0c;在当前尚未分组的学生中&#xff0c;将名次最靠前的学…

正点原子嵌入式linux驱动开发——Linux内核启动流程

上一篇笔记学习了Linux内核的顶层Makefile&#xff0c;现在来看Linux内核的大致启动流程&#xff0c;Linux内核的启 动流程要比uboot复杂的多&#xff0c;涉及到的内容也更多&#xff0c;因此本章就大致的了解一Linux内核的启动流程。 链接脚本vmlinux.lds 要分析Linux启动流…

区块链(10):java区块链项目的Web服务整体实现

根据上篇文章的HttpServer进行修改。 1 区块链的查询服务的web实现 public class BlocksServlet extends HttpServlet {@Overrideprotected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {resp.setCharacterEncoding(…

2023年10月9日历史上的今天大事件早读

1740年10月09日红溪惨案 1874年10月09日万国邮政联盟成立 1912年10月09日第一次巴尔干战争爆发 1913年10月09日武昌起义元勋蒋翊武被害 1924年10月09日近代翻译家林纾(林琴南)逝世 1934年10月09日南斯拉夫国王遇刺身亡 1936年10月09日红军三大主力会师 1941年10月09日第…

电子招标投标系统 —采购招投标管理一体化系统-

项目说明 随着公司的快速发展&#xff0c;企业人员和经营规模不断壮大&#xff0c;公司对内部招采管理的提升提出了更高的要求。在企业里建立一个公平、公开、公正的采购环境&#xff0c;最大限度控制采购成本至关重要。符合国家电子招投标法律法规及相关规范&#xff0c;以及审…