LLM各层参数详细分析(以LLaMA为例)

网上大多分析LLM参数的文章都比较粗粒度,对于LLM的精确部署不太友好,在这里记录一下分析LLM参数的过程。


首先看QKV。先上transformer原文
在这里插入图片描述
也就是说,当h(heads) = 1时,在默认情况下, W i Q W_i^Q WiQ W i K W_i^K WiK W i V W_i^V WiV都是2维方阵,方阵维度是 d m o d e l × d m o d e l d_{model} \times d_{model} dmodel×dmodel.

结合llama源码 (https://github.com/facebookresearch/llama/blob/main/llama/model.py)

class ModelArgs:dim: int = 4096n_layers: int = 32n_heads: int = 32n_kv_heads: Optional[int] = Nonevocab_size: int = -1  # defined later by tokenizermultiple_of: int = 256  # make SwiGLU hidden layer size multiple of large power of 2ffn_dim_multiplier: Optional[float] = Nonenorm_eps: float = 1e-5max_batch_size: int = 32max_seq_len: int = 2048
# ...class Attention(nn.Module):"""Multi-head attention module."""def __init__(self, args: ModelArgs):"""Initialize the Attention module.Args:args (ModelArgs): Model configuration parameters.Attributes:n_kv_heads (int): Number of key and value heads.n_local_heads (int): Number of local query heads.n_local_kv_heads (int): Number of local key and value heads.n_rep (int): Number of repetitions for local heads.head_dim (int): Dimension size of each attention head.wq (ColumnParallelLinear): Linear transformation for queries.wk (ColumnParallelLinear): Linear transformation for keys.wv (ColumnParallelLinear): Linear transformation for values.wo (RowParallelLinear): Linear transformation for output.cache_k (torch.Tensor): Cached keys for attention.cache_v (torch.Tensor): Cached values for attention."""super().__init__()self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_headsmodel_parallel_size = fs_init.get_model_parallel_world_size()self.n_local_heads = args.n_heads // model_parallel_sizeself.n_local_kv_heads = self.n_kv_heads // model_parallel_sizeself.n_rep = self.n_local_heads // self.n_local_kv_headsself.head_dim = args.dim // args.n_heads

计算出
self.n_kv_heads = h = 32
self.head_dim = 4096/32=128
所以 W i Q W_i^Q WiQ W i K W_i^K WiK W i V W_i^V WiV 大小都为(4096, 128).(在未拆分前 W Q W^Q WQ W K W^K WK W V W^V WV都是 ( d i m , d i m ) = ( 4096 , 4096 ) (dim, dim) = (4096,4096) (dim,dim)=(4096,4096)大小)

Q , K , V Q,K,V Q,K,V的大小都是 ( n c t x , d i m ) = ( 2048 , 4096 ) (n_{ctx}, dim) = (2048,4096) (nctx,dim)=(2048,4096)在多头公式里。在self-attention里,其实他们都是同一个值:输入X),所以 Q × W i Q Q×W_i^Q Q×WiQ K × W i K K×W_i^K K×WiK Q × W i Q Q×W_i^Q Q×WiQ 都是 ( n c t x , d k ) = ( 2048 , 128 ) (n_{ctx}, d_k)=(2048,128) (nctx,dk)=(2048,128)。带入原文attention公式后,大小为(2048, 128)不变。Attention不改变大小(在默认 d k = d v d_k=d_v dk=dv情况下)。
在这里插入图片描述

经过Cancat,分开的头又合并,大小变为(2048, 4096)矩阵,经过 W O W^O WO大小是(4096,4096))全连接,还是(2048, 4096)矩阵。


然后看Feed forward.根据源码,

