MHD、MQA、GQA注意力机制详解

MHD、MQA、GQA注意力机制详解

  • 注意力机制详解及代码
    • 前言:
    • MHA
    • MQA
    • GQA

注意力机制详解及代码

前言:

自回归解码器推理是 Transformer 模型的 一个严重瓶颈,因为在每个解码步骤中加 载解码器权重以及所有注意键和值会产生 内存带宽开销

下图为三种注意力机制的结构图和实验结果

在这里插入图片描述

在这里插入图片描述

MHA

多头注意力机制是Transformer模型中的核心组件。在其设计中,"多头"意味着该机制并不只计算一种注意力权重,而是并行计算多种权重,每种权重都从不同的“视角”捕获输入的不同信息。

  • hidden_state经过线性层得到q、k、v
  • q、k、v经过split后增加一个维度:num_heads
  • q、k计算注意力分数score
  • softmax对注意力分数进行归一化得到注意力权重attention_probs
  • 使用注意力权重和值计算输出:output
  • 对注意力输出进行拼接concat
import torch
from torch import nn
class MutiHeadAttention(torch.nn.Module):def __init__(self, hidden_size, num_heads):super(MutiHeadAttention, self).__init__()self.num_heads = num_headsself.head_dim = hidden_size // num_heads## 初始化Q、K、V投影矩阵self.q_linear = nn.Linear(hidden_size, hidden_size)self.k_linear = nn.Linear(hidden_size, hidden_size)self.v_linear = nn.Linear(hidden_size, hidden_size)## 输出线性层self.o_linear = nn.Linear(hidden_size, hidden_size)def forward(self, hidden_state, attention_mask=None):batch_size = hidden_state.size()[0]query = self.q_linear(hidden_state)key = self.k_linear(hidden_state)value = self.v_linear(hidden_state)query = self.split_head(query)key = self.split_head(key)value = self.split_head(value)## 计算注意力分数attention_scores = torch.matmul(query, key.transpose(-1, -2)) / torch.sqrt(torch.tensor(self.head_dim))if attention_mask != None:attention_scores += attention_mask * -1e-9## 对注意力分数进行归一化attention_probs = torch.softmax(attention_scores, dim=-1)output = torch.matmul(attention_probs, value)## 对注意力输出进行拼接output = output.transpose(-1, -2).contiguous().view(batch_size, -1, self.head_dim * self.num_heads)output = self.o_linear(output)return outputdef split_head(self, x):batch_size = x.size()[0]return x.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1,2)

MQA

多查询注意力(MQA)可能导致质量下降和训练不稳定,并且训练针对质量和推理优化的单独模型可能不可行。此外,虽然一些语言模型已经使用了多查询注意力,如PaLM但许多语言模型没有,包括公开可用的语言模型,如T5和LLaM.

  • hidden_state经过线性层得到q、k、v
  • q、k、v经过split后增加一个维度:num_heads(q = num_heads,k=1,v=1)。相当于多个query,即多查询。
  • q、k计算注意力分数score
  • softmax对注意力分数进行归一化得到注意力权重attention_probs
  • 使用注意力权重和值计算输出:output
  • 对注意力输出进行拼接concat
## 多查询注意力
import torch
from torch import nn
class MutiQueryAttention(torch.nn.Module):def __init__(self, hidden_size, num_heads):super(MutiQueryAttention, self).__init__()self.num_heads = num_headsself.head_dim = hidden_size // num_heads## 初始化Q、K、V投影矩阵self.q_linear = nn.Linear(hidden_size, hidden_size)self.k_linear = nn.Linear(hidden_size, self.head_dim) ###self.v_linear = nn.Linear(hidden_size, self.head_dim) ##### 输出线性层self.o_linear = nn.Linear(hidden_size, hidden_size)def forward(self, hidden_state, attention_mask=None):batch_size = hidden_state.size()[0]query = self.q_linear(hidden_state)key = self.k_linear(hidden_state)value = self.v_linear(hidden_state)query = self.split_head(query)key = self.split_head(key, 1)value = self.split_head(value, 1)## 计算注意力分数attention_scores = torch.matmul(query, key.transpose(-1, -2)) / torch.sqrt(torch.tensor(self.head_dim))if attention_mask != None:attention_scores += attention_mask * -1e-9## 对注意力分数进行归一化attention_probs = torch.softmax(attention_scores, dim=-1)output = torch.matmul(attention_probs, value)output = output.transpose(-1, -2).contiguous().view(batch_size, -1, self.head_dim * self.num_heads)output = self.o_linear(output)return outputdef split_head(self, x, head_num=None):batch_size = x.size()[0]if head_num == None:return x.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1,2)else:return x.view(batch_size, -1, head_num, self.head_dim).transpose(1,2)

