Transformer 相关模型的参数量计算

如何计算Transformer 相关模型的参数量呢?
先回忆一下Transformer模型论文《Attention is all your need》中的两个图。
在这里插入图片描述
在这里插入图片描述

设Transformer模型的层数为N,每个Transformer层主要由self-attention 和 Feed Forward组成。设self-attention模块的head个数为 n h e a d n_{head} nhead,每一个head对应的维度为 d h e a d d_{head} dhead,self-attention输出维度为 d m o d e l = n heads ⋅ d head d_{model}= n_\text{heads}\cdot d_\text{head} dmodel=nheadsdhead。我们可以得到一个Transformer层的参数量为 12 d m o d e l 2 + 13 d m o d e l 12 d_{model}^2 + 13 d_{model} 12dmodel2+13dmodel,具体如下:

  • self-attention块的模型参数有Q、K、V的权重矩阵 W Q 、 W K 、 W V W_Q、W_K 、W_V WQWKWV和偏置,输出矩阵 W O W_O WO及其偏置。这4个权重矩阵的大小为 [ d m o d e l , d m o d e l ] [d_{model}, d_{model}] [dmodel,dmodel],4个偏置的大小为 [ d m o d e l ] [d_{model}] [dmodel],所以self-attention块的参数量为 4 d m o d e l 2 + 4 d m o d e l 4 d_{model}^2 + 4 d_{model} 4dmodel2+4dmodel

  • Feed Forward块一般由2个线性层组成,第一个线性层将维度从 d m o d e l d_{model} dmodel 映射成 4 d m o d e l 4d_{model} 4dmodel, 其权重矩阵 W 1 W_1 W1的大小为 [ d m o d e l , 4 d m o d e l ] [d_{model}, 4d_{model}] [dmodel,4dmodel] ,其偏置的大小为 [ 4 d m o d e l ] [4d_{model}] [4dmodel]。 第二个线性层将维度从 4 d m o d e l 4d_{model} 4dmodel 映射成 d m o d e l d_{model} dmodel,其权重矩阵 W 2 W_2 W2的大小为 [ 4 d m o d e l , d m o d e l ] [4d_{model}, d_{model}] [4dmodel,dmodel] ,其偏置的大小为 [ d m o d e l ] [d_{model}] [dmodel]。所以Feed Forward的参数量为 8 d m o d e l 2 + 5 d m o d e l 8 d_{model}^2 + 5 d_{model} 8dmodel2+5dmodel

  • self-attention 和 Feed Forward都跟随着layer normalization,它有两个可训练模型参数,形状都是 [ d m o d e l ] [d_{model}] [dmodel]。所以2个layer normalization的参数量为 4 d m o d e l 4 d_{model} 4dmodel

除了Transformer层之外的参数有:

  • 词embedding矩阵的参数量,embedding的维度通常等于 d m o d e l d_{model} dmodel,设词表的大小为V,则词embedding的参数量为 V d m o d e l Vd_{model} Vdmodel
  • 位置向量相关,有些位置向量表示方式需要学习参数。

所以N层Transformer模型的可训练模型参数量为 N ( 12 d m o d e l 2 + 13 d m o d e l ) + V d m o d e l N(12 d_{model}^2 + 13 d_{model}) + Vd_{model} N(12dmodel2+13dmodel)+Vdmodel。当 d m o d e l d_{model} dmodel较大时,可以忽略一次项,模型参数量近似为 12 N d m o d e l 2 12 N d_{model}^2 12Ndmodel2

最后试验一下模型参数估计量与论文是否对的上,下表是GPT3和LLaMA的计算对比,可以发现数量级是可以对的上的,因为我们忽略了一次项,所以具体数据与论文不一致。

模型名实际参数量 n l a y e r n_{layer} nlayer d m o d e l d_{model} dmodel n h e a d n_{head} nhead d h e a d d_{head} dhead估计参数量
GPT-3175B961228896128173946175488
LLaMA 6.7B6.7B324096321286442450944
LLaMA 13.0B13.0B4051204012812582912000
LLaMA 32.5B32.5B6066565212831897681920
LLaMA 65.2B65.2B8081926412864424509440