class FeedForward(nn.Module):def __init__(self,dim: int,hidden_dim: int,multiple_of: int,ffn_dim_multiplier: Optional[float],):"""Initialize the FeedForward module.Args:dim (int): Input dimension.hidden_dim (int): Hidden dimension of the feedforward layer.multiple_of (int): Value to ensure hidden dimension is a multiple of this value.ffn_dim_multiplier (float, optional): Custom multiplier for hidden dimension. Defaults to None.Attributes:w1 (ColumnParallelLinear): Linear transformation for the first layer.w2 (RowParallelLinear): Linear transformation for the second layer.w3 (ColumnParallelLinear): Linear transformation for the third layer."""super().__init__()hidden_dim = int(2 * hidden_dim / 3)# custom dim factor multiplierif ffn_dim_multiplier is not None:hidden_dim = int(ffn_dim_multiplier * hidden_dim)hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)self.w1 = ColumnParallelLinear(dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x)self.w2 = RowParallelLinear(hidden_dim, dim, bias=False, input_is_parallel=True, init_method=lambda x: x)self.w3 = ColumnParallelLinear(dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x)def forward(self, x):return self.w2(F.silu(self.w1(x)) * self.w3(x))

multiattention layer过后,经过加法和normlayer(RMS norm),进入feed_forward前馈网络。注意这里的前馈网络其中一个维度会有8/3≈2.7的放缩,然后multiple_of又保证必须是256的倍数,所以这里算出来hidden_dim是256的倍数中与8/3*4096最接近的,是11008。以这里的w1,w3大小为(4096,11008),w2大小为(11008,4096). 输出结果大小

整个decode layer计算如图所示,
在这里插入图片描述

来源:https://github.com/microsoft/Llama-2-Onnx/blob/main/Images/DecoderLayer.png

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/87198.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【JDK 8-Lambda】3.1 Java高级核心玩转 JDK8 Lambda 表达式

一、 什么是函数式编程 ? 二、 什么是lambda表达式? 1. 先看两个示例 A.【创建线程】 B.【数组排序-降序】 2. lambda表达式特性 A. 使用场景(前提): B. 语法 (params) -> expression C. 参数列表 D. 方法体 F. 好处 一、 什么是函数式编…

Wrapper可以构造的复杂查询条件汇总

目录 概述1、eq:等于条件2、ne:不等于条件3、gt:大于条件4、ge:大于等于条件5、lt:小于条件6、le:小于等于条件7、like:模糊查询条件8、notLike:不包含关键字的模糊查询条件9、左模糊…

实在智能携手40+央企,探索财务大模型及数智化实践与应用

“这次培训给我一个最大的感触就是,过去以为AI智能化、大模型技术是很高深的事情。但现在,我们通过RPA等数字化工具,自主根据自己的工作岗位,完成业务自动化流程的开发和设计。AI技术没有想象中的那么难入门。” 这是一位参加了“…

持续集成Jenkins安装部署

Jenkins是一个在DevOps领域中、支持CI/CD(持续集成/持续交付)过程域的开源项目,其提供可扩展插件的支持,以自动化的机制对项目工程执行打包、编译、构建、测试以及最终发布到目的地服务器并成功部署运行,本文主要描述J…

如何在批发零售业运用IPD?

批发零售业指购进商品后,再向其他批发或零售单位(含个体经营者)及其他企事业单位、机关团体等批量销售生活用品、生产资料的活动,以及从事进出口贸易和贸易经纪与代理的活动,包括拥有货物所有权,并以本单位…

左神高级提升班2 约瑟夫环结构

目录 【案例1】 【题目描述】 【输入描述:】 【输出描述:】 【输入】 【输出】 【思路解析】 【代码实现】 【案例1】 【题目描述】 某公司招聘,有n个人入围,HR在黑板上依次写下m个正整数A1、A2、……、Am,然后…

如何在JoySSL上申请免费的SSL证书

1,前往 JoySSL 的官方网站注册页面,创建一个账号并登录您的 JoySSL 账户。 扫码注册账号申请免费证书https://www.joyssl.com/certificate/select/free.html?nid52,找到并选择你需要的 SSL 证书相关的功能或选项。 3,提供您的域…

springboot实现发送邮箱验证码

准备工作 在邮箱官网开放SMTP授权,获取相应密钥,才可以进行发送邮件 这里以网易163邮箱为例,登录邮箱后,依次点击“设置-POP3/SMTP/IMAP” ,然后开启SMTP服务。这时候会提示一个授权码,例如:H…

全流程ARCGIS Pro技术应用教程

详情点击公众号链接:全流程ARCGIS Pro技术应用教程 前沿 GIS是利用电子计算机及其外部设备,采集、存储、分析和描述整个或部分地球表面与空间信息系统。简单地讲,它是在一定的地域内,将地理空间信息和 一些与该地域地理信息相关…

