Multi-Query Attention (MQA) PyTorch 实现

和多头注意力机制的唯一区别:K、V在不同的head之间实现了复用,而对于不同的头,Q依然不同。

因此这里的代码和标准多头注意力的实现也是几乎完全一样:

import torch
import torch.nn as nn
import torch.nn.functional as Fclass MultiQueryAttention(nn.Module):def __init__(self, embed_dim, num_heads):super().__init__()self.num_heads = num_headsself.head_dim = embed_dim // num_headsself.scale = self.head_dim ** -0.5# 查询、键、值投影self.q_proj = nn.Linear(embed_dim, embed_dim)  # 多头查询self.k_proj = nn.Linear(embed_dim, self.head_dim)  # 单头键self.v_proj = nn.Linear(embed_dim, self.head_dim)  # 单头值self.out_proj = nn.Linear(embed_dim, embed_dim)def forward(self, x):batch_size, seq_len, embed_dim = x.shape# 投影q = self.q_proj(x)  # (batch, seq_len, embed_dim)k = self.k_proj(x)  # (batch, seq_len, head_dim)v = self.v_proj(x)  # (batch, seq_len, head_dim)# 重塑查询为多头q = q.reshape(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)# (batch, num_heads, seq_len, head_dim)# 键和值保持单头,扩展到多头维度k = k.unsqueeze(1)  # (batch, 1, seq_len, head_dim)v = v.unsqueeze(1)  # (batch, 1, seq_len, head_dim)# 注意力计算scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale# (batch, num_heads, seq_len, seq_len)attn = F.softmax(scores, dim=-1)out = torch.matmul(attn, v)  # (batch, num_heads, seq_len, head_dim)# 合并多头out = out.transpose(1, 2).reshape(batch_size, seq_len, embed_dim)out = self.out_proj(out)  # (batch, seq_len, embed_dim)return out# 示例用法
embed_dim = 64
num_heads = 8
model = MultiQueryAttention(embed_dim, num_heads)
x = torch.randn(2, 10, embed_dim)  # (batch, seq_len, embed_dim)
output = model(x)
print(output.shape)  # torch.Size([2, 10, 64])

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/902162.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

visual studio无法跳转到函数定义、变量定义、跳转函数位置不准问题解决

参考:https://blog.csdn.net/snakehacker/article/details/135438353 程序有时会出现大部分函数都不能准确的从头文件中正确定位到函数定位,这是因为数据库错乱造成的,可以通过重构数据库来解决,操作方法如下: 菜单栏:工具——选项 文本编辑…

Java优雅实现判空方法

在 Java 开发中,频繁的 if (obj ! null) 判空代码会导致代码冗余、可读性差,且容易遗漏判空导致 NullPointerException。以下从 语言特性、设计模式、工具类 和 编码规范 四个维度,结合实际案例,详解如何优雅处理空值问题。 一、…

京东百亿补贴杀入外卖市场:一场关乎即时零售未来的攻防战

当美团和饿了么在外卖市场双雄争霸十余年之际,京东突然以"百亿补贴免佣金"的组合拳高调入场。这场看似跨界的外卖大战,实则是互联网巨头对万亿级即时零售市场的生死争夺。 外卖只是表象,即时零售才是终极战场 京东黑板报4月10日官…

UNION和UNION ALL的主要区别

UNION和UNION ALL的主要区别在于处理重复数据和排序的方式。 UNION和UNION ALL都是SQL语言中用于合并两个或多个SELECT语句结果集的关键字。它们的主要区别如下: 1、对重复结果的处理:UNION在进行表链接后会筛选掉重复的记录,而UNION ALL不会…

七段码 路径压缩 并查集 dfs

