LeetCode - Google 大模型校招10题 第1天 Attention 汇总 (3题)

欢迎关注我的CSDN:https://spike.blog.csdn.net/
本文地址:https://spike.blog.csdn.net/article/details/145368666


GQA
GroupQueryAttention(分组查询注意力机制) 和 KVCache(键值缓存) 是大语言模型中的常见架构,GroupQueryAttention 是注意力机制的变体,通过将查询(Query)分组,每组与相同的键(Key)值(Value)交互,优化计算效率和性能,保持模型对于输入信息有效关注,减少计算资源的消耗,适用于处理大规模数据和复杂任务的场景。KVCache 是缓存机制,用于存储和快速检索键值对(KV),当模型处理新的输入(Q)时,直接从缓存中读取KV数据,无需重新计算,显著提高模型的推理速度和效率。GQA 与 KVCache 在提升模型性能和优化资源利用方面,都发挥着重要作用,结合使用可以进一步增强模型在实际应用中的表现。

从 MHA 到 GQA,再到 GQA+KVCache,简单实现,参考:

  • GQA:从头实现 LLaMA3 网络与推理流程
  • KVCache:GPT(Decoder Only) 类模型的 KV Cache 公式与原理

Scaled Dot-Product Attention (缩放点积注意力机制),也称单头自注意力机制,公式:
A t t e n t i o n ( Q , K , V ) = s o f t m a x ( Q K ⊤ d k ) V Attention(Q,K,V)=softmax(\frac{QK^{\top}}{\sqrt{d_{k}}})V Attention(Q,K,V)=softmax(dk QK)V

1. MultiHeadAttention

MultiHeadAttention (多头注意力机制),合计 43 行:

  1. __init__ 初始化 (10行):
    • 输入:heads(头数)、d_model(维度)、dropout (用于 scores)
    • 计算 d_k 每个 Head 的维度,即 d m o d e l = h e a d s × d k d_{model} = heads \times d_{k} dmodel=heads×dk
    • 线性层是 QKVO,Dropout 层
  2. attention 注意力 (10行):
    • q q q 的维度 [bs,h,s,d],与 k ⊤ k^{\top} k[bs,h,d,s],mm 之后 scores 是 [bs,h,s,s]
    • mask 的维度是 [bs,s,s],使用 unsqueeze(1),转换成 [bs,1,s,s]
    • QKV 的计算,额外支持 Dropout
  3. forward 推理 (12行):
    • QKV Linear 转换成 [bs,s,h,dk],再转换 [bs,h,s,dk]
    • 计算 attn 的 [bs,h,s,dk]
    • 转换 [bs,s,h,dk],再 contiguous(),再 合并 h × d k = d h \times d_{k} = d h×dk=d
    • 再过 O
  4. 测试 (11行):
    • torch.randn 构建数据
    • Mask 的 torch.tril(torch.ones(bs, s, s))

即:

import math
import torch
import torch.nn.functional as F
from torch import nn
class MultiHeadAttention(nn.Module):"""多头自注意力机制 MultiHeadAttention"""def __init__(self, heads, d_model, dropout=0.1):  # 10行super().__init__()self.d_model = d_modelself.d_k = d_model // headsself.h = headsself.q_linear = nn.Linear(d_model, d_model)self.k_linear = nn.Linear(d_model, d_model)self.v_linear = nn.Linear(d_model, d_model)self.out = nn.Linear(d_model, d_model)self.dropout = nn.Dropout(dropout)@staticmethoddef attention(q, k, v, d_k, mask=None, dropout=None):  # 10行scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)# 掩盖掉那些为了填补长度增加的单元,使其通过 softmax 计算后为 0if mask is not None:mask = mask.unsqueeze(1)scores = scores.masked_fill(mask == 0, -1e9)scores = F.softmax(scores, dim=-1)if dropout is not None:scores = dropout(scores)output = torch.matmul(scores, v)return outputdef forward(self, q, k, v, mask=None):  # 12行bs = q.size(0)# 进行线性操作划分为成 h 个头k = self.k_linear(k).view(bs, -1, self.h, self.d_k)q = self.q_linear(q).view(bs, -1, self.h, self.d_k)v = self.v_linear(v).view(bs, -1, self.h, self.d_k)# 矩阵转置k = k.transpose(1, 2)  # [bs,h,s,d] = [2, 8, 10, 64]q = q.transpose(1, 2)v = v.transpose(1, 2)# 计算 attentionattn = self.attention(q, k, v, self.d_k, mask, self.dropout)print(f"[Info] attn: {attn.shape}")# 连接多个头并输入到最后的线性层concat = attn.transpose(1, 2).contiguous().view(bs, -1, self.d_model)output = self.out(concat)return output
def main():# 设置超参数bs, s, h, d = 2, 10, 8, 512dropout_rate = 0.1# 创建 MultiHeadAttention 实例attention = MultiHeadAttention(h, d, dropout_rate)# 创建随机输入张量q = torch.randn(bs, s, d)k = torch.randn(bs, s, d)v = torch.randn(bs, s, d)# 可选:创建掩码,因果掩码,上三角矩阵mask = torch.tril(torch.ones(bs, s, s))# 测试无掩码的情况output_no_mask = attention(q, k, v)print("Output shape without mask:", output_no_mask.shape)# 测试有掩码的情况output_with_mask = attention(q, k, v, mask)print("Output shape with mask:", output_with_mask.shape)# 检查输出是否符合预期assert output_no_mask.shape == (bs, s, d), "Output shape is incorrect without mask"assert output_with_mask.shape == (bs, s, d), "Output shape is incorrect with mask"print("Test passed!")
if __name__ == '__main__':main()

2. GroupQueryAttention

GroupQueryAttention (分组查询注意力机制),相比于 MHA,参考 torch.nn.functional.scaled_dot_product_attention

  1. __init__ :增加参数 kv_heads,即 KV Head 数量,KV 的 Linear 层输出维度(kv_heads * self.d_k)也需要修改。
  2. forward:使用 repeat_interleave 扩充 KV 维度,其他相同,增加 3 行。

即:

import math
import torch
import torch.nn.functional as F
from torch import nn
class GroupQueryAttention(nn.Module):"""分组查询注意力机制(Group Query Attention)"""def __init__(self, heads, d_model, kv_heads, dropout=0.1):super().__init__()self.d_model = d_modelself.d_k = d_model // headsself.h = headsself.kv_heads = kv_headsself.q_linear = nn.Linear(d_model, d_model)self.k_linear = nn.Linear(d_model, kv_heads * self.d_k)self.v_linear = nn.Linear(d_model, kv_heads * self.d_k)self.out = nn.Linear(d_model, d_model)self.dropout = nn.Dropout(dropout)@staticmethoddef attention(q, k, v, d_k, mask=None, dropout=None):# [2, 8, 10, 64] x [2, 8, 64, 10] = [2, 8, 10, 10]scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)# 掩盖掉那些为了填补长度增加的单元,使其通过 softmax 计算后为 0if mask is not None:mask = mask.unsqueeze(1)scores = scores.masked_fill(mask == 0, -1e9)scores = F.softmax(scores, dim=-1)if dropout is not None:scores = dropout(scores)output = torch.matmul(scores, v)return outputdef forward(self, q, k, v, mask=None):bs = q.size(0)# 进行线性操作q = self.q_linear(q).view(bs, -1, self.h, self.d_k)  # [2, 10, 8, 64]k = self.k_linear(k).view(bs, -1, self.kv_heads, self.d_k)  # [2, 10, 4, 64]v = self.v_linear(v).view(bs, -1, self.kv_heads, self.d_k)# 复制键值头以匹配查询头的数量group = self.h // self.kv_headsk = k.repeat_interleave(group, dim=2)  # [2, 10, 4, 64] -> [2, 10, 8, 64]v = v.repeat_interleave(group, dim=2)# 矩阵转置, 将 head 在前k = k.transpose(1, 2)  # [2, 8, 10, 64]q = q.transpose(1, 2)v = v.transpose(1, 2)# 计算 attentionattn = self.attention(q, k, v, self.d_k, mask, self.dropout)# 连接多个头并输入到最后的线性层concat = attn.transpose(1, 2).contiguous().view(bs, -1, self.d_model)output = self.out(concat)return output
def main():# 设置超参数, GQA 8//4=2组bs, s, h, d, kv_heads = 2, 10, 8, 512, 4dropout_rate = 0.1# 创建 MultiHeadAttention 实例attention = GroupQueryAttention(h, d, kv_heads, dropout_rate)# 创建随机输入张量q = torch.randn(bs, s, d)k = torch.randn(bs, s, d)v = torch.randn(bs, s, d)# 可选:创建掩码,因果掩码,上三角矩阵mask = torch.tril(torch.ones(bs, s, s))# 测试无掩码的情况output_no_mask = attention(q, k, v)print("Output shape without mask:", output_no_mask.shape)# 测试有掩码的情况output_with_mask = attention(q, k, v, mask)print("Output shape with mask:", output_with_mask.shape)# 检查输出是否符合预期assert output_no_mask.shape == (bs, s, d), "Output shape is incorrect without mask"assert output_with_mask.shape == (bs, s, d), "Output shape is incorrect with mask"print("Test passed!")
if __name__ == '__main__':main()

