MHD、MQA、GQA注意力机制详解

MHD、MQA、GQA注意力机制详解

  • 注意力机制详解及代码
    • 前言:
    • MHA
    • MQA
    • GQA

注意力机制详解及代码

前言:

自回归解码器推理是 Transformer 模型的 一个严重瓶颈,因为在每个解码步骤中加 载解码器权重以及所有注意键和值会产生 内存带宽开销

下图为三种注意力机制的结构图和实验结果

在这里插入图片描述

在这里插入图片描述

MHA

多头注意力机制是Transformer模型中的核心组件。在其设计中,"多头"意味着该机制并不只计算一种注意力权重,而是并行计算多种权重,每种权重都从不同的“视角”捕获输入的不同信息。

  • hidden_state经过线性层得到q、k、v
  • q、k、v经过split后增加一个维度:num_heads
  • q、k计算注意力分数score
  • softmax对注意力分数进行归一化得到注意力权重attention_probs
  • 使用注意力权重和值计算输出:output
  • 对注意力输出进行拼接concat
import torch
from torch import nn
class MutiHeadAttention(torch.nn.Module):def __init__(self, hidden_size, num_heads):super(MutiHeadAttention, self).__init__()self.num_heads = num_headsself.head_dim = hidden_size // num_heads## 初始化Q、K、V投影矩阵self.q_linear = nn.Linear(hidden_size, hidden_size)self.k_linear = nn.Linear(hidden_size, hidden_size)self.v_linear = nn.Linear(hidden_size, hidden_size)## 输出线性层self.o_linear = nn.Linear(hidden_size, hidden_size)def forward(self, hidden_state, attention_mask=None):batch_size = hidden_state.size()[0]query = self.q_linear(hidden_state)key = self.k_linear(hidden_state)value = self.v_linear(hidden_state)query = self.split_head(query)key = self.split_head(key)value = self.split_head(value)## 计算注意力分数attention_scores = torch.matmul(query, key.transpose(-1, -2)) / torch.sqrt(torch.tensor(self.head_dim))if attention_mask != None:attention_scores += attention_mask * -1e-9## 对注意力分数进行归一化attention_probs = torch.softmax(attention_scores, dim=-1)output = torch.matmul(attention_probs, value)## 对注意力输出进行拼接output = output.transpose(-1, -2).contiguous().view(batch_size, -1, self.head_dim * self.num_heads)output = self.o_linear(output)return outputdef split_head(self, x):batch_size = x.size()[0]return x.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1,2)

MQA

多查询注意力(MQA)可能导致质量下降和训练不稳定,并且训练针对质量和推理优化的单独模型可能不可行。此外,虽然一些语言模型已经使用了多查询注意力,如PaLM但许多语言模型没有,包括公开可用的语言模型,如T5和LLaM.

  • hidden_state经过线性层得到q、k、v
  • q、k、v经过split后增加一个维度:num_heads(q = num_heads,k=1,v=1)。相当于多个query,即多查询。
  • q、k计算注意力分数score
  • softmax对注意力分数进行归一化得到注意力权重attention_probs
  • 使用注意力权重和值计算输出:output
  • 对注意力输出进行拼接concat
## 多查询注意力
import torch
from torch import nn
class MutiQueryAttention(torch.nn.Module):def __init__(self, hidden_size, num_heads):super(MutiQueryAttention, self).__init__()self.num_heads = num_headsself.head_dim = hidden_size // num_heads## 初始化Q、K、V投影矩阵self.q_linear = nn.Linear(hidden_size, hidden_size)self.k_linear = nn.Linear(hidden_size, self.head_dim) ###self.v_linear = nn.Linear(hidden_size, self.head_dim) ##### 输出线性层self.o_linear = nn.Linear(hidden_size, hidden_size)def forward(self, hidden_state, attention_mask=None):batch_size = hidden_state.size()[0]query = self.q_linear(hidden_state)key = self.k_linear(hidden_state)value = self.v_linear(hidden_state)query = self.split_head(query)key = self.split_head(key, 1)value = self.split_head(value, 1)## 计算注意力分数attention_scores = torch.matmul(query, key.transpose(-1, -2)) / torch.sqrt(torch.tensor(self.head_dim))if attention_mask != None:attention_scores += attention_mask * -1e-9## 对注意力分数进行归一化attention_probs = torch.softmax(attention_scores, dim=-1)output = torch.matmul(attention_probs, value)output = output.transpose(-1, -2).contiguous().view(batch_size, -1, self.head_dim * self.num_heads)output = self.o_linear(output)return outputdef split_head(self, x, head_num=None):batch_size = x.size()[0]if head_num == None:return x.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1,2)else:return x.view(batch_size, -1, head_num, self.head_dim).transpose(1,2)

