GQA (group query attention)

什么是GQA?

多个head的Query共用一组K和V。llama模型就用到该技术。

需要明确几点:

1.group有几组

2.每个group对应几个head

3.q以head为单位 k,v以group为单位 每个head/group特征维度都是head_dim

代码实现

import torch.nn as nn
import torch
import math# 自注意力
class GroupQueryAttention(nn.Module):def __init__(self, d_model, n_heads, n_groups):super().__init__()self.d_model = d_modelself.n_heads = n_headsself.n_groups = n_groupsassert d_model % n_heads == 0self.head_dim = d_model // n_heads# 每个group对应几个headself.heads_per_group = n_heads // n_groups# 以head为单位self.w_q = nn.Linear(d_model, n_heads*self.head_dim)  # n_heads*head_dim=d_model# 以group为单位self.w_k = nn.Linear(d_model, n_groups*self.head_dim)self.w_v = nn.Linear(d_model, n_groups*self.head_dim)self.w_combine = nn.Linear(d_model, d_model)self.softmax = nn.Softmax(dim=-1)# 给k,v进行复制,假设每个组对应3个head,那就要把每个组的数据复制3遍def expand(self, data): # data:k/v [b, group, seq_len, head_dim]b,_,seq_len,_ = data.shapedata = data[:,:,None,:,:].expand(b, self.n_groups, self.heads_per_group, seq_len, self.head_dim)data = data.contiguous().view(b, -1, seq_len, self.head_dim)return data  # [b, group*heads_per_group, seq_len, head_dim]def forward(self, x, use_mask=False): # x: [b, seq_len, d_model]b, seq_len, _ = x.shapeq,k,v = self.w_q(x), self.w_k(x), self.w_v(x)q = q.view(b, seq_len, self.n_heads, self.head_dim).permute(0,2,1,3)  # 以head为单位k = k.view(b, seq_len, self.n_groups, self.head_dim).permute(0,2,1,3) # 以group为单位v = v.view(b, seq_len, self.n_groups, self.head_dim).permute(0,2,1,3)# 复制k,v = self.expand(k), self.expand(v)score = q @ k.transpose(-1,-2) / math.sqrt(self.head_dim)if use_mask:mask = torch.tril(torch.ones(seq_len, seq_len))score = score.masked_fill(mask==0, float('-inf'))score = self.softmax(score) @ vscore = score.permute(0,2,1,3).contiguous().view(b, seq_len, self.n_heads*self.head_dim)out = self.w_combine(score)return outx = torch.rand(2,100,384)
model = GroupQueryAttention(384, 4, 2)
out = model(x)
print(out.shape)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/878961.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

串口通信协议(UART)

简介 uart通讯协议,是一种成本低、容易使用、通信线路简单,可实现两个设备的互相通信的协议;是一种全双工,设备点对点通信的协议。下面从硬件电路、电平标准和串口参数等方面来了解uart通信协议。 硬件电路 硬件电路非常简单&am…

深入Redis:强大的主从复制

如果某个服务器或者程序,只有一个节点(服务器),就会有很大的问题。比如可用不高,并发量也比较低。引入分布式系统,也主要是为了解决上述的单点问题。 Redis,主要部署在分布式系统上。在分布式系…

Docker容器相关命令

Docker是一种容器化技术,可以帮助用户更轻松地创建、部署和管理容器。下面是一些常见的Docker容器管理任务: 创建容器:使用Docker镜像创建一个新的容器。 docker run image_name列出容器:查看当前运行的容器列表。 docker ps启动容…

无人机之地面站篇

无人机的地面站,又称无人机控制站,是整个无人机系统的重要组成部分,扮演着作战指挥中心的角色。以下是对无人机地面站的详细阐述: 一、定义与功能 无人机地面站是指具有对无人机飞行平台和任务载荷进行监控和操纵能力的一组设备&…

CCPC网络预选赛感想

背景 断更了几天的比赛题解,是因为去打ccpc预选赛去了) 第一天 第一天是热身赛,就我和t去了,l回去家里取东西。这也是我和t的第一次线下见面吧qwq,很强很帅的一个大一新生(分分钟薄纱我)。 …

MySQL 中的 `TRIM()` 函数:优雅去除字符串两侧的空格

在数据库管理中,数据的准确性和整洁性至关重要。有时,从外部源导入的数据或用户输入的数据可能包含不必要的空格,尤其是在字符串的开头或结尾。这些空格虽然看似微小,但在数据查询、比较或展示时可能会引发问题。幸运的是&#xf…

Postgresql碎片整理

创建pgstattuple 扩展 CREATE EXTENSION pgstattuple 获取表的元组(行)信息,包括空闲空间的比例和行的平均宽度 SELECT * FROM pgstattuple(表名); 查看表和索引大小 SELECT pg_relation_size(表名), pg_relation_size(索引名称); 清理碎片方…