GQA

  • 使用 5% 的原始预训练 计算将现有的多头语言模型检查点训 练到具有 MQA 的模型中
  • 引入分组查询注意力 (GQA),这是多 头语言模型的泛化。查询注意力,它使用中间,多于一个,少于查询头数量的键值头。
  • 经过训练的GQA 实现了接近多头注意力 的质量,并且速度与 MQA 相当。
  • hidden_state经过线性层得到q、k、v
  • q、k、v经过split后增加一个维度:num_heads(q = num_heads,k=group_num,v=group_num)。相当于把多头分组了,比如原先有10个头,那就是10个query,分成5组,每组2个query,1个value,1个key。
  • q、k计算注意力分数score
  • softmax对注意力分数进行归一化得到注意力权重attention_probs
  • 使用注意力权重和值计算输出:output
  • 对注意力输出进行拼接concat
## 分组注意力查询
import torch
from torch import nn
class MutiGroupAttention(torch.nn.Module):def __init__(self, hidden_size, num_heads, group_num):super(MutiGroupAttention, self).__init__()self.num_heads = num_headsself.head_dim = hidden_size // num_headsself.group_num = group_num## 初始化Q、K、V投影矩阵self.q_linear = nn.Linear(hidden_size, hidden_size)self.k_linear = nn.Linear(hidden_size, self.group_num * self.head_dim)self.v_linear = nn.Linear(hidden_size, self.group_num * self.head_dim)## 输出线性层self.o_linear = nn.Linear(hidden_size, hidden_size)def forward(self, hidden_state, attention_mask=None):batch_size = hidden_state.size()[0]query = self.q_linear(hidden_state)key = self.k_linear(hidden_state)value = self.v_linear(hidden_state)query = self.split_head(query)key = self.split_head(key, self.group_num)value = self.split_head(value, self.group_num)## 计算注意力分数attention_scores = torch.matmul(query, key.transpose(-1, -2)) / torch.sqrt(torch.tensor(self.head_dim))if attention_mask != None:attention_scores += attention_mask * -1e-9## 对注意力分数进行归一化attention_probs = torch.softmax(attention_scores, dim=-1)output = torch.matmul(attention_probs, value)output = output.transpose(-1, -2).contiguous().view(batch_size, -1, self.head_dim * self.num_heads)output = self.o_linear(output)return outputdef split_head(self, x, group_num=None):batch_size,seq_len = x.size()[:2]if group_num == None:return x.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1,2)else:x = x.view(batch_size, -1, group_num, self.head_dim).transpose(1,2)x = x[:, :, None, :, :].expand(batch_size, group_num, self.num_heads // group_num, seq_len, self.head_dim).reshape(batch_size, self.num_heads // group_num * group_num, seq_len, self.head_dim)return x

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/12730.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

巩固学习8

在 Pandas 中,sep参数用于指定数据中字段之间的分隔符。常见的参数包括: 逗号:,,常用于CSV文件。 制表符:\t,常用于TSV文件。 空格:’ ,用于空格分隔的数据。 分号:;&…

【合成孔径雷达】合成孔径雷达的多视角理解和时/频成像算法的统一解释

文章目录 一、什么是雷达成像(1)主要的遥感探测手段:光学、红外和雷达(2)从数学的角度:雷达成像主要研究什么?数据采集: y T x n yTxn yTxn信息提取: y − > x ? y…

编译错误:stray ‘\357’ in program的解决方法

目录 把报错文件更换编码格式,我试的utf-8 bom编码就可以了,可以多换几种试试。 网友的另一种案例: 编译错误:stray ‘\357’ in program的解决方法 把报错文件更换编码格式,我试的utf-8 bom编码就可以了&#xff0c…

LabVIEW做仪器测试不知道是否适用

LabVIEW(Laboratory Virtual Instrument Engineering Workbench)是一个用于系统工程和测量系统的图形编程平台,由National Instruments开发。它非常适用于仪器控制、数据采集、信号处理以及自动化测试与测量系统的开发。如果您的工作涉及到这…

如何同步管理1000个设备的VLAN数据?

什么是VLAN? VLAN,也就是虚拟局域网,是通过为子网提供数据链路连接来抽象出局域网的概念。在企业网中,一个企业级交换机一般是24口或者是48口,连接这些接口的终端在物理上形成一个广播域。广播域过大,就会导…

【AI智能体】零代码构建AI应用,全网都在喊话歌手谁能应战,一键AI制作歌手信息查询应用

欢迎来到《小5讲堂》 这是《文心智能体平台》系列文章,每篇文章将以博主理解的角度展开讲解。 温馨提示:博主能力有限,理解水平有限,若有不对之处望指正! 目录 文心智能体大赛背景创建应用平台地址快速构建【基础配置】…

前端无样式id或者class等来定位标签

