LLM各层参数详细分析(以LLaMA为例)

网上大多分析LLM参数的文章都比较粗粒度,对于LLM的精确部署不太友好,在这里记录一下分析LLM参数的过程。


首先看QKV。先上transformer原文
在这里插入图片描述
也就是说,当h(heads) = 1时,在默认情况下, W i Q W_i^Q WiQ W i K W_i^K WiK W i V W_i^V WiV都是2维方阵,方阵维度是 d m o d e l × d m o d e l d_{model} \times d_{model} dmodel×dmodel.

结合llama源码 (https://github.com/facebookresearch/llama/blob/main/llama/model.py)

class ModelArgs:dim: int = 4096n_layers: int = 32n_heads: int = 32n_kv_heads: Optional[int] = Nonevocab_size: int = -1  # defined later by tokenizermultiple_of: int = 256  # make SwiGLU hidden layer size multiple of large power of 2ffn_dim_multiplier: Optional[float] = Nonenorm_eps: float = 1e-5max_batch_size: int = 32max_seq_len: int = 2048
# ...class Attention(nn.Module):"""Multi-head attention module."""def __init__(self, args: ModelArgs):"""Initialize the Attention module.Args:args (ModelArgs): Model configuration parameters.Attributes:n_kv_heads (int): Number of key and value heads.n_local_heads (int): Number of local query heads.n_local_kv_heads (int): Number of local key and value heads.n_rep (int): Number of repetitions for local heads.head_dim (int): Dimension size of each attention head.wq (ColumnParallelLinear): Linear transformation for queries.wk (ColumnParallelLinear): Linear transformation for keys.wv (ColumnParallelLinear): Linear transformation for values.wo (RowParallelLinear): Linear transformation for output.cache_k (torch.Tensor): Cached keys for attention.cache_v (torch.Tensor): Cached values for attention."""super().__init__()self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_headsmodel_parallel_size = fs_init.get_model_parallel_world_size()self.n_local_heads = args.n_heads // model_parallel_sizeself.n_local_kv_heads = self.n_kv_heads // model_parallel_sizeself.n_rep = self.n_local_heads // self.n_local_kv_headsself.head_dim = args.dim // args.n_heads

计算出
self.n_kv_heads = h = 32
self.head_dim = 4096/32=128
所以 W i Q W_i^Q WiQ W i K W_i^K WiK W i V W_i^V WiV 大小都为(4096, 128).(在未拆分前 W Q W^Q WQ W K W^K WK W V W^V WV都是 ( d i m , d i m ) = ( 4096 , 4096 ) (dim, dim) = (4096,4096) (dim,dim)=(4096,4096)大小)

Q , K , V Q,K,V Q,K,V的大小都是 ( n c t x , d i m ) = ( 2048 , 4096 ) (n_{ctx}, dim) = (2048,4096) (nctx,dim)=(2048,4096)在多头公式里。在self-attention里,其实他们都是同一个值:输入X),所以 Q × W i Q Q×W_i^Q Q×WiQ K × W i K K×W_i^K K×WiK Q × W i Q Q×W_i^Q Q×WiQ 都是 ( n c t x , d k ) = ( 2048 , 128 ) (n_{ctx}, d_k)=(2048,128) (nctx,dk)=(2048,128)。带入原文attention公式后,大小为(2048, 128)不变。Attention不改变大小(在默认 d k = d v d_k=d_v dk=dv情况下)。
在这里插入图片描述

经过Cancat,分开的头又合并,大小变为(2048, 4096)矩阵,经过 W O W^O WO大小是(4096,4096))全连接,还是(2048, 4096)矩阵。


然后看Feed forward.根据源码,