GQA

  • 使用 5% 的原始预训练 计算将现有的多头语言模型检查点训 练到具有 MQA 的模型中
  • 引入分组查询注意力 (GQA),这是多 头语言模型的泛化。查询注意力,它使用中间,多于一个,少于查询头数量的键值头。
  • 经过训练的GQA 实现了接近多头注意力 的质量,并且速度与 MQA 相当。
  • hidden_state经过线性层得到q、k、v
  • q、k、v经过split后增加一个维度:num_heads(q = num_heads,k=group_num,v=group_num)。相当于把多头分组了,比如原先有10个头,那就是10个query,分成5组,每组2个query,1个value,1个key。
  • q、k计算注意力分数score
  • softmax对注意力分数进行归一化得到注意力权重attention_probs
  • 使用注意力权重和值计算输出:output
  • 对注意力输出进行拼接concat
## 分组注意力查询
import torch
from torch import nn
class MutiGroupAttention(torch.nn.Module):def __init__(self, hidden_size, num_heads, group_num):super(MutiGroupAttention, self).__init__()self.num_heads = num_headsself.head_dim = hidden_size // num_headsself.group_num = group_num## 初始化Q、K、V投影矩阵self.q_linear = nn.Linear(hidden_size, hidden_size)self.k_linear = nn.Linear(hidden_size, self.group_num * self.head_dim)self.v_linear = nn.Linear(hidden_size, self.group_num * self.head_dim)## 输出线性层self.o_linear = nn.Linear(hidden_size, hidden_size)def forward(self, hidden_state, attention_mask=None):batch_size = hidden_state.size()[0]query = self.q_linear(hidden_state)key = self.k_linear(hidden_state)value = self.v_linear(hidden_state)query = self.split_head(query)key = self.split_head(key, self.group_num)value = self.split_head(value, self.group_num)## 计算注意力分数attention_scores = torch.matmul(query, key.transpose(-1, -2)) / torch.sqrt(torch.tensor(self.head_dim))if attention_mask != None:attention_scores += attention_mask * -1e-9## 对注意力分数进行归一化attention_probs = torch.softmax(attention_scores, dim=-1)output = torch.matmul(attention_probs, value)output = output.transpose(-1, -2).contiguous().view(batch_size, -1, self.head_dim * self.num_heads)output = self.o_linear(output)return outputdef split_head(self, x, group_num=None):batch_size,seq_len = x.size()[:2]if group_num == None:return x.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1,2)else:x = x.view(batch_size, -1, group_num, self.head_dim).transpose(1,2)x = x[:, :, None, :, :].expand(batch_size, group_num, self.num_heads // group_num, seq_len, self.head_dim).reshape(batch_size, self.num_heads // group_num * group_num, seq_len, self.head_dim)return x

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/12730.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【合成孔径雷达】合成孔径雷达的多视角理解和时/频成像算法的统一解释

文章目录 一、什么是雷达成像(1)主要的遥感探测手段:光学、红外和雷达(2)从数学的角度:雷达成像主要研究什么?数据采集: y T x n yTxn yTxn信息提取: y − > x ? y…

编译错误:stray ‘\357’ in program的解决方法

目录 把报错文件更换编码格式,我试的utf-8 bom编码就可以了,可以多换几种试试。 网友的另一种案例: 编译错误:stray ‘\357’ in program的解决方法 把报错文件更换编码格式,我试的utf-8 bom编码就可以了&#xff0c…

如何同步管理1000个设备的VLAN数据?

什么是VLAN? VLAN,也就是虚拟局域网,是通过为子网提供数据链路连接来抽象出局域网的概念。在企业网中,一个企业级交换机一般是24口或者是48口,连接这些接口的终端在物理上形成一个广播域。广播域过大,就会导…

【AI智能体】零代码构建AI应用,全网都在喊话歌手谁能应战,一键AI制作歌手信息查询应用

欢迎来到《小5讲堂》 这是《文心智能体平台》系列文章,每篇文章将以博主理解的角度展开讲解。 温馨提示:博主能力有限,理解水平有限,若有不对之处望指正! 目录 文心智能体大赛背景创建应用平台地址快速构建【基础配置】…

前端无样式id或者class等来定位标签

目录: 1、使用背景2、代码处理 1、使用背景 客户使用我们产品组件,发现替换文件,每次替换都会新增如下的样式,造就样式错乱,是组件的文件,目前临时处理的话就是替换文件时删除新增的样式,但是发…

【JVM】阅读Class字节码:常量池

目录 基本结构解析 常量池 常量池简介 如何阅读Class文件中的常量池信息 基本结构解析 Magic(魔数) Magic的唯一作用是确定这个文件是否为一个能被虚拟机所接受的class 文件。魔数值固定为0xCAFEBABE,不会改变。 常量池 常量池简介 下图是反编译过后的字节码文…

