Group Query Attention (GQA) 机制详解以及手动实现计算

Group Query Attention (GQA) 机制详解

1. GQA的定义

Grouped-Query Attention (GQA) 是对 Multi-Head Attention (MHA) 和 Multi-Query Attention (MQA) 的扩展。通过提供计算效率和模型表达能力之间的灵活权衡,实现了查询头的分组。GQA将查询头分成了G个组,每个组共享一个公共的键(K)和值(V)投影。

2. GQA的变体

GQA有三种变体:

  • GQA-1:一个单独的组,等同于 Multi-Query Attention (MQA)。
  • GQA-H:组数等于头数,基本上与 Multi-Head Attention (MHA) 相同。
  • GQA-G:一个中间配置,具有G个组,平衡了效率和表达能力。
3. GQA的优势

使用G个组可以减少存储每个头的键和值所需的内存开销,特别是在具有大的上下文窗口或批次大小的情况下。GQA提供了对模型质量和效率的细致控制。

4. GQA的实现

GQA的最简形式可以通过实现 GroupedQueryAttention 类来实现。GroupedQueryAttention 类继承自 Attention 类,重写了 forward 方法,其中使用了 MultiQueryAttention 类的实例来处理每个组的查询。通过将每个组的结果拼接起来,然后与投影矩阵进行矩阵乘法运算,最终得到 GQA 的输出。[1]

pytorch 示例实现:

假设我们有以下初始化的query, key, value:

# shapes: (batch_size, seq_len, num_heads, head_dim)
query = torch.randn(1, 256, 8, 64)
key = torch.randn(1, 256, 2, 64)
value = torch.randn(1, 256, 2, 64)
1. 确定分组数量

首先,我们需要确定将查询头分为多少组。在这个例子中,我们有8个查询头,而键和值的头数为2,所以我们可以将查询头分为4组,每组有2个查询头。

2. 对查询进行分组

然后,我们将查询头分组。我们可以使用 torch.chunk 函数将查询张量沿着头维度分割成4个组,每个组有2个头。

query_groups = torch.chunk(query, 4, dim=2)  # shape of each group: (1, 256, 2, 64)
3. 计算注意力分数

对于每一个查询组,我们计算它与键的注意力分数。我们首先计算查询组和键的点积,然后通过 torch.softmax 函数得到注意力分数。

attention_scores = []
for query_group in query_groups:score = torch.matmul(query_group, key.transpose(-2, -1))  # shape: (1, 256, 2, 256)score = torch.softmax(score, dim=-1)attention_scores.append(score)
4. 计算注意力输出

接下来,我们使用注意力分数对值进行加权求和,得到每一个查询组的注意力输出。

attention_scores = []
for query_group in query_groups:score = torch.matmul(query_group, key.transpose(-2, -1))  # shape: (1, 256, 2, 256)score = torch.softmax(score, dim=-1)attention_scores.append(score)
5. 拼接输出

最后,我们将所有查询组的注意力输出拼接起来,得到最终的 Grouped Query Attention 的输出。

attention_outputs = []
for score in attention_scores:output = torch.matmul(score, value)  # shape: (1, 256, 2, 64)attention_outputs.append(output)

这就是 Grouped Query Attention 的实现过程。在这个过程中,我们将查询头分组,然后对每一个查询组分别计算注意力分数和输出,最后将所有查询组的输出拼接起来。这样可以减少存储每个头的键和值所需的内存开销,特别是在具有大的上下文窗口或批次大小的情况下。


  1. Grouped-Query Attention (GQA) - The Large Language Model Playbook

  2. 安全验证 - 知乎
  3. 安全验证 - 知乎
  4. 安全验证 - 知乎
  5. Grouped-Query Attention (GQA) - The Large Language Model Playbook

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/681.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

事务的传播行为介绍和事务失效

常用的就下图介绍的这两种,REQUIRED 支持当前事务,如果不存在,就新建一个,EQUIRES_NEW 如果有事务存在,挂起当前事务,创建一个新的事务 同一个service中必须用代理对象调用,否则失效

使用go和消息队列优化投票功能

文章目录 1、优化方案与主要实现代码1.1、原系统的技术架构1.2、新系统的技术架构1.3、查看和投票接口实现1.4、数据入库MySQL协程实现1.5、路由配置1.6、启动程序入口实现 2、压测结果2.1、设置Jmeter线程组2.2、Jmeter聚合报告结果,支持11240/秒吞吐量2.3、Jmeter…

java中MD5加密

MD5加密 MD5加密是不可逆的加密方式,A可以根据MD5加密转换成B,但是B不能再转换成A String passwordDigestUtils.md5DigestAsHex(password.getBytes());完成密码的加密

【情侣博客网站】

效果图 PC端 建塔教程 第一步:下载网站源码(在文章下方有下载链接) 第二步:上传到服务器或虚拟主机,解压。 第三步:这一步很关键,数据库进行连接,看图 admin/connect.php就是这…

[Android]Jetpack Compose设置颜色

在 Kotlin 和 Jetpack Compose 中设置颜色是一个非常直接的过程,涉及到使用 Color 类来定义和使用颜色。 Jetpack Compose 提供了多种方式来定义和应用颜色,包括预定义颜色、RGB 值、十六进制值等。下面是一些常用的设置颜色的方法: 1. 使用…

