Group Query Attention (GQA) 机制详解以及手动实现计算

Group Query Attention (GQA) 机制详解

1. GQA的定义

Grouped-Query Attention (GQA) 是对 Multi-Head Attention (MHA) 和 Multi-Query Attention (MQA) 的扩展。通过提供计算效率和模型表达能力之间的灵活权衡,实现了查询头的分组。GQA将查询头分成了G个组,每个组共享一个公共的键(K)和值(V)投影。

2. GQA的变体

GQA有三种变体:

  • GQA-1:一个单独的组,等同于 Multi-Query Attention (MQA)。
  • GQA-H:组数等于头数,基本上与 Multi-Head Attention (MHA) 相同。
  • GQA-G:一个中间配置,具有G个组,平衡了效率和表达能力。
3. GQA的优势

使用G个组可以减少存储每个头的键和值所需的内存开销,特别是在具有大的上下文窗口或批次大小的情况下。GQA提供了对模型质量和效率的细致控制。

4. GQA的实现

GQA的最简形式可以通过实现 GroupedQueryAttention 类来实现。GroupedQueryAttention 类继承自 Attention 类,重写了 forward 方法,其中使用了 MultiQueryAttention 类的实例来处理每个组的查询。通过将每个组的结果拼接起来,然后与投影矩阵进行矩阵乘法运算,最终得到 GQA 的输出。[1]

pytorch 示例实现:

假设我们有以下初始化的query, key, value:

# shapes: (batch_size, seq_len, num_heads, head_dim)
query = torch.randn(1, 256, 8, 64)
key = torch.randn(1, 256, 2, 64)
value = torch.randn(1, 256, 2, 64)
1. 确定分组数量

首先,我们需要确定将查询头分为多少组。在这个例子中,我们有8个查询头,而键和值的头数为2,所以我们可以将查询头分为4组,每组有2个查询头。

2. 对查询进行分组

然后,我们将查询头分组。我们可以使用 torch.chunk 函数将查询张量沿着头维度分割成4个组,每个组有2个头。

query_groups = torch.chunk(query, 4, dim=2)  # shape of each group: (1, 256, 2, 64)
3. 计算注意力分数

对于每一个查询组,我们计算它与键的注意力分数。我们首先计算查询组和键的点积,然后通过 torch.softmax 函数得到注意力分数。

attention_scores = []
for query_group in query_groups:score = torch.matmul(query_group, key.transpose(-2, -1))  # shape: (1, 256, 2, 256)score = torch.softmax(score, dim=-1)attention_scores.append(score)
4. 计算注意力输出

接下来,我们使用注意力分数对值进行加权求和,得到每一个查询组的注意力输出。

attention_scores = []
for query_group in query_groups:score = torch.matmul(query_group, key.transpose(-2, -1))  # shape: (1, 256, 2, 256)score = torch.softmax(score, dim=-1)attention_scores.append(score)
5. 拼接输出

最后,我们将所有查询组的注意力输出拼接起来,得到最终的 Grouped Query Attention 的输出。

attention_outputs = []
for score in attention_scores:output = torch.matmul(score, value)  # shape: (1, 256, 2, 64)attention_outputs.append(output)

这就是 Grouped Query Attention 的实现过程。在这个过程中,我们将查询头分组,然后对每一个查询组分别计算注意力分数和输出,最后将所有查询组的输出拼接起来。这样可以减少存储每个头的键和值所需的内存开销,特别是在具有大的上下文窗口或批次大小的情况下。


  1. Grouped-Query Attention (GQA) - The Large Language Model Playbook

  2. 安全验证 - 知乎
  3. 安全验证 - 知乎
  4. 安全验证 - 知乎
  5. Grouped-Query Attention (GQA) - The Large Language Model Playbook

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/681.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

事务的传播行为介绍和事务失效

常用的就下图介绍的这两种,REQUIRED 支持当前事务,如果不存在,就新建一个,EQUIRES_NEW 如果有事务存在,挂起当前事务,创建一个新的事务 同一个service中必须用代理对象调用,否则失效

使用go和消息队列优化投票功能

文章目录 1、优化方案与主要实现代码1.1、原系统的技术架构1.2、新系统的技术架构1.3、查看和投票接口实现1.4、数据入库MySQL协程实现1.5、路由配置1.6、启动程序入口实现 2、压测结果2.1、设置Jmeter线程组2.2、Jmeter聚合报告结果,支持11240/秒吞吐量2.3、Jmeter…

【情侣博客网站】

效果图 PC端 建塔教程 第一步:下载网站源码(在文章下方有下载链接) 第二步:上传到服务器或虚拟主机,解压。 第三步:这一步很关键,数据库进行连接,看图 admin/connect.php就是这…

el-menu 该有的页面显示不出来第一个应该想到的问题首先就算检查是否多写了一个 , 导致显示不出来原有的页面

问题描述 el-menu 该有的页面显示不出来第一个应该想到的问题首先就算检查是否多写了一个 , 导致显示不出来原有的页面 如图所示多写了一个,就会导致该有的页面显示不出来。

【nginx代理和tengine的启动-重启等命令】

在nginx成功启动后[任务管理器有nginx.exe进程],运行vue项目,在浏览器访问http://localhost:10001/,提示:访问拒绝(调试中network某些地址403); 解决方案: localhost改为ip&#xff…

【Flutter】GetX状态管理及路由管理用法