目录: 1、使用背景2、代码处理 1、使用背景 客户使用我们产品组件,发现替换文件,每次替换都会新增如下的样式,造就样式错乱,是组件的文件,目前临时处理的话就是替换文件时删除新增的样式,但是发…

8评分卡建模整体流程梳理

评分卡建模整体流程梳理 学习目标 掌握评分卡建模流程使用Toad库构建评分卡1 加载数据 import pandas as pd from sklearn.metrics import roc_auc_score,roc_curve,auc from sklearn.model_selection import train_test_split from sklearn.linear_model import Logis…

云服务器上Redis数据库被攻击实录+总结

情景重现 Redis日志记录(异常部分): 36346:M 14 May 2024 15:46:12.505 # Possible SECURITY ATTACK detected. It looks like somebody is sending POST or Host: commands to Redis. This is likely due to an attacker attempting to us…

【JVM】阅读Class字节码:常量池

目录 基本结构解析 常量池 常量池简介 如何阅读Class文件中的常量池信息 基本结构解析 Magic(魔数) Magic的唯一作用是确定这个文件是否为一个能被虚拟机所接受的class 文件。魔数值固定为0xCAFEBABE,不会改变。 常量池 常量池简介 下图是反编译过后的字节码文…

Python可视化总结与案例解析

目录 第一章:Python可视化基础 1.1 环境搭建 1.2 数据可视化 1.3 统计图表 1.4 交互式可视化 1.5 实战案例:网站流量分析 1.6 总结 第二章:Python可视化高级应用 2.1 高级图表类型 2.2 动态可视化 2.3 数据可视化最佳实践 2.4 实战…

TensorFlow的学习

0.基础概念 术语表: https://developers.google.cn/machine-learning/glossary?hlzh-cn#logits 1.快速入门 https://tensorflow.google.cn/tutorials/quickstart/beginner?hlzh-cn 2.基于Keras进行图像分类 https://tensorflow.google.cn/tutorials/keras/cl…

gradle 共享存储挂载缓存目录的问题

2个任务同时构建的时候,报错如上。 原因:挂载目录的问题导致的,挂在最小粒度的目录下。 /home/app/.gradle/caches/modules-2/files-2.1 挂载到这个级别的目录下。

一文详解什么是手机在网时长API

手机在网时长API最近被讨论得越来越多,因为随着移动互联网的不断发展,越来越多的场景需要使用到用户的手机号,比如商品交易、客户服务、信息收发、网络即时通讯等。手机号码状态查询功能使用得越来越广泛,常见的有手机在网时长查询…

演员怎么上百度百科

百度百科是一个公正、开放、客观的平台,它为演员提供了一个展示自己过往经历和演艺生涯的平台。以下是百科优化网yajje总结的演员创建百度百科的一些步骤和注意事项: 创建演员百度百科的基本条件 人物影响力:演员创建百度百科需要满足官方的规…

振弦采集仪在岩土工程监测中的重要性及应用案例分享

振弦采集仪在岩土工程监测中的重要性及应用案例分享 岩土工程监测是为了确保土地和建筑物的稳定性以及确保施工安全而进行的一项重要工作。河北稳控科技振弦采集仪是岩土工程监测中一种常用的仪器设备,通过测量土体振动频率来评估土体的稳定性和强度变化&#xff0…

写个进度条

using UnityEngine; using UnityEngine.UI; [ExecuteAlways] public class 进度条控制器 : MonoBehaviour {public Image 母进度条; // 母进度条背景public Image 子进度条; // 子进度条[Range(0, 1)] public float 进度值; // 进度值public float 起点偏移量; // 从左侧开始的…

理解打包好的vue项目结构dist包

目录 linux查询dist目录整体解释子目录文件解释CSSFONTSJS linux查询dist目录 roothcss-ecs-7881:/www/java_project/dist# ls -l total 3004 drwxr-xr-x 2 root root 4096 Dec 31 10:15 css -rw-r--r-- 1 root root 4286 Dec 31 10:15 favicon.ico drwxr-xr-x 2 root r…

MySQL从主库恢复从库

主库备份数据,拷贝至从节点1.1 备份数据 sudo python /data/apps/xtrabackup/script/xtrabackup.py -m full 备份目录为: /data/mysql_bakcup/<port>/<date>/full_<date> 例:/data/mysql_backup/13306/20231124/full_164044/ 1.2 拷贝备份数据至从节点 sc…

霸道龙尊短视频:成都鼎茂宏升文化传媒公司

霸道龙尊短视频&#xff1a;龙族的传奇与现代的交融 在数字化时代的浪潮中&#xff0c;短视频以其短小精悍、内容丰富的特点&#xff0c;迅速占领了人们的碎片时间。成都鼎茂宏升文化传媒公司而在这些短视频中&#xff0c;一股独特的“霸道龙尊”风潮正在悄然兴起&#xff0c;…