探索和构建 LLaMA 3 架构:深入探讨组件、编码和推理技术(四)分组多查询注意力

探索和构建 LLaMA 3 架构:深入探讨组件、编码和推理技术(四)分组多查询注意力

Grouped-query Attention,简称GQA

分组查询注意力(Grouped-query Attention,简称GQA)是多查询和多头注意力的插值。它在保持与多查询注意力相当的处理速度的同时,实现了与多头注意力相似的质量。

在这里插入图片描述

自回归解码的标准做法是缓存序列中先前标记的键和值,以加快注意力计算的速度。

  • 然而,随着上下文窗口或批量大小的增加,多头注意力(Multi-Head Attention,简称MHA)模型中键值缓存(Key-Value Cache,简称KV Cache)的大小所关联的内存成本显著增加。

  • 多查询注意力(Multi-Query Attention,简称MQA)是一种机制,它对多个查询仅使用单个键值头,这可以节省内存并大幅加快解码器的推理速度。

  • Llama(一种模型)整合了GQA,以解决在Transformer模型自回归解码期间的内存带宽挑战。主要问题源于GPU进行计算的速度比它们将数据移入内存的速度快。在每个阶段都需要加载解码器权重和注意力键,这消耗了大量的内存。

在这里插入图片描述
在这里插入图片描述

class SelfAttention(nn.Module): def  __init__(self, args: ModelArgs):super().__init__()self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads# Indicates the number of heads for the queriesself.n_heads_q = args.n_heads# Indiates how many times the heads of keys and value should be repeated to match the head of the Queryself.n_rep = self.n_heads_q // self.n_kv_heads# Indicates the dimentiona of each headself.head_dim = args.dim // args.n_headsself.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)self.cache_k = torch.zeros((args.max_batch_size, args.max_seq_len, self.n_kv_heads, self.head_dim))self.cache_v = torch.zeros((args.max_batch_size, args.max_seq_len, self.n_kv_heads, self.head_dim))def forward(self, x: torch.Tensor, start_pos: int, freq_complex: torch.Tensor):batch_size, seq_len, _ = x.shape #(B, 1, dim)# Apply the wq, wk, wv matrices to query, key and value# (B, 1, dim) -> (B, 1, H_q * head_dim)xq = self.wq(x)# (B, 1, dim) -> (B, 1, H_kv * head_dim)xk = self.wk(x)xv = self.wv(x)# (B, 1, H_q * head_dim) -> (B, 1, H_q, head_dim)xq = xq.view(batch_size, seq_len, self.n_heads_q, self.head_dim)xk = xk.view(batch_size, seq_len, self.n_kv_heads, self.head_dim)# (B, 1, H_kv * head_dim) -> (B, 1, H_kv, head_dim)xv = xv.view(batch_size, seq_len, self.n_kv_heads, self.head_dim)# Apply the rotary embeddings to the keys and values# Does not chnage the shape of the tensor# (B, 1, H_kv, head_dim) -> (B, 1, H_kv, head_dim)xq = apply_rotary_embeddings(xq, freq_complex, device=x.device)xk = apply_rotary_embeddings(xk, freq_complex, device=x.device)# Replace the enty in the cache for this tokenself.cache_k[:batch_size, start_pos:start_pos + seq_len] = xkself.cache_v[:batch_size, start_pos:start_pos + seq_len] = xv# Retrive all the cached keys and values so far# (B, seq_len_kv, H_kv, head_dim)keys = self.cache_k[:batch_size, 0:start_pos + seq_len]values = self.cache_v[:batch_size, 0:start_pos+seq_len] # Repeat the heads of the K and V to reach the number of heads of the querieskeys = repeat_kv(keys, self.n_rep)values = repeat_kv(values, self.n_rep)# (B, 1, h_q, head_dim) --> (b, h_q, 1, head_dim)xq = xq.transpose(1, 2)keys = keys.transpose(1, 2)values = values.transpose(1, 2)# (B, h_q, 1, head_dim) @ (B, h_kv, seq_len-kv, head_dim) -> (B, h_q, 1, seq_len-kv)scores = torch.matmul(xq, keys.transpose(2,3)) / math.sqrt(self.head_dim)scores = F.softmax(scores.float(), dim=-1).type_as(xq)# (B, h_q, 1, seq_len) @ (B, h_q, seq_len-kv, head_dim) --> (b, h-q, q, head_dim)output = torch.matmul(scores, values)# (B, h_q, 1, head_dim) -> (B, 1, h_q, head_dim) -> ()output = (output.transpose(1,2).contiguous().view(batch_size, seq_len, -1))return self.wo(output) # (B, 1, dim) -> (B, 1, dim)