TensorFlow的学习

0.基础概念 术语表: https://developers.google.cn/machine-learning/glossary?hlzh-cn#logits 1.快速入门 https://tensorflow.google.cn/tutorials/quickstart/beginner?hlzh-cn 2.基于Keras进行图像分类 https://tensorflow.google.cn/tutorials/keras/cl…

gradle 共享存储挂载缓存目录的问题

2个任务同时构建的时候,报错如上。 原因:挂载目录的问题导致的,挂在最小粒度的目录下。 /home/app/.gradle/caches/modules-2/files-2.1 挂载到这个级别的目录下。

演员怎么上百度百科

百度百科是一个公正、开放、客观的平台,它为演员提供了一个展示自己过往经历和演艺生涯的平台。以下是百科优化网yajje总结的演员创建百度百科的一些步骤和注意事项: 创建演员百度百科的基本条件 人物影响力:演员创建百度百科需要满足官方的规…

振弦采集仪在岩土工程监测中的重要性及应用案例分享

振弦采集仪在岩土工程监测中的重要性及应用案例分享 岩土工程监测是为了确保土地和建筑物的稳定性以及确保施工安全而进行的一项重要工作。河北稳控科技振弦采集仪是岩土工程监测中一种常用的仪器设备,通过测量土体振动频率来评估土体的稳定性和强度变化&#xff0…

霸道龙尊短视频:成都鼎茂宏升文化传媒公司

霸道龙尊短视频:龙族的传奇与现代的交融 在数字化时代的浪潮中,短视频以其短小精悍、内容丰富的特点,迅速占领了人们的碎片时间。成都鼎茂宏升文化传媒公司而在这些短视频中,一股独特的“霸道龙尊”风潮正在悄然兴起,…

Nginx配置文件conf解释

系列文章目录 文章目录 系列文章目录前言 前言 前些天发现了一个巨牛的人工智能学习网站,通俗易懂,风趣幽默,忍不住分享一下给大家。点击跳转到网站,这篇文章男女通用,看懂了就去分享给你的码吧。 Nginx(“engine x”…

基于springboot+vue+Mysql的在线答疑系统

开发语言:Java框架:springbootJDK版本:JDK1.8服务器:tomcat7数据库:mysql 5.7(一定要5.7版本)数据库工具:Navicat11开发软件:eclipse/myeclipse/ideaMaven包:…

服务攻防——应用协议软件,设备平台

向日葵利用 vnc利用5900端口 当为none就可以直接连接,而其他几种密码也能破解 可以使用hydna来尝试爆破 teamviewer(cve2020-13699) 让对方点击这个网站,就会 触发 zabbix 端口10051 cve2020 手工 点击这个 找到cookie 然后不需要密码就能进…

搭建Rust开发环境

Windows搭建 下载:https://www.rust-lang.org/zh-CN/tools/install Linux搭建 这里我更推荐基于Linux搭建。 curl --proto https --tlsv1.2 -sSf https://sh.rustup.rs | sh等一会儿以后,会让你输入命令,这里输入1: 之后就…

一表捋清网络安全等级保护测评要求

三级网络安全等级保护测评指标: 对于中小企事业单位来说,网络安全建设是一个复杂且投入较高的过程,因此他们更倾向于寻找一种“省心省力”的等保建设方案,以及一种能够持续有效且具有较高性价比的网络安全建设投入方式。 此时&…

【微积分】三角函数求导积分公式的巧妙记忆

三角函数积分求导公式的巧妙记忆 图像的整体记忆: 上面是sinx cosx 下面也是s开头,secx,cscx 中间是tanx cotx 解释说明: 1️⃣ 对角线互为倒数,即sinx对角线是cscx,这样我们可以更好记住这个六边形图像。…

Web课外练习7

<!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><title>照片墙</title><style>body {display: …

libcity笔记: HSTLSTMEncoder

1 __init__ 2 encode 得到的内容如下&#xff1a; data_feature的内容&#xff1a; 一共有多少个location1【包括pad的一个】最长的时间间隔&#xff08;秒&#xff09;最长的距离间隔&#xff08;千米&#xff09;多少个useer idpadding 的locationidpad_item的内容 location…

SpringBoot 3.2.5 + ElasticSearch 8.12.0 - SpringData 开发指南

目录 一、SpringData ElasticSearch 1.1、环境配置 1.2、创建实体类 1.3、ElasticSearchTemplate 的使用 1.3.1、创建索引库&#xff0c;设置映射 1.3.2、创建索引映射注意事项 1.3.3、简单的 CRUD 1.3.4、三种构建搜索条件的方式 1.3.5、NativeQuery 搜索实战 1.3.6…