GQA (group query attention)

什么是GQA?

多个head的Query共用一组K和V。llama模型就用到该技术。

需要明确几点:

1.group有几组

2.每个group对应几个head

3.q以head为单位 k,v以group为单位 每个head/group特征维度都是head_dim

代码实现

import torch.nn as nn
import torch
import math# 自注意力
class GroupQueryAttention(nn.Module):def __init__(self, d_model, n_heads, n_groups):super().__init__()self.d_model = d_modelself.n_heads = n_headsself.n_groups = n_groupsassert d_model % n_heads == 0self.head_dim = d_model // n_heads# 每个group对应几个headself.heads_per_group = n_heads // n_groups# 以head为单位self.w_q = nn.Linear(d_model, n_heads*self.head_dim)  # n_heads*head_dim=d_model# 以group为单位self.w_k = nn.Linear(d_model, n_groups*self.head_dim)self.w_v = nn.Linear(d_model, n_groups*self.head_dim)self.w_combine = nn.Linear(d_model, d_model)self.softmax = nn.Softmax(dim=-1)# 给k,v进行复制,假设每个组对应3个head,那就要把每个组的数据复制3遍def expand(self, data): # data:k/v [b, group, seq_len, head_dim]b,_,seq_len,_ = data.shapedata = data[:,:,None,:,:].expand(b, self.n_groups, self.heads_per_group, seq_len, self.head_dim)data = data.contiguous().view(b, -1, seq_len, self.head_dim)return data  # [b, group*heads_per_group, seq_len, head_dim]def forward(self, x, use_mask=False): # x: [b, seq_len, d_model]b, seq_len, _ = x.shapeq,k,v = self.w_q(x), self.w_k(x), self.w_v(x)q = q.view(b, seq_len, self.n_heads, self.head_dim).permute(0,2,1,3)  # 以head为单位k = k.view(b, seq_len, self.n_groups, self.head_dim).permute(0,2,1,3) # 以group为单位v = v.view(b, seq_len, self.n_groups, self.head_dim).permute(0,2,1,3)# 复制k,v = self.expand(k), self.expand(v)score = q @ k.transpose(-1,-2) / math.sqrt(self.head_dim)if use_mask:mask = torch.tril(torch.ones(seq_len, seq_len))score = score.masked_fill(mask==0, float('-inf'))score = self.softmax(score) @ vscore = score.permute(0,2,1,3).contiguous().view(b, seq_len, self.n_heads*self.head_dim)out = self.w_combine(score)return outx = torch.rand(2,100,384)
model = GroupQueryAttention(384, 4, 2)
out = model(x)
print(out.shape)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/878961.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

串口通信协议(UART)

简介 uart通讯协议,是一种成本低、容易使用、通信线路简单,可实现两个设备的互相通信的协议;是一种全双工,设备点对点通信的协议。下面从硬件电路、电平标准和串口参数等方面来了解uart通信协议。 硬件电路 硬件电路非常简单&am…

深入Redis:强大的主从复制

如果某个服务器或者程序,只有一个节点(服务器),就会有很大的问题。比如可用不高,并发量也比较低。引入分布式系统,也主要是为了解决上述的单点问题。 Redis,主要部署在分布式系统上。在分布式系…

无人机之地面站篇

无人机的地面站,又称无人机控制站,是整个无人机系统的重要组成部分,扮演着作战指挥中心的角色。以下是对无人机地面站的详细阐述: 一、定义与功能 无人机地面站是指具有对无人机飞行平台和任务载荷进行监控和操纵能力的一组设备&…

Postgresql碎片整理

创建pgstattuple 扩展 CREATE EXTENSION pgstattuple 获取表的元组(行)信息,包括空闲空间的比例和行的平均宽度 SELECT * FROM pgstattuple(表名); 查看表和索引大小 SELECT pg_relation_size(表名), pg_relation_size(索引名称); 清理碎片方…

【魔法 / NOI】

题目 思路 动态规划: 状态定义: f [ k ] [ i ] [ j ] 对应使用了不超过 k 次魔法,从 i 到 j 的路径集合 f[k][i][j] 对应使用了不超过k次魔法,从i到j的路径集合 f[k][i][j]对应使用了不超过k次魔法,从i到j的路径集合 状…

vc-align源码分析 -- ant-design-vue系列

vc-align源码分析 源码地址:https://github.com/vueComponent/ant-design-vue/tree/main/components/vc-align 1 基础代码 1.1 名词约定 需要对齐的节点叫source,对齐的目标叫target。 1.2 props 提供了两个参数: align:对…

WPF-快速构建统计表、图表并认识相关框架

一、使用ScottPlot.Wpf 官网地址:https://scottplot.net/quickstart/wpf/ 1、添加NuGet包:ScottPlot.Wpf 2、XAML映射命名空间: xmlns:ScottPlot"clr-namespace:ScottPlot.WPF;assemblyScottPlot.WPF" 3、简单示例:…