3. GQA + KVCache

GroupQueryAttention + KVCache,相比于 GQA,增加 KVCache:

  1. forward :增加参数 kv_cache,合并 [cached_k, new_k],同时返回 new_kv_cache,用于迭代,增加 5 行。
  2. 设置 cur_qkvcur_mask,迭代序列s维度,合计 8 行。

即:

import math
import torch
import torch.nn.functional as F
from torch import nn
class GroupQueryAttention(nn.Module):"""分组查询注意力机制(Group Query Attention)"""def __init__(self, heads, d_model, kv_heads, dropout=0.1):super().__init__()self.d_model = d_modelself.d_k = d_model // headsself.h = headsself.kv_heads = kv_headsself.q_linear = nn.Linear(d_model, d_model)self.k_linear = nn.Linear(d_model, kv_heads * self.d_k)self.v_linear = nn.Linear(d_model, kv_heads * self.d_k)self.out = nn.Linear(d_model, d_model)self.dropout = nn.Dropout(dropout)@staticmethoddef attention(q, k, v, d_k, mask=None, dropout=None):# [2, 8, 1, 64] x [2, 8, 64, 10] = [2, 8, 1, 10]scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)# 掩盖掉那些为了填补长度增加的单元,使其通过 softmax 计算后为 0if mask is not None:mask = mask.unsqueeze(1)scores = scores.masked_fill(mask == 0, -1e9)scores = F.softmax(scores, dim=-1)if dropout is not None:scores = dropout(scores)output = torch.matmul(scores, v)return outputdef forward(self, q, k, v, mask=None, kv_cache=None):bs = q.size(0)# 进行线性操作q = self.q_linear(q).view(bs, -1, self.h, self.d_k)  # [2, 1, 8, 64]new_k = self.k_linear(k).view(bs, -1, self.kv_heads, self.d_k)  # [2, 1, 4, 64]new_v = self.v_linear(v).view(bs, -1, self.kv_heads, self.d_k)  # [2, 1, 4, 64]# 处理 KV Cacheif kv_cache is not None:cached_k, cached_v = kv_cachenew_k = torch.cat([cached_k, new_k], dim=1)new_v = torch.cat([cached_v, new_v], dim=1)# 复制键值头以匹配查询头的数量group = self.h // self.kv_headsk = new_k.repeat_interleave(group, dim=2)  # [2, 10, 4, 64] -> [2, 10, 8, 64]v = new_v.repeat_interleave(group, dim=2)# 矩阵转置, 将 head 在前# KV Cache 最后1轮: q—>[2, 8, 1, 64] k->[2, 8, 10, 64] v->[2, 8, 10, 64]k = k.transpose(1, 2)  # [2, 8, 10, 64]q = q.transpose(1, 2)v = v.transpose(1, 2)# 计算 attentionattn = self.attention(q, k, v, self.d_k, mask, self.dropout)  # [2, 8, 1, 64]print(f"[Info] attn: {attn.shape}")# 连接多个头并输入到最后的线性层concat = attn.transpose(1, 2).contiguous().view(bs, -1, self.d_model)output = self.out(concat)# 更新 KV Cachenew_kv_cache = (new_k, new_v)  # 当前的 KV 缓存return output, new_kv_cache
def main():# 设置超参数bs, s, h, d, kv_heads = 2, 10, 8, 512, 4dropout_rate = 0.1# 创建 GroupQueryAttention 实例attention = GroupQueryAttention(h, d, kv_heads, dropout_rate)# 创建随机输入张量q = torch.randn(bs, s, d)k = torch.randn(bs, s, d)v = torch.randn(bs, s, d)# 可选:创建掩码,因果掩码,上三角矩阵mask = torch.tril(torch.ones(bs, s, s))# 模拟逐步生成序列,测试 KV Cacheprint("Testing KV Cache...")kv_cache, output = None, Nonefor i in range(s):cur_q = q[:, i:i+1, :]cur_k = k[:, i:i+1, :]cur_v = v[:, i:i+1, :]cur_mask = mask[:, i:i+1, :i+1]   # q是 i:i+1,k是 :i+1output, kv_cache = attention(cur_q, cur_k, cur_v, cur_mask, kv_cache)print(f"Output shape at step {i}:", output.shape)# 检查输出是否符合预期assert output.shape == (bs, 1, d), "Output shape is incorrect when using KV Cache"print("Test passed!")
if __name__ == "__main__":main()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/68275.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Linux之Tcp粘包笔记