python-基础(4)-list

python专栏地址 上一篇:python-基础(3)-字符串操作 List结构 本节将学习以下内容 list初识list的操作 一、List初识 创建 通过[]/list([])创建 ,两者的区别可以参考python中用list和中括号创建列表有什么区别?(在创建时相同,但一个的实质时…

Spring Data Jpa的save方法更新未传值的字段被更新为空的处理方法

Spring Data Jpa的save()方法通过主键是否为空来判断insert或是update操作,但更新方法和以往使用的mybatis-plus存在一定的差异,特别记录处理方法。 Resourceprivate Dao dao;/*** 更新操作* param data 前端传入存在更新的字段值的对象*/public void up…

pat乙-1020月饼

贪心:既然有存货量一定,利润要最高; 这个贪心就在于我看“单价”最高,这个单价也是要把存货量算进去的,所以按“单价”排序,再遍历,优选选择“单价”最高的,不够的再补,…

el-menu 该有的页面显示不出来第一个应该想到的问题首先就算检查是否多写了一个 , 导致显示不出来原有的页面

问题描述 el-menu 该有的页面显示不出来第一个应该想到的问题首先就算检查是否多写了一个 , 导致显示不出来原有的页面 如图所示多写了一个,就会导致该有的页面显示不出来。

Python 天气预测

Python天气预测通常涉及到数据采集、数据预处理、选择和训练模型、以及预测和可视化等步骤。以下是使用Python进行天气预测的一般流程: 数据采集 使用爬虫技术从天气网站(如Weather Underground、中国天气网等)爬取历史天气数据&#xff0c…

my.cnf配置文件调优

mysql数据库的性能调优首先要考虑的就是表结构设计,一个糟糕的设计模式即使在性能强劲的服务器上运行时,也会表现得很差。与设计模式相似,查询语句也会影响mysql的性能,应该避免写出低效的sql查询语句。最后要考虑的就是参数优化,mysql数据库默认设置的性能非常差,只能起…

LeetCode 628. 三个数的最大乘积 java版

1. 官网: . - 力扣(LeetCode) 2. 题目: 给你一个整型数组 nums ,在数组中找出由三个数组成的最大乘积,并输出这个乘积。 示例 1: 输入:nums [1,2,3] 输出:6示例 2&…

Java中的递归方法:初学者的简明指南

Java中的递归方法:初学者的简明指南 递归是编程中的一个重要概念,它指的是一个方法直接或间接地调用自身。递归方法在处理某些问题时,特别是那些可以分解为更小、更简单的子问题时,非常有用。虽然递归的概念初看起来可能有些复杂…

QEMU_v8搭建OP-TEE运行环境

文章目录 一、依赖下载二、设置网络三、安装下载四、运行OP-TEE 一、依赖下载 更新依赖包,下载一系列依赖。比如Python需要Python3.x版本,需要配置git的用户名和邮箱等。这里不详细展开了,很多博客都有涉及到。 二、设置网络 这一点非常重…

【nginx代理和tengine的启动-重启等命令】

在nginx成功启动后[任务管理器有nginx.exe进程],运行vue项目,在浏览器访问http://localhost:10001/,提示:访问拒绝(调试中network某些地址403); 解决方案: localhost改为ip&#xff…

已解决:用 Pyinstaller 模块将整个GUI项目打包成 单独exe 文件、移到没有python的环境依旧可以运行(全网完美解决的方法)

1.环境 我的项目是pyside2实现的GUI项目,涉及文件读写、panda操作等,项目包括ui文件、ico文件、项目模块(项目中有多个不同的模块、每个模块里面有代码)。项目内容是可以使用该GUI框架对别的数据文件进行读取、加工处理、保存等一系列操作。win11+python3.8+pyside2+其它库…

【Flutter】GetX状态管理及路由管理用法

目录 一、安装二、使用1.安装GetX插件,快捷生成模版代码2.主入口MaterialApp改成GetMaterialApp3.定义路由常量RoutePath类、别名映射页面RoutePages类4. 初始initialRoute,getPages。5.调用 总结 一、安装 dependencies: get: ^4.6.6二、使用 1.安装G…

MDK-ARM Keil5.38 下载安装环境搭建

一、keil软件介绍 KEIL是公司的名称,有时候也指KEIL公司的所有软件开发工具,目前2005年Keil由ARM公司收购,成为ARM的公司之一。 MDK(Microcontroller Development Kit) 也称MDK-ARM、KEIL MDK、RealView MDK、KEIL For…

[最新]访问/加速StackOverFlow的方法

但是有很多问题都是在StackOverFlow上有现成的解决方案,而某度搜索引擎…前一页的回答互相抄袭,看着实在胀眼睛。 话不多说,解决办法: 直接访问插件商店下载插件(最快捷方便,点点就行)&#x…

Python中的迭代器:深入理解与实用指南

文章目录 1. 迭代器的基本概念2. Python中的迭代器实例3. 自定义迭代器3.1 例子3.2 详细过程 4. 迭代器的高级应用5. 常见问题与解答 迭代器是Python中非常核心的概念之一,在面试中也会被问到。下面我会详细介绍什么是迭代器,使用方法,以及使…