参考资料

  1. Transformer 论文(模型图来自论文)、GPT3的论文等

  2. 整理过程中参考的blog: 1. 知乎用户回旋托马斯x 的文章,除了计算量外,还算了计算量、中间激活等 , 2 transformer 参数量计算, 3 flops 计算, 4 transformers 参数量计算公式

  3. transfomers 库如何得到参数量

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/43645.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

linux系统部署jenkins详细教程

一、Linux环境 1、下载war包 官网下载地址: https://get.jenkins.io/war-stable/2.332.4/jenkins.war 2、将war包上传至服务器 创建目录/home/ubuntu/jenkins 上传war包至该目录 3、将jenkins添加到环境变量 进入环境变量文件 vim /etc/profile # 文件下方追加…

【3Ds Max】图形合并命令的简单使用

示例(将文字设置在球体上) 1. 首先这里创建一个球体和一个文本 2. 选中球体,在复合对象中点击图形合并按钮 点击“拾取图形”按钮,然后选中文本,此时可以看到球体上已经投射出文本 3. 接下来是一些常用参数的介绍 当…

从零实战SLAM-第八课(非特征点的视觉里程计)

在七月算法报的班,老师讲的蛮好。好记性不如烂笔头,关键内容还是记录一下吧,课程入口,感兴趣的同学可以学习一下。 --------------------------------------------------------------------------------------------------------…

centos下使用jemalloc解决Mysql内存泄漏问题

参考: MySQL bug:https://bugs.mysql.com/bug.php?id83047&tdsourcetags_pcqq_aiomsg https://github.com/jemalloc/jemalloc/blob/dev/INSTALL.md (1)ptmalloc 是glibc的内存分配管理 (2)tcmalloc…

Django

一 django 安装 1. **安装 Django:** 首先,确保您已经安装了 Python 和 pip(Python 包管理器)。然后,在命令行中运行以下命令来安装 Django: bashpip install Django 2. **创建项目:** …

Electron-builder打包和自动更新

前言 文本主要讲述如何为 electron 打包出来软件配置安装引导和结合 github 的 release 配置自动更新。 electron-builder 是将 Electron 工程打包成相应平台的软件的工具,我的工程是使用 electron-vite 构建的,其默认集成了 electron-builder &#x…

中大型无人机远程VHF语音电台系统方案

方案背景 中大型无人机在执行飞行任务时,特别是在管制空域飞行时地面航管人员需要通过语音与无人机通信。按《无人驾驶航空器飞行管理暂行条例》规定,中大型无人机应当进行适航管理。物流无人机和载人eVTOL都将进行适航管理,所以无人机也要有…

Unity 工具 之 Azure 微软SSML语音合成TTS流式获取音频数据的简单整理

Unity 工具 之 Azure 微软SSML语音合成TTS流式获取音频数据的简单整理 目录 Unity 工具 之 Azure 微软SSML语音合成TTS流式获取音频数据的简单整理 一、简单介绍 二、实现原理 三、实现步骤 四、关键代码 一、简单介绍 Unity 工具类,自己整理的一些游戏开发可…

Qt creator之对齐参考线——新增可视化缩进功能

Qt creator随着官方越来越重视,更新频率也在不断加快,今天无意中发现qt creator新版有了对齐参考线,也称可视化缩进Visualize Indent,默认为启用状态。 下图为旧版Qt Creator显示设置栏: 下图为新版本Qt Creator显示设…

Day14 01-Shell脚本编程详解

文章目录 第一章 Shell编程【重点】1.1. Shell的概念介绍1.1.1. 命令解释器4.1.1.2. Shell脚本 1.2. Shell编程规范1.2.1. 脚本文件的结构1.2.2. 脚本文件的执行 1.3. Shell的变量1.3.1. 变量的用法1.3.2. 变量的分类1.3.3. 局部变量1.3.4. 环境变量1.3.5. 位置参数变量1.3.6. …

Python入门【内存管理机制、Python缓存机制、垃圾回收机制、分代回收机制】(三十二)