目录 一、安装二、使用1.安装GetX插件,快捷生成模版代码2.主入口MaterialApp改成GetMaterialApp3.定义路由常量RoutePath类、别名映射页面RoutePages类4. 初始initialRoute,getPages。5.调用 总结 一、安装 dependencies: get: ^4.6.6二、使用 1.安装G…

MDK-ARM Keil5.38 下载安装环境搭建

一、keil软件介绍 KEIL是公司的名称,有时候也指KEIL公司的所有软件开发工具,目前2005年Keil由ARM公司收购,成为ARM的公司之一。 MDK(Microcontroller Development Kit) 也称MDK-ARM、KEIL MDK、RealView MDK、KEIL For…

[最新]访问/加速StackOverFlow的方法

但是有很多问题都是在StackOverFlow上有现成的解决方案,而某度搜索引擎…前一页的回答互相抄袭,看着实在胀眼睛。 话不多说,解决办法: 直接访问插件商店下载插件(最快捷方便,点点就行)&#x…

Python中的迭代器:深入理解与实用指南

文章目录 1. 迭代器的基本概念2. Python中的迭代器实例3. 自定义迭代器3.1 例子3.2 详细过程 4. 迭代器的高级应用5. 常见问题与解答 迭代器是Python中非常核心的概念之一,在面试中也会被问到。下面我会详细介绍什么是迭代器,使用方法,以及使…

怎么转行做产品经理?

小白转产品经理第一点要先学基础理论知识,学了理论再去实践,转行,跳槽! 学理论比较好的就是去报NPDP的系统班,考后也会有面试指导课、职场晋升课程,对小白来说非常合适了~(B站:不爱…

探索 IntelliJ IDEA 2024.1最新变化:全面升级助力编码效率

探索 IntelliJ IDEA 2024.1最新变化:全面升级助力编码效率 文章目录 探索 IntelliJ IDEA 2024.1最新变化:全面升级助力编码效率摘要引言 IntelliJ IDEA 2024.1 最新变化关键亮点全行代码补全 Ultimate对 Java 22 功能的支持新终端 Beta编辑器中的粘性行 …

『FPGA通信接口』串行通信接口-IIC(2)EEPROM读写控制器

文章目录 1.EEPROM简介2.AT24C04简介3.逻辑框架设计4.随机读写时序5.仿真代码与仿真结果分析6.注意事项7.效果8.传送门 1.EEPROM简介 EEPROM (Electrically Erasable Programmable read only memory) 是指带电可擦可编程只读存储器。是一种掉电后数据不丢失的存储芯片。在嵌入…

uniapp项目中表单输入完整项之后提交按钮颜色高亮显示并触发点击事件

1.效果图&#xff1a; 2.html <view class"add" :style"{background: dynamicBackgroundColor, border-color: white}" click"handleClick">添加新地址 </view> 3.js formData: {name: ,phoneNumber: ,addressDetail: }//利用com…

JVM 性能调优命令(jps,jinfo,jstat,jstack,jmap)

常用命令&#xff1a;jps、jinfo、jstat、jstack、jmap jps jps查看java进程及相关信息 jps -l 输出jar包路径&#xff0c;类全名 jps -m 输出main参数 jps -v 输出JVM参数jps命令示例 显示本机的Java虚拟机进程&#xff1a; # jps 15729 jar 92153 Jps 90267 Jstat显示主类…

嵌入式Linux:Linux系统文件目录说明

在Linux系统中&#xff0c;系统文件和目录按照一定的约定被组织和分配到不同的位置。这些文件和目录通常用于存储系统配置、程序文件、库文件等。 以下是一些常见的系统文件目录及其用途的详细说明&#xff1a; /bin:存放系统中最基本的命令&#xff08;二进制文件&#xff09;…

YOLOv9改进策略 | Neck篇 | 2024.1最新MFDS-DETR的HS-FPN改进特征融合层(轻量化Neck、全网独家首发)

一、本文介绍 本文给大家带来的改进机制是最近这几天最新发布的改进机制MFDS-DETR提出的一种HS-FPN结构&#xff0c;其是一种为白细胞检测设计的网络结构&#xff0c;主要用于解决白细胞数据集中的多尺度挑战。它的基本原理包括两个关键部分&#xff1a;特征选择模块和特征融合…

【单调栈】力扣85.最大矩形

好久没更新了 ~ 我又回来啦&#xff01; 两个好消息&#xff1a; 我考上研了&#xff0c;收到拟录取通知啦&#xff01;开放 留言功能 了&#xff0c;小伙伴对于内容有什么疑问可以在文章底部评论&#xff0c;看到之后会及时回复大家的&#xff01; 前面更新过的算法&#x…

《QT实用小工具·三十二》九宫格炫酷主界面

1、概述 源码放在文章末尾 项目实现了九宫格炫酷主界面&#xff0c;下面是项目demo演示&#xff1a; 项目部分代码如下&#xff1a; #pragma execution_character_set("utf-8")#include "frmmain.h" #include "ui_frmmain.h"frmMain::frmMain…

噪声系数测试之增益法

提到增益法测试噪声系数,大家并不陌生,这是一种简洁的测试方法,精度不如Y因子法,但是在某些测试场合,比如只有频谱仪而没有噪声头时,且待测件具有非常高的增益时,就可以使用增益法测试噪声系数。 增益法测试噪声系数的连接示意图如图1所示,其思路为:DUT输入端端接50 …

jsoup接收429,404错误用来接收json格式

1.代码用例 try { // 拿到当前剩余余下的钱Document doc Jsoup.connect(url).header("Authorization", "Bearer " apiKey).header("Content-Type", "application/json").header("Connection", "keep-aliv…