目录 一.网络传输四层模型 二.数据传输中数据包传输的两个限制概念 三.数据传输的中粘包问题 四.数据组装的原因 Nagle算法原理: 五.关闭Nagle优化处理粘包问题吗? 六.粘包处理方法 a.设置消息边界: b.定义消息长度: 七.UDP是否会出…

websocket实现

由于安卓资源管理器展示的路径不尽相同,各种软件保存文件的位置也不一定一样.对于普通用户上传文件时,查找文件可能是一个麻烦的事情.后来想到了一个办法,使用pc端进行辅助上传. 文章目录 实现思路1.0 实现定义web与客户端通信数据类型和数据格式web端websocket实现web端对客户…

什么是反向海淘?如何入局反向海淘?

什么是反向海淘? 简单来说,反向海淘就是海外消费者通过国内的电商平台或独立站买入中国商品,然后通过跨境物流送到海外。以前是我们在国内买国外的东西,现在反过来,老外开始疯狂种草咱们的国货啦! 为什么反…

leetcode刷题记录(一百)——121. 买卖股票的最佳时机

(一)问题描述 121. 买卖股票的最佳时机 - 力扣(LeetCode)121. 买卖股票的最佳时机 - 给定一个数组 prices ,它的第 i 个元素 prices[i] 表示一支给定股票第 i 天的价格。你只能选择 某一天 买入这只股票,并…

荔枝派LicheePi Zero V3S芯片图形系统开发详解[持续更新]

一、图形子系统 一般移动Linux设备实现图像显示的方案无非两种: 一种是使用OpenGL ES 另外一种就是使用FrameBuffer 使用OpenGL有个前提就是这个芯片是需要有GPU的,不然是没有意义的。 查看芯片系统框图,注意到V3S这款芯片是不支持GPU的…

AI刷题-最小化团建熟悉程度和