系列博客

探索和构建 LLaMA 3 架构:深入探讨组件、编码和推理技术(一)
https://duanzhihua.blog.csdn.net/article/details/138208650
探索和构建 LLaMA 3 架构:深入探讨组件、编码和推理技术(二)
https://duanzhihua.blog.csdn.net/article/details/138212328

探索和构建 LLaMA 3 架构:深入探讨组件、编码和推理技术(三)KV缓存
https://duanzhihua.blog.csdn.net/article/details/138213306
在这里插入图片描述
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/3960.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Blender基础操作

1.移动物体: 选中一个物体,按G,之后可以任意移动 若再按X,则只沿X轴移动,同理可按Y与Z 2.旋转物体: 选中一个物体,按R,之后可以任意旋转 若再按X,则只绕X轴旋转&…

Python自学之路--003:PyCharm新建工程之后安装的Python第三方库找不到问题

目录 1、概述 2、问题原因 3、解决办法 3.1、.py文件通过.bat不能调用 3.2、通过调用之前PyCharm工程的解释器找到库 3.3、重新安装一遍或将库Copy到新工程的.venv里面 1、概述 通过PyCharm新建一个工程的时候发现,之前安装的python库没了,如下图。…

【Linux】:文件查看 stat、cat、more、less、head、tail、uniq、wc

🎥 屿小夏 : 个人主页 🔥个人专栏 : Linux深造日志 🌄 莫道桑榆晚,为霞尚满天! 文章目录 📑前言一、stat(查看文件详细属性信息)1.1 内容解析:1.2…

【linux高性能服务器编程】项目实战——仿QQ聊天程序源码剖析

hello !大家好呀! 欢迎大家来到我的Linux高性能服务器编程系列之项目实战——仿QQ聊天程序源码剖析,在这篇文章中,你将会学习到如何利用Linux网络编程技术来实现一个简单的聊天程序,并且我会给出源码进行剖析&#xff…

远程控制安卓手机:便捷、高效与安全的方法

在移动设备的领域里,远程控制安卓手机的能力也变得越来越重要。这种技术可以让我们在远程地点方便地操作手机,无论是处理紧急事务、帮助他人解决问题,还是仅仅为了享受科技带来的便利。本文将为你介绍2种便捷、高效且安全的方法,让…

【智能算法】向日葵优化算法(SFO)原理及实现

目录 1.背景2.算法原理2.1算法思想2.2算法过程 3.结果展示4.参考文献 1.背景 2019年,GF Gomes等人受到自然界向日葵运动行为启发,提出了向日葵优化算法(Sunflower Optimization, SFO)。 2.算法原理 2.1算法思想 SFO模拟向日葵行…

【服务器部署篇】Linux下Ansible安装和配置

作者介绍:本人笔名姑苏老陈,从事JAVA开发工作十多年了,带过刚毕业的实习生,也带过技术团队。最近有个朋友的表弟,马上要大学毕业了,想从事JAVA开发工作,但不知道从何处入手。于是,产…

vue3【详解】vue3 比 vue2 升级了哪些重要的功能?