2024年测评7款最佳AI论文修改润色平台

在2024年,AI论文修改润色平台的测评和推荐成为学术界和研究者们关注的热点。本文将详细评测并推荐7款最佳AI论文修改润色平台,包括千笔-AIPassPaper,并结合我搜索到的资料进行分析。 一、千笔-AIPassPaper 千笔-AIPassPaper是一款集论文大纲…

【Nginx系列】Nginx中rewrite模块

💝💝💝欢迎来到我的博客,很高兴能够在这里和您见面!希望您在这里可以感受到一份轻松愉快的氛围,不仅可以获得有趣的内容和知识,也可以畅所欲言、分享您的想法和见解。 推荐:kwan 的首页,持续学…

SpringBoot教程(安装篇) | Elasticsearch的安装

SpringBoot教程(安装篇) | Elasticsearch的安装 一、确定Elasticsearch版本二、下载elasticsearch(windows版本)官网下载如何解压配置 允许 别人跨域 访问自己启动运行 三、Es可视化工具安装(elasticsearch-head&#…

DDS基本原理--FPGA学习笔记

DDS信号发生器原理: timescale 1ns / 1ps // // Company: // Engineer: // // Create Date: 2024/09/04 15:20:30 // Design Name: hilary // Module Name: DDS_Module //module DDS_Module(Clk,Reset_n,Fword,Pword,Data);input Clk;input Reset_n;input [31:0]…

如何使div居中?CSS居中终极指南

前言 长期以来,如何在父元素中居中对齐一个元素,一直是一个让人头疼的问题,随着 CSS 的发展,越来越多的工具可以用来解决这个难题,五花八门的招式一大堆,这篇博客,旨在帮助你理解不同的居中方法…

自制游戏手柄--Android画面的input输入控制

在使用传感器获取到运动数据后,怎样转换为input事件传给手机呢,这里以Android为例, 我们可以考虑以下方式: 1. 物理方式,使用舵机连接触碰笔去实现, 2. 构造MotionEvent事件,注入input&#…

fastadmin 文件上传七牛云

1-安装七牛云官方SDK composer require qiniu/php-sdk 2-七牛云配置 <?phpnamespace app\common\controller;use Qiniu\Storage\BucketManager; use think\Config; use Qiniu\Auth; use Qiniu\Storage\UploadManager; use think\Controller; use think\Db;/*** 七牛基类*…

CTK框架(四): 插件编写

目录 1.生成插件 1.1.环境说明 1.2.服务类&#xff0c;纯虚类&#xff0c;提供接口 1.3.实现插件类&#xff0c;实现纯虚函数 1.4.激活插件&#xff0c;加入ctk框架的生命周期中 1.5.添加资源文件 1.6..pro文件 2.使用此插件 3.总结 1.生成插件 1.1.环境说明 编译ct…

如何将卷积神经网络(CNN)应用于医学图像分析:从分类到分割和检测的实用指南

引言 在现代医疗领域,医学图像已经成为疾病诊断和治疗规划的重要工具。医学图像的类型繁多,包括但不限于X射线、CT(计算机断层扫描)、MRI(磁共振成像)和超声图像。这些图像提供了对身体内部结构的详细视图,有助于医生在进行准确诊断和制定个性化治疗方案时获取关键的信…

[数据结构] 哈希结构的哈希冲突解决哈希冲突

标题&#xff1a;[C] 哈希结构的哈希冲突 && 解决哈希冲突 水墨不写bug 目录 一、引言 1.哈希 2.哈希冲突 3.哈希函数 二、解决哈希冲突 1.闭散列 I&#xff0c;线性探测 II&#xff0c;二次探测 2.开散列 正文开始&#xff1a; 一、引言 哈希表是一种非常实用而…

JS基础学习笔记

1.引入方式 内部脚本 <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><title>Document</title> <…

Nginx跨域运行案例:云台控制http请求,通过 http server 代理转发功能,实现跨域运行。(基于大华摄像头WEB无插件开发包)

文章目录 引言I 跨域运行案例开发资源测试/生产环境,Nginx代理转发,实现跨域运行本机开发运行II nginx的location指令Nginx配置中, 获取自定义请求header头Nginx 配置中,获取URL参数引言 背景:全景监控 需求:感知站点由于云台相关操作为 http 请求,http 请求受浏览器…

抢鲜体验 PolarDB PG 15 开源版

unsetunsetPolarDB 商业版unsetunset 8 月&#xff0c;PolarDB PostgreSQL 版兼容 PostgreSQL 15 版本&#xff08;商业版&#xff09;正式发布上线。 当前版本主要增强优化了以下方面&#xff1a; 改进排序功能&#xff1a;改进内存和磁盘排序算法。 增强SQL功能&#xff1a;支…