【Ambari】银河麒麟V10 ARM64架构_安装Ambari2.7.6HDP3.3.1(HiDataPlus)

🍁 博主 "开着拖拉机回家"带您 Go to New World.✨🍁 🦄 个人主页——🎐开着拖拉机回家_大数据运维-CSDN博客 🎐✨🍁 🪁🍁 希望本文能够给您带来一定的帮助🌸文…

Matlab中(:,1)和(:,end)和[~, A]的含义与用法

背景 阅读Moses Chong-ook Nah的DMP-MATLAB程序记录。 github链接:https://github.com/mosesnah-shared/DMP-MATLAB 如果不知道某个函数或变量的作用,直接打印出来,看看输出是什么。不知道matlab如何打印?程序后面的分号;去掉就…

Day1-DeepWalk

论文《DeepWalk: Online Learning of Social Representations》 2014年发表在数据挖掘顶会ACM SIGKDD(KDD)上的论文 目的:学习节点表示 推动:将自然语言处理里面的无监督学习方法迁移至此 思路:将图结构序列化&#x…

C#中实现校验是否包含中文与http接口地址中解析ip和端口号

场景 Winform/CSharp中实现对Http接口地址、IP地址字符串格式/合法性校验: Winform/CSharp中实现对Http接口地址、IP地址字符串格式/合法性校验_c# 检查ip格式_霸道流氓气质的博客-CSDN博客 在上面的基础上对某http接口地址(ip加端口号,示例http://12…

深入电机控制基础知识(1)- 磁共能与电磁转矩

1.1 概述 打开任意一本电机学的教材,翻到电机基本概念的说明的位置,总能看到一句描述电机本质的话:电机是一种机电能量转化的装置。 机电能量转化,很生动形象的说明电机的工作原理。对于电动机而言,吸收电能&#xff…

LuatOS-SOC接口文档(air780E)--errDump - 错误上报

示例 -- 基本用法, 10分钟上报一次,如果有的话 if errDump thenerrDump.config(true, 600) end-- 附开源服务器端: https://gitee.com/openLuat/luatos-devlogerrDump.dump(zbuff, type, isDelete) 手动读取异常日志,主要用于用户将日志发送给自己的服务器而不是I…

全网首发YOLOv5/YOLOv7暴力涨点:Gold-YOLO,遥遥领先,超越所有YOLO | 华为诺亚NeurIPS23

💡💡💡本文独家改进:提出了全新的信息聚集-分发(Gather-and-Distribute Mechanism)GD机制,Gold-YOLO,替换yolov5 head部分 实现暴力涨点 Gold-YOLO | 亲测在多个数据集能够实现大幅涨点 💡💡💡Yolov5/Yolov7魔术师,独家首发创新(原创),适用于Yolov5、…

优思学院|如何解读Minitab中测量系统分析(MSA GRR)的结果?

在现代制造和质量控制过程中,精确的测量是至关重要的。为了确保我们的测量工具可靠,我们需要评估其重复性与再现性。这就是测量系统分析(Measurement System Analysis,简称MSA)的关键目标之一。以下将介绍如何使用Mini…

Vue如何监听键盘事件

引言 在Web开发中,键盘事件是非常常见的交互方式之一。Vue作为一种流行的JavaScript框架,提供了一种简单而灵活的方式来监听键盘事件。本文将介绍如何在Vue中监听键盘事件,并展示一些实用的示例。 目录 Vue中监听键盘事件的基本用法监听特定…

前端进阶--深入理解JavaScript

1、JS的作用域和作用域链 作用域链的作用是保证对执行环境有权访问的所有变量和函数的有序访问,通过作用域链,我们可以访问到外层环境的变量和函数。作用域链的本质上是一个指向变量对象的指针列表。变量对象是一个包含了执行环境中所有变量和函数的对象…

IDEA下使用Spring MVC

<?xml version"1.0" encoding"UTF-8"?> <project xmlns"http://maven.apache.org/POM/4.0.0"xmlns:xsi"http://www.w3.org/2001/XMLSchema-instance"xsi:schemaLocation"http://maven.apache.org/POM/4.0.0 http://ma…