12.七段码 - 蓝桥云课 将七个二极管映射为 1-7 开一个二维矩阵 为 相邻的边连上线 edge[1][2] edge[1][6] 1;edge[2][1] edge[2][3] edge[2][7] 1;edge[3][2] edge[3][4] edge[3][7] 1;edge[4][3] edge[4][5] 1;edge[5][4] edge[5][6] edge[5][7] 1;edge[6][1…

科技如何改变世界?

技术是我们日常生活中不可或缺的一部分,以至于我们常常忘记了它的重要性。如果你正在科技领域工作,或者希望进入该领域,你可能是众多有使命感的人之一,希望知道自己的日常工作能为社会或地球的长远利益做出贡献。 别再四处寻找了…

抽象的https原理简介

前言 小明和小美是一对好朋友,他们分隔两地,平时经常写信沟通,但是偶然被小明发现他回给小美的信好像被人拆开看过,甚至偷偷被篡改过。 对称加密算法 开头的通信过程比较像HTTP服务器与客户端的通信过程,全明文传输…

高级java每日一道面试题-2025年4月13日-微服务篇[Nacos篇]-Nacos如何处理网络分区情况下的服务可用性问题?

如果有遗漏,评论区告诉我进行补充 面试官: Nacos如何处理网络分区情况下的服务可用性问题? 我回答: 在讨论 Nacos 如何处理网络分区情况下的服务可用性问题时,我们需要深入理解 CAP 理论以及 Nacos 在这方面的设计选择。Nacos 允许用户根据具体的应用…

python解压文件 zip tar.gz tar.xz

以下代码为解压zip包 tar包文件 zip_path:文件绝对路径 output_folder:文件解压后存放的文件夹路径 def extract_file(zip_path, output_folder):# 支持解压zip tar tar.gz tar.xz .tar.bz2# 确保输出文件夹存在os.makedirs(output_folder, exist_okT…

网络基础(协议,地址,OSI模型、Socket编程......)

目录 一、计算机网络发展 二、协议 1.认识协议 2.OSI七层模型 3.TCP/IP 五层(或四层)模型 4.协议本质 三、网络传输流程 1.MAC地址 2.协议栈 3.IP地址 IP地址 vs MAC地址 1. 核心区别 2. 具体通信过程类比 3. 关键总结 为什么需要两者? 4.协议栈图解…

生成式AI对话中提示词策略:明确问题、明确目标和提供背景信息是最有效的策略

生成式AI对话中提示词策略:明确问题、明确目标和提供背景信息是最有效的策略 最有效的提示词策略包括明确问题、明确目标和提供背景信息。普适性有效提示词策略可分为三类:明确需求与精确指引型、清晰解释与逻辑排序型、拆解任务与多样化表达型。[局限]数据来源于中国用户,…

AtCoder ABC402 ABCD

A - CBC 把大写字母按顺序连起来 B - Restaurant Queue 一眼队列,stl模拟就行 C - Dislike Foods 显然,每次克服暴力枚举每个菜肴会超时。 然而题目中给了每个菜肴的配菜个数,不妨换过来统计每个配菜用在了哪些菜肴。每次克服时&#x…

Transformer 架构 - 解码器 (Transformer Architecture - Decoder)

欢迎回到我们的 Transformer 系列教程!在上一篇中,我们详细探讨了 Transformer 的编码器,它负责将输入的源序列(比如源语言句子)转换为一系列包含丰富上下文信息的向量表示。 现在,我们将把目光投向 Transformer 的另一半——解码器 (Decoder)。解码器负责接收编码器的输…

神经网络与模型训练过程笔记

1.专有名词 ANN 人工神经网络,一种受生物神经元启发的监督学习算法。输入数据通过网络中的层级函数传递,激活特定神经元。函数复杂度越高,模型对数据的拟合能力越强,预测精度越高。 偏置项 其中x下表从1开始的是输入变量&#xf…

【计算机网络 | 第二篇】常见的通信协议(一)

HTTP和HTTPS有什么区别? 端口号:HTTP默认是80端口,HTTPS默认是443。 URL前缀:HTTPHTTP 的 URL 前缀是 http://,HTTPS 的 URL 前缀是 https://。 安全性和资源消耗:HTTP协议运行在TCP上,都是明…

【python实用小脚本系列】用 Python 自己手搓一个给视频“静音”的小脚本,批量处理,轻松高效制作“无声电影”!

嘿,小伙伴们!今天我来给大家介绍一个超实用的 Python 小工具——一个能给视频“静音”的“声音消除器”!是不是听起来很酷?想象一下,你可以把任何有声视频变成无声视频,是不是很有趣?接下来&…

【gpt生成-总览】怎样才算开发了一门编程语言,需要通过什么测试

开发一门真正的编程语言需要经历完整的设计、实现和验证过程,并通过系统的测试体系验证其完备性。以下是分阶段开发标准及测试方法: 一、语言开发核心阶段 1. 语言规范设计(ISO/IEC 标准级别) ​​语法规范​​:BNF/…

leetcode222 完全二叉树的节点个数

完全二叉树 的定义如下:在完全二叉树中,除了最底层节点可能没填满外,其余每层节点数都达到最大值,并且最下面一层的节点都集中在该层最左边的若干位置。若最底层为第 h 层(从第 0 层开始),则该层…

若依集成BladeX单点登录的令牌管理与api请求流程

目录 概述系统架构单点登录流程令牌管理机制接口调用流程关键代码实现数据结构安全性考虑常见问题与解决 概述 本文档详细说明若依系统如何实现与BladeX的单点登录集成,包括令牌管理和接口调用的完整流程。整个集成采用基于OAuth2的授权码流程,允许用…

《AI大模型应知应会100篇》第27篇:模型温度参数调节:控制创造性与确定性

第27篇:模型温度参数调节:控制创造性与确定性 摘要 在大语言模型的使用中,“温度”(Temperature)是一个关键参数,它决定了模型输出的创造性和确定性之间的平衡。通过调整温度参数,您可以根据任…