目录 问题描述 输入格式 输出格式 解题思路: 状态表示 状态转移 动态规划数组 预处理 实现: 1.初始化: 2.动态规划部分: (1)对于已分组状态的,跳过: (2&…

使用Python和Qt6创建GUI应用程序---GUI的一个非常简短的历史

GUI的一个非常简短的历史 图形用户界面有着悠久而可敬的历史,可以追溯到20世纪60年代。斯坦福大学的NLS(在线系统)引入了鼠标和Windows概念于1968年首次公开展示。接下来是施乐PARC的Smalltalk系统GUI 1973,这是最现代的基础通用g…

DroneXtract:一款针对无人机的网络安全数字取证工具

关于DroneXtract DroneXtract是一款使用 Golang 开发的适用于DJI无人机的综合数字取证套件,该工具可用于分析无人机传感器值和遥测数据、可视化无人机飞行地图、审计威胁活动以及提取多种文件格式中的相关数据。 功能介绍 DroneXtract 具有四个用于无人机取证和审…

day7手机拍照装备

对焦对不上:1、光太暗;2、离太近;3、颜色太单一没有区分点 滤镜可以后期P 渐变灰滤镜:均衡色彩,暗的地方亮一些,亮的地方暗一些 中灰滤镜:减少光差 手机支架:最基本70cm即可 手…

【从零到一,C++项目实战】CineShare++(基于C++的视频点播系统)

🌈个人主页: 南桥几晴秋 🌈C专栏: 南桥谈C 🌈C语言专栏: C语言学习系列 🌈Linux学习专栏: 南桥谈Linux 🌈数据结构学习专栏: 数据结构杂谈 🌈数据…

RabbitMQ 架构分析

文章目录 前言一、RabbitMQ架构分析1、Broker2、Vhost3、Producer4、Messages5、Connections6、Channel7、Exchange7、Queue8、Consumer 二、消息路由机制1、Direct Exchange2、Topic Exchange3、Fanout Exchange4、Headers Exchange5、notice5.1、备用交换机(Alter…

九、CSS工程化方案

一、PostCSS介绍 二、PostCSS插件的使用 项目安装 - npm install postcss-cli 全局安装 - npm install postcss-cli -g postcss-cli地址:GitHub - postcss/postcss-cli: CLI for postcss postcss地址:GitHub - postcss/postcss: Transforming styles…

SpringBoot开发(二)Spring Boot项目构建、Bootstrap基础知识

1. Spring Boot项目构建 1.1. 简介 基于官方网站https://start.spring.io进行项目的创建. 1.1.1. 简介 Spring Boot是基于Spring4框架开发的全新框架,设计目的是简化搭建及开发过程,并不是对Spring功能上的增强,而是提供了一种快速使用Spr…

GESP2024年3月认证C++六级( 第三部分编程题(2)好斗的牛)

参考程序&#xff08;暴力枚举&#xff09; #include <iostream> #include <vector> #include <algorithm> using namespace std; int N; vector<int> a, b; int ans 1e9; int main() {cin >> N;a.resize(N);b.resize(N);for (int i 0; i &l…

SpringBoot统一数据返回格式 统一异常处理

统一数据返回格式 & 统一异常处理 1. 统一数据返回格式1.1 快速入门1.2 存在问题1.3 案列代码修改1.4 优点 2. 统一异常处理 1. 统一数据返回格式 强制登录案例中,我们共做了两部分⼯作 通过Session来判断⽤⼾是否登录对后端返回数据进⾏封装,告知前端处理的结果 回顾 后…

Elasticsearch+kibana安装(简单易上手)

下载ES( Download Elasticsearch | Elastic ) 将ES安装包解压缩 解压后目录如下: 修改ES服务端口&#xff08;可以不修改&#xff09; 启动ES 记住这些内容 验证ES是否启动成功 下载kibana( Download Kibana Free | Get Started Now | Elastic ) 解压后的kibana目…

十年筑梦,再创鲸彩!庆祝和鲸科技十周年

2025 年 1 月 16 日&#xff0c;“十年筑梦&#xff0c;再创鲸彩” 2025 和鲸科技十周年庆暨 2024 年终表彰大会圆满落幕。 十年征程&#xff0c;和鲸科技遨游于科技蓝海&#xff0c;破浪前行&#xff0c;无惧风雨。期间所取得的每一项成就&#xff0c;都凝聚着全体成员的智慧结…

【Uniapp-Vue3】动态设置页面导航条的样式

1. 动态修改导航条标题 uni.setNavigationBarTitle({ title:"标题名称" }) 点击修改以后顶部导航栏的标题会从“主页”变为“动态标题” 2. 动态修改导航条颜色 uni.setNavigationBarColor({ backgroundColor:"颜色" }) 3. 动态添加导航加载动画 // 添加加…

openlayer getLayerById 根据id获取layer图层

背景&#xff1a; 在项目中使用getLayerById获取图层&#xff0c;这个getLayerById()方法不是openlayer官方文档自带的&#xff0c;而是自己封装的一个方法&#xff0c;这个封装的方法的思路是&#xff1a;遍历所有的layer&#xff0c;根据唯一标识【可能是id&#xff0c;也可能…

Unity入门2 背景叠层 瓦片规则

切割场景 瓦片调色盘 放在Assets里面新建瓦片地图,palettes tile 瓦片 palettes调色板 上下窗口是分开的 拖进这个格子窗 瓦片太碎&#xff0c;要封装 装好之后&#xff0c;只是把瓦片放上去了&#xff0c;但是还没有画布&#xff0c;显示是这样的 no valid target 新建“…