class FeedForward(nn.Module):def __init__(self,dim: int,hidden_dim: int,multiple_of: int,ffn_dim_multiplier: Optional[float],):"""Initialize the FeedForward module.Args:dim (int): Input dimension.hidden_dim (int): Hidden dimension of the feedforward layer.multiple_of (int): Value to ensure hidden dimension is a multiple of this value.ffn_dim_multiplier (float, optional): Custom multiplier for hidden dimension. Defaults to None.Attributes:w1 (ColumnParallelLinear): Linear transformation for the first layer.w2 (RowParallelLinear): Linear transformation for the second layer.w3 (ColumnParallelLinear): Linear transformation for the third layer."""super().__init__()hidden_dim = int(2 * hidden_dim / 3)# custom dim factor multiplierif ffn_dim_multiplier is not None:hidden_dim = int(ffn_dim_multiplier * hidden_dim)hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)self.w1 = ColumnParallelLinear(dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x)self.w2 = RowParallelLinear(hidden_dim, dim, bias=False, input_is_parallel=True, init_method=lambda x: x)self.w3 = ColumnParallelLinear(dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x)def forward(self, x):return self.w2(F.silu(self.w1(x)) * self.w3(x))

multiattention layer过后,经过加法和normlayer(RMS norm),进入feed_forward前馈网络。注意这里的前馈网络其中一个维度会有8/3≈2.7的放缩,然后multiple_of又保证必须是256的倍数,所以这里算出来hidden_dim是256的倍数中与8/3*4096最接近的,是11008。以这里的w1,w3大小为(4096,11008),w2大小为(11008,4096). 输出结果大小

整个decode layer计算如图所示,
在这里插入图片描述

来源:https://github.com/microsoft/Llama-2-Onnx/blob/main/Images/DecoderLayer.png

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/87198.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

实在智能携手40+央企,探索财务大模型及数智化实践与应用

“这次培训给我一个最大的感触就是,过去以为AI智能化、大模型技术是很高深的事情。但现在,我们通过RPA等数字化工具,自主根据自己的工作岗位,完成业务自动化流程的开发和设计。AI技术没有想象中的那么难入门。” 这是一位参加了“…

持续集成Jenkins安装部署

Jenkins是一个在DevOps领域中、支持CI/CD(持续集成/持续交付)过程域的开源项目,其提供可扩展插件的支持,以自动化的机制对项目工程执行打包、编译、构建、测试以及最终发布到目的地服务器并成功部署运行,本文主要描述J…

如何在批发零售业运用IPD?

批发零售业指购进商品后,再向其他批发或零售单位(含个体经营者)及其他企事业单位、机关团体等批量销售生活用品、生产资料的活动,以及从事进出口贸易和贸易经纪与代理的活动,包括拥有货物所有权,并以本单位…

左神高级提升班2 约瑟夫环结构

目录 【案例1】 【题目描述】 【输入描述:】 【输出描述:】 【输入】 【输出】 【思路解析】 【代码实现】 【案例1】 【题目描述】 某公司招聘,有n个人入围,HR在黑板上依次写下m个正整数A1、A2、……、Am,然后…

如何在JoySSL上申请免费的SSL证书

1,前往 JoySSL 的官方网站注册页面,创建一个账号并登录您的 JoySSL 账户。 扫码注册账号申请免费证书https://www.joyssl.com/certificate/select/free.html?nid52,找到并选择你需要的 SSL 证书相关的功能或选项。 3,提供您的域…

springboot实现发送邮箱验证码

准备工作 在邮箱官网开放SMTP授权,获取相应密钥,才可以进行发送邮件 这里以网易163邮箱为例,登录邮箱后,依次点击“设置-POP3/SMTP/IMAP” ,然后开启SMTP服务。这时候会提示一个授权码,例如:H…

全流程ARCGIS Pro技术应用教程

详情点击公众号链接:全流程ARCGIS Pro技术应用教程 前沿 GIS是利用电子计算机及其外部设备,采集、存储、分析和描述整个或部分地球表面与空间信息系统。简单地讲,它是在一定的地域内,将地理空间信息和 一些与该地域地理信息相关…

【Ambari】银河麒麟V10 ARM64架构_安装Ambari2.7.6HDP3.3.1(HiDataPlus)

🍁 博主 "开着拖拉机回家"带您 Go to New World.✨🍁 🦄 个人主页——🎐开着拖拉机回家_大数据运维-CSDN博客 🎐✨🍁 🪁🍁 希望本文能够给您带来一定的帮助🌸文…

Matlab中(:,1)和(:,end)和[~, A]的含义与用法

背景 阅读Moses Chong-ook Nah的DMP-MATLAB程序记录。 github链接:https://github.com/mosesnah-shared/DMP-MATLAB 如果不知道某个函数或变量的作用,直接打印出来,看看输出是什么。不知道matlab如何打印?程序后面的分号;去掉就…