【魔法 / NOI】

题目 思路 动态规划: 状态定义: f [ k ] [ i ] [ j ] 对应使用了不超过 k 次魔法,从 i 到 j 的路径集合 f[k][i][j] 对应使用了不超过k次魔法,从i到j的路径集合 f[k][i][j]对应使用了不超过k次魔法,从i到j的路径集合 状…

vc-align源码分析 -- ant-design-vue系列

vc-align源码分析 源码地址:https://github.com/vueComponent/ant-design-vue/tree/main/components/vc-align 1 基础代码 1.1 名词约定 需要对齐的节点叫source,对齐的目标叫target。 1.2 props 提供了两个参数: align:对…

WPF-快速构建统计表、图表并认识相关框架

一、使用ScottPlot.Wpf 官网地址:https://scottplot.net/quickstart/wpf/ 1、添加NuGet包:ScottPlot.Wpf 2、XAML映射命名空间: xmlns:ScottPlot"clr-namespace:ScottPlot.WPF;assemblyScottPlot.WPF" 3、简单示例:…

2024年测评7款最佳AI论文修改润色平台

在2024年,AI论文修改润色平台的测评和推荐成为学术界和研究者们关注的热点。本文将详细评测并推荐7款最佳AI论文修改润色平台,包括千笔-AIPassPaper,并结合我搜索到的资料进行分析。 一、千笔-AIPassPaper 千笔-AIPassPaper是一款集论文大纲…

【Nginx系列】Nginx中rewrite模块

💝💝💝欢迎来到我的博客,很高兴能够在这里和您见面!希望您在这里可以感受到一份轻松愉快的氛围,不仅可以获得有趣的内容和知识,也可以畅所欲言、分享您的想法和见解。 推荐:kwan 的首页,持续学…

SpringBoot教程(安装篇) | Elasticsearch的安装

SpringBoot教程(安装篇) | Elasticsearch的安装 一、确定Elasticsearch版本二、下载elasticsearch(windows版本)官网下载如何解压配置 允许 别人跨域 访问自己启动运行 三、Es可视化工具安装(elasticsearch-head&#…

JVM 调优篇4 jvm的垃圾回收中垃圾日志的阅读查看2

一 jvm垃圾回收日志 1.1 日志参数 查看垃圾回收日志,可进行日志的设置参数如下: 参数解释-verbose:gc输出gc日志信息,默认输出到标准输出-XX:+PrintGC输出GC日志。类似:-verbose:gc-XX:+PrintGCDetails在发生来及回收时,打印内存回收详细日志,并在进程退出是,输出当前…

DDS基本原理--FPGA学习笔记

DDS信号发生器原理: timescale 1ns / 1ps // // Company: // Engineer: // // Create Date: 2024/09/04 15:20:30 // Design Name: hilary // Module Name: DDS_Module //module DDS_Module(Clk,Reset_n,Fword,Pword,Data);input Clk;input Reset_n;input [31:0]…

如何使div居中?CSS居中终极指南

前言 长期以来,如何在父元素中居中对齐一个元素,一直是一个让人头疼的问题,随着 CSS 的发展,越来越多的工具可以用来解决这个难题,五花八门的招式一大堆,这篇博客,旨在帮助你理解不同的居中方法…

自制游戏手柄--Android画面的input输入控制

在使用传感器获取到运动数据后,怎样转换为input事件传给手机呢,这里以Android为例, 我们可以考虑以下方式: 1. 物理方式,使用舵机连接触碰笔去实现, 2. 构造MotionEvent事件,注入input&#…

fastadmin 文件上传七牛云

1-安装七牛云官方SDK composer require qiniu/php-sdk 2-七牛云配置 <?phpnamespace app\common\controller;use Qiniu\Storage\BucketManager; use think\Config; use Qiniu\Auth; use Qiniu\Storage\UploadManager; use think\Controller; use think\Db;/*** 七牛基类*…

【leetcode刷题之路】面试经典hot100(2)——普通数组+矩阵+链表

文章目录 5 普通数组5.1 【动态规划】最大子数组和5.2 【排序】合并区间5.3 【数组】轮转数组5.4 【前缀和】除自身以外数组的乘积5.5 【哈希表】缺失的第一个正数 6 矩阵6.1 【哈希表】矩阵置零6.2 【模拟】螺旋矩阵6.3 【模拟】旋转图像6.4 【分治】搜索二维矩阵 II 7 链表7.…

Go语言结构体和元组全面解析

Go语言中的复合类型与其应用 在编程中&#xff0c;标准类型虽然方便&#xff0c;但无法满足所有需求。Go通过支持结构体和元组类型&#xff0c;为开发者提供了自定义数据类型的能力。本文将介绍如何定义结构体、如何使用指针操作结构体、如何通过元组返回多个值等内容&#xf…