改用 createApp 初始化实例 vue2 使用 new Vue() 初始化实例 vue3 使用 Vue.createApp() 初始化实例 新增 emits 选项 vue3 选项式API中新增了emits 选项,用于显示声明组件中的自定义事件,自定义事件的名称,需用 on 开头。 export default {…

如何在vue3+vite中优雅的使用iconify图标

前言 从Vue2迁移到Vue3,在使用上有着很大的差别。本文的话主要是针对图标的使用差别上进行分析,同时给出基于iconify图标库中unplugin-icons的用法。这里特殊说明一下:其实element-plus中用到的图标也是基于iconify图标库的,在我们…

LT9611UXC双端口 MIPI DSI/CSI 转 HDMI2.0,带音频

1. 说明 LT9611UXC 是一款高性能 MIPI DSI/CSI 至 HDMI2.0 转换器。MIPI DSI/CSI 输入具有可配置的单端口或双端口,具有 1 个高速时钟通道和 1~4 个高速数据通道,工作速率最高为 2Gbps/通道,可支持高达 16Gbps 的总带宽。 LT9611UXC 支持突发…

13 c++版本的五子棋

前言 呵呵 这大概是 大学里面的 c 五子棋了吧 有一些 面向对象的理解, 但是不多 这里 具体的实现 就不赘述, 仅仅是 发一下代码 以及 具体的使用 然后 貌似 放在 win10 上面执行 还有一些问题, 渲染的, 应该很好调整 五子棋 #include<Windows.h> #include<io…

STM32、GD32驱动SHT30温湿度传感器源码分享

一、SHT30介绍 1、简介 SHT30是一种数字湿度和温度传感器&#xff0c;由Sensirion公司生产。它是基于物理蒸发原理的湿度传感器&#xff0c;具有高精度和长期稳定性。SHT30采用I2C数字接口&#xff0c;可以直接与微控制器或其他设备连接。该传感器具有低功耗和快速响应的特点…

树莓派4-通过IIC实现图片循环播放

一、环境 1、树莓派4&#xff1b; 2、串口连接电脑&#xff1b; 3、树莓派由杜邦线连接0.96寸OLED1306协议 4、树莓派能够联网&#xff0c;便于安装环境。离线情况也可以安装&#xff0c;相对麻烦&#xff1b; 二、目标 1、树莓派可以开启IIC并识别已连接的IIC&#xff1b; …

Web3解密:理解去中心化应用的核心原理

引言 在当前数字化时代&#xff0c;去中心化技术和应用正在逐渐引起人们的关注和兴趣。Web3技术作为去中心化应用&#xff08;DApps&#xff09;的基础&#xff0c;为我们提供了一个全新的互联网体验。但是&#xff0c;对于许多人来说&#xff0c;这个复杂的概念仍然充满了神秘…

MongoDB基础操作

文章目录 一、什么是MongoDB二、MongoDB 与关系型数据库对比三、数据类型四、部署MongoDB1、下载二进制包2、下载安装包并解压3、创建用于存放数据和日志的目录&#xff0c;并修改权限4、启动MongoDB4.1前台启动4.2后台启动4.3、配置文件启动服务4.4、配置systemd服务4.5、syst…

RabbitMQ发布确认和消息回退(6)

概念 发布确认原理 生产者将信道设置成 confirm 模式&#xff0c;一旦信道进入 confirm 模式&#xff0c;所有在该信道上面发布的消息都将会被指派一个唯一的 ID(从 1 开始)&#xff0c;一旦消息被投递到所有匹配的队列之后&#xff0c;broker就会发送一个确认给生产者(包含消…

qt实现方框调整

效果 在四周调整 代码 #ifndef MAINWINDOW_H #define MAINWINDOW_H#include <QWidget>class MainWindow : public QWidget {Q_OBJECT public:explicit MainWindow(QWidget *parent 0);~MainWindow();void paintEvent(QPaintEvent *event);void updateRect();void re…

Restful API 具体设计规范(概述)

协议 https 域名 https://www.baidu.com/api 版本 https://www.baidu.com/v1 路径 https://www.baidu.com/v1/blogs 方法 数据过滤 状态码返回结果 返回的数据格式 尽量使用 JSON&#xff0c;避免使用 XML。 总结&#xff1a; 看 url 就知道要什么看 http method 就知道干…

Linux进阶篇:CentOS7搭建NFS文件共享服务

CentOS7搭建NFS文件共享服务 一、NFS介绍 NFS(Network File System)意为网络文件系统&#xff0c;它最大的功能就是可以通过网络&#xff0c;让不同的机器不同的操作系统可以共享彼此的文件。简单的讲就是可以挂载远程主机的共享目录到本地&#xff0c;就像操作本地磁盘一样&…

Docker——数据管理和网络通信

目录 一、Docker的数据管理 1.数据卷 2.数据卷容器 3.容器互联 二、Docker镜像的创建 1.基于现有镜像创建 2.基于本地模板创建 3.基于Dockerfile 创建 3.1联合文件系统&#xff08;UnionFS&#xff09; 3.2镜像加载原理 3.3为什么Docker里的Centos大小才200M 4.Dcok…