Day1-DeepWalk

论文《DeepWalk: Online Learning of Social Representations》 2014年发表在数据挖掘顶会ACM SIGKDD(KDD)上的论文 目的:学习节点表示 推动:将自然语言处理里面的无监督学习方法迁移至此 思路:将图结构序列化&#x…

深入电机控制基础知识(1)- 磁共能与电磁转矩

1.1 概述 打开任意一本电机学的教材,翻到电机基本概念的说明的位置,总能看到一句描述电机本质的话:电机是一种机电能量转化的装置。 机电能量转化,很生动形象的说明电机的工作原理。对于电动机而言,吸收电能&#xff…

优思学院|如何解读Minitab中测量系统分析(MSA GRR)的结果?

在现代制造和质量控制过程中,精确的测量是至关重要的。为了确保我们的测量工具可靠,我们需要评估其重复性与再现性。这就是测量系统分析(Measurement System Analysis,简称MSA)的关键目标之一。以下将介绍如何使用Mini…

前端进阶--深入理解JavaScript

1、JS的作用域和作用域链 作用域链的作用是保证对执行环境有权访问的所有变量和函数的有序访问,通过作用域链,我们可以访问到外层环境的变量和函数。作用域链的本质上是一个指向变量对象的指针列表。变量对象是一个包含了执行环境中所有变量和函数的对象…

IDEA下使用Spring MVC

<?xml version"1.0" encoding"UTF-8"?> <project xmlns"http://maven.apache.org/POM/4.0.0"xmlns:xsi"http://www.w3.org/2001/XMLSchema-instance"xsi:schemaLocation"http://maven.apache.org/POM/4.0.0 http://ma…

【GO】LGTM_Grafana_gozero_配置trace(4)_代码实操及追踪

最近在尝试用 LGTM 来实现 Go 微服务的可观测性&#xff0c;就顺便整理一下文档。 Tempo 会分为 4 篇文章&#xff1a; Tempo 的架构官网测试实操跑通gin 框架发送 trace 数据到 tempogo-zero 微服务框架发送数据到 tempo 本文就是写一下如何在 go-zero 微服务框架里面配置 t…

暗月中秋靶场活动writeup

前言 暗月在中秋节搞了个靶场活动&#xff0c;一共有4个flag&#xff0c;本着增长经验的想法参加了本次活动&#xff0c;最终在活动结束的时候拿到了3个flag&#xff0c;后面看了其他人的wp也复现拿到第四个flag。过程比较曲折&#xff0c;所以记录一下。 靶场地址 103.108.…

Chinese-LLaMA-AIpaca

文章目录 关于 Chinese-LLaMA-Alpaca一、LLaMA模型 --> HF格式二、合并LoRA权重,生成全量模型权重方式1:单LoRA权重合并方式2:多LoRA权重合并(适用于Chinese-Alpaca-Plus )三、使用 Transformers 进行推理四、使用 webui 搭建界面1、克隆text-generation-webui并安装必…

基础排序算法

插入排序&#xff08;insertion sort&#xff09; 插入排序每次循环将一个元素放置在适当的位置。像抓牌一样。手里的排是有序的&#xff0c;新拿一张牌&#xff0c;与手里的牌进行比较将其放在合适的位置。 插入排序要将待排序的数据分成两部分&#xff0c;一部分有序&#…

Leetcode 409. 最长回文串

文章目录 题目代码&#xff08;9.24 首刷自解&#xff09; 题目 Leetcode 409. 最长回文串 代码&#xff08;9.24 首刷自解&#xff09; class Solution { public:int longestPalindrome(string s) {unordered_map<char, int> mp;for(char c : s) mp[c];int res 0;int…

什么是Selenium?使用Selenium进行自动化测试!

你知道什么是 Selenium 吗&#xff1f;你知道为什么要使用它吗&#xff1f;答案就在本文中&#xff0c;很高兴能够与你共飧。 自动化测试正席卷全球&#xff0c;Selenium 认证是业界最抢手的技能之一。 什么是 Selenium&#xff1f; Selenium 是一种开源工具&#xff0c;用于…