👏作者简介:大家好,我是爱敲代码的小王,CSDN博客博主,Python小白 📕系列专栏:python入门到实战、Python爬虫开发、Python办公自动化、Python数据分析、Python前后端开发 📧如果文章知识点有错误…

LeetCode150道面试经典题-- 存在重复元素 II(简单)

1.题目 给你一个整数数组 nums 和一个整数 k &#xff0c;判断数组中是否存在两个 不同的索引 i 和 j &#xff0c;满足 nums[i] nums[j] 且 abs(i - j) < k 。如果存在&#xff0c;返回 true &#xff1b;否则&#xff0c;返回 false 。 2.示例 示例 1&#xff1a; 输…

CSS中的字体属性有哪些值,并分别描述它们的作用。

聚沙成塔每天进步一点点 ⭐ 专栏简介⭐ font-style⭐ font-weight⭐ font-size⭐ font-family⭐ font-variant⭐ line-height⭐ letter-spacing⭐ word-spacing⭐ font⭐ 写在最后 ⭐ 专栏简介 前端入门之旅&#xff1a;探索Web开发的奇妙世界 记得点击上方或者右侧链接订阅本专…

JS中对象数组深拷贝方法

structuredClone() JavaScript 中提供了一个原生 API 来执行对象的深拷贝&#xff1a;structuredClone。它可以通过结构化克隆算法创建一个给定值的深拷贝&#xff0c;并且还可以传输原始值的可转移对象。 当对象中存在循环引用时&#xff0c;仍然可以通过 structuredClone()…

【声波】声波在硼酸、硫酸镁 (MgSO4) 和纯水中的吸收研究(Matlab代码实现)

&#x1f4a5;&#x1f4a5;&#x1f49e;&#x1f49e;欢迎来到本博客❤️❤️&#x1f4a5;&#x1f4a5; &#x1f3c6;博主优势&#xff1a;&#x1f31e;&#x1f31e;&#x1f31e;博客内容尽量做到思维缜密&#xff0c;逻辑清晰&#xff0c;为了方便读者。 ⛳️座右铭&a…

c++ | 字节转换 | 字长 | 机器位数

为什么有的时候脑子转不过来&#xff1f;&#xff1f; 为什么要对字节、机器长啊、位啊都要门清 位数 一般的就是指计算机的位数&#xff0c;比如64位/32位&#xff0c;更简单的理解&#xff0c;计算机就是在不停的做二进制的计算&#xff0c;比如32位计算机&#xff0c;在长…

[保研/考研机试] KY26 10进制 VS 2进制 清华大学复试上机题 C++实现

题目链接&#xff1a; 10进制 VS 2进制http://www.nowcoder.com/share/jump/437195121691738172415 描述 对于一个十进制数A&#xff0c;将A转换为二进制数&#xff0c;然后按位逆序排列&#xff0c;再转换为十进制数B&#xff0c;我们称B为A的二进制逆序数。 例如对于十进制…

4.物联网LWIP之C/S编程

LWIP配置 服务器端实现 客户端实现 错误分析 一。LWIP配置&#xff08;FREERTOS配置&#xff0c;ETH配置&#xff0c;LWIP配置&#xff09; 1.FREERTOS配置 为什么要修改定时源为Tim1&#xff1f;不用systick&#xff1f; 原因&#xff1a;HAL库与FREERTOS都需要使用systi…

C语言好题解析(三)

目录 选择题一选择题二选择题三选择题四编程题一编程题二 选择题一 以下程序段的输出结果是&#xff08;&#xff09;#include<stdio.h> int main() { char s[] "\\123456\123456\t"; printf("%d\n", strlen(s)); return 0; }A: 12 B: 13 …

Lnton羚通关于【PyTorch】教程:torchvision 目标检测微调

torchvision 目标检测微调 本教程将使用Penn-Fudan Database for Pedestrian Detection and Segmentation 微调 预训练的Mask R-CNN 模型。 它包含 170 张图片&#xff0c;345 个行人实例。 定义数据集 用于训练目标检测、实例分割和人物关键点检测的参考脚本允许轻松支持添加…