haiku实现门控多头注意力模块

在多头注意力机制中,通常输入的数据包括查询(Q)、键(K)和值(V)。这些数据的维度以及权重矩阵的维度在多头注意力机制中扮演关键角色。下面对数据及权重的维度进行解释:

  1. 输入数据(Queries, Keys, Values):

    • Queries (Q): 表示待查询的信息,通常对应输入序列的每个位置。其维度通常为 (batch_size, seq_length, q_dim),其中 q_dim 是查询向量的维度。
    • Keys (K): 表示用于计算注意力分数的信息,也通常对应输入序列的每个位置。其维度通常为 (batch_size, seq_length, key_dim),其中 key_dim 是键向量的维度。
    • Values (V): 表示待加权求和的信息,同样对应输入序列的每个位置。其维度通常为 (batch_size, seq_length, value_dim),其中 value_dim 是值向量的维度。
  2. 权重矩阵:

    • 查询权重矩阵 (Q_weights): 用于对查询(Q)进行线性变换,将其映射到多个注意力头的维度。其维度通常为 (q_dim, num_heads, head_dim),其中 num_heads 是注意力头的数量,head_dim 是每个注意力头的维度。
    • 键权重矩阵 (K_weights): 用于对键(K)进行线性变换,同样映射到多个注意力头的维度。其维度通常为 (key_dim, num_heads, head_dim)。
    • 值权重矩阵 (V_weights): 用于对值(V)进行线性变换,映射到多个注意力头的维度。其维度通常为 (value_dim, num_heads, head_dim)。
def glorot_uniform():return hk.initializers.VarianceScaling(scale=1.0,mode='fan_avg',distribution='uniform')def stable_softmax(logits: jax.Array) -> jax.Array:"""Numerically stable softmax for (potential) bfloat 16."""if logits.dtype == jnp.float32:output = jax.nn.softmax(logits)elif logits.dtype == jnp.bfloat16:# Need to explicitly do softmax in float32 to avoid numerical issues# with large negatives. Large negatives can occur if trying to mask# by adding on large negative logits so that things softmax to zero.output = jax.nn.softmax(logits.astype(jnp.float32)).astype(jnp.bfloat16)else:raise ValueError(f'Unexpected input dtype {logits.dtype}')return outputclass Attention(hk.Module):"""Multihead attention."""def __init__(self, config, global_config, output_dim, name='attention'):super().__init__(name=name)self.config = configself.global_config = global_configself.output_dim = output_dimdef __call__(self, q_data, m_data, mask, nonbatched_bias=None):"""Builds Attention module.Arguments:q_data: A tensor of queries, shape [batch_size, N_queries, q_channels].m_data: A tensor of memories from which the keys and values areprojected, shape [batch_size, N_keys, m_channels].mask: A mask for the attention, shape [batch_size, N_queries, N_keys].nonbatched_bias: Shared bias, shape [N_queries, N_keys].Returns:A float32 tensor of shape [batch_size, N_queries, output_dim]."""# Sensible default for when the config keys are missingkey_dim = self.config.get('key_dim', int(q_data.shape[-1]))value_dim = self.config.get('value_dim', int(m_data.shape[-1]))num_head = self.config.num_headassert key_dim % num_head == 0assert value_dim % num_head == 0key_dim = key_dim // num_headvalue_dim = value_dim // num_head# weights维度(数据最后一维的维度数,注意力头数量,每个注意力头映射的数据维度)q_weights = hk.get_parameter('query_w', shape=(q_data.shape[-1], num_head, key_dim),dtype=q_data.dtype,init=glorot_uniform())k_weights = hk.get_parameter('key_w', shape=(m_data.shape[-1], num_head, key_dim),dtype=q_data.dtype,init=glorot_uniform())v_weights = hk.get_parameter('value_w', shape=(m_data.shape[-1], num_head, value_dim),dtype=q_data.dtype,init=glorot_uniform())# bqa: 输入张量 q_data 的轴的标记。(batch_size, seq_length, q_dim)# 'b' :batch 维度,'q':查询序列维度,'a' 查询向量的维度。所以,'bqa' 表示 q_data 的三个轴。# ahc:查询权重矩阵的形状, a:查询向量的维度,h:注意力头的数量,c: 每个注意力头中查询的维度。# key_dim**(-0.5) 注意力缩放,避免注意力分数过大或过小# jnp.einsum:Einstein Summation Notation(爱因斯坦求和约定)。# 一种紧凑、灵活的方式来指定和计算张量的乘积、求和和转置等操作。q = jnp.einsum('bqa,ahc->bqhc', q_data, q_weights) * key_dim**(-0.5)k = jnp.einsum('bka,ahc->bkhc', m_data, k_weights)v = jnp.einsum('bka,ahc->bkhc', m_data, v_weights)# 注意力分数,计算每个查询(q)和键(k)之间的点积,以获得注意力分数。# 结果维度为bhqk (batch_size, num_heads, num_q, num_k), # num_q/num_k为查询/键的数量,一般为 seq_length。logits = jnp.einsum('bqhc,bkhc->bhqk', q, k)if nonbatched_bias is not None:logits += jnp.expand_dims(nonbatched_bias, axis=0)# 注意力分数中加入masklogits = jnp.where(mask, logits, _SOFTMAX_MASK)# 对注意力分数进行softmax操作,我们得到每个位置对输入序列的权重分配。weights = stable_softmax(logits)# 注意力分数对值进行加权求和,得到多头注意力机制的输出# 两个向量的点积可以用于度量它们之间的相似性。如果两个向量越相似,它们的点积就越大weighted_avg = jnp.einsum('bhqk,bkhc->bqhc', weights, v)if self.global_config.zero_init:init = hk.initializers.Constant(0.0)else:init = glorot_uniform()# 带有bias的门控注意力if self.config.gating:gating_weights = hk.get_parameter('gating_w',shape=(q_data.shape[-1], num_head, value_dim),dtype=q_data.dtype,init=hk.initializers.Constant(0.0))gating_bias = hk.get_parameter('gating_b',shape=(num_head, value_dim),dtype=q_data.dtype,init=hk.initializers.Constant(1.0))gate_values = jnp.einsum('bqc, chv->bqhv', q_data,gating_weights) + gating_biasgate_values = jax.nn.sigmoid(gate_values)# ⊙ 对应元素相乘weighted_avg *= gate_valueso_weights = hk.get_parameter('output_w', shape=(num_head, value_dim, self.output_dim),dtype=q_data.dtype,init=init)o_bias = hk.get_parameter('output_b', shape=(self.output_dim,),dtype=q_data.dtype,init=hk.initializers.Constant(0.0))# 线性变换到输出维度大小output = jnp.einsum('bqhc,hco->bqo', weighted_avg, o_weights) + o_biasreturn output

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/609812.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

鸿蒙原生应用/元服务开发-长时任务

概述 功能介绍 应用退至后台后,对于在后台需要长时间运行用户可感知的任务,例如播放音乐、导航等。为防止应用进程被挂起,导致对应功能异常,可以申请长时任务,使应用在后台长时间运行。申请长时任务后,系统…

Linux调试------gdb的使用

目录 前言 一、gdb打开可执行程序 二、查看代码与操作断点 1.l 查看代码 2.b 打断点 3.info b 查看断点信息 4.d 删除断点 5.disable 和 enable 断点的禁用与启用 三、调试 1.r 启动调试 2. n 逐过程 3. s 逐语句 4.display显示变量 5.undisplay 取消显…

e2studio开发磁力计LIS2MDL(1)----轮询获取磁力计数据

e2studio开发磁力计LIS2MDL.1--轮询获取磁力计数据 概述视频教学样品申请源码下载速率新建工程工程模板保存工程路径芯片配置工程模板选择时钟设置UART配置UART属性配置设置e2studio堆栈e2studio的重定向printf设置R_SCI_UART_Open()函数原型回调函数user_uart_callback ()prin…

C语言菱形图案生成程序详细解析

1. 介绍 在这篇博客中,我们将详细解析一个用C语言编写的程序,该程序旨在生成具有对称美感的菱形图案。用户可以通过输入指定的长度,程序将输出相应大小的菱形。 2. 代码 注:代码中的注释已经非常详细,也有举例&#…

根据MySql的表名,自动生成实体类,模仿ORM框架

ORM框架可以根据数据库的表自动生成实体类,以及相应CRUD操作 本文是一个自动生成实体类的工具,用于生成Mysql表对应的实体类。 新建Winform窗体应用程序AutoGenerateForm,框架(.net framework 4.5), 添加对System.Configuration的…

高效构建Java应用:Maven入门和进阶(三)

高效构建Java应用:Maven入门和进阶(三) 三. Maven的核心功能和构建管理3.1 依赖管理和配置3.2 依赖传递和冲突3.3 依赖导入失败场景和解决方案3.4 扩展构建管理和插件配置 三. Maven的核心功能和构建管理 3.1 依赖管理和配置 Maven 依赖管理…

7-Docker私有仓库registry搭建

1.Docker私有仓库registry搭建 1.下载registry镜像 命令: docker pull registry [root@centos79 ~]# docker search registry NAME DESCRIPTION STARS OFFICIAL AUTOMATED registry …

CentOS安装Docker(超详细)

CentOS安装Docker 1 知识小课堂2 CentOS安装Docker2.1 1.1.卸载(可选)1.2.安装docker1.3.启动docker1.4.配置镜像加速 3 CentOS安装DockerCompose3.1.下载3.2.修改文件权限3.3.Base自动补全命令: 4 Docker镜像仓库4.1.简化版镜像仓库4.2.带有…

数据分析求职-知识脑图

今天和大家聊聊数据分析求职常见面试题,这是这个系列的第一篇文章,但是我不想开始就直接罗列题目,因为这样的文章实在太多了,同学们的兴趣程度肯定一般。所以,我想先和大家聊聊在准备面试题时候通常遇到的困扰&#xf…

GFP-CERTIFIED®FLUOFORTE®钙离子检测试剂盒

Enzo Life Sciences的GFP-CERTIFIED FLUOFORTE Calcium assay kit提供了一种荧光分析方法,用于检测广泛生物靶标的细胞内钙动员情况。相对于其他商业化的染料,GFP-CERTIFIED FLUOFORTE染料是最亮和最灵敏的荧光钙指示剂。该试剂盒为贴壁和非贴壁细胞系提…

生物信息学中的可重复性研究

科学就其本质而言,是累积渐进的。无论你是使用基于网络的还是基于命令行的工具,在进行研究时都应保证该研究可被其他研究人员重复。这有利于你的工作的累积与进展。在生物信息学领域,这意味着如下内容。 工作流应该有据可查。这可能包括在电脑…

linux 压力测试 AB ApacheBench

ab的简介 ab是apachebench命令的缩写。 ab是apache自带的压力测试工具。ab非常实用,它不仅可以对apache服务器进行网站访问压力测试,也可以对或其它类型的服务器进行压力测试。比如nginx、tomcat、IIS等 ab的原理 ab的原理:ab命令会创建多…

Pendulum详解3——Pendulum在数据分析中的应用 - 时光的数据探险

写在开头: 在数据分析的世界里,时光是一项无可避免的数据元素。处理时间数据,尤其是在大规模数据集和复杂时区的情境下,是数据科学家和分析师们经常面对的挑战。为了迎接这一挑战,我们介绍一款强大的时间处理库 - Pendulum,并深入探讨其在数据分析领域的优越性。 1. 时…

ros架构

ROS(Robot Operating System)是一个灵活的、分布式的系统架构,用于构建机器人软件。它由一系列工具、库和约定组成,提供了一套通用的功能和通信机制,以支持机器人系统的开发、部署和运行。 ROS架构主要包括以下几个核心…

C++面试宝典第17题:找规律填数

题目 仔细观察下面的数字序列,找到规律,并填写空白处的数字。 (1)1, 2, 4, 7, 11, 16, __ (2)-1, 2, 7, 28, __, 126 (3)6, 10, 18, 32, 57, __ (4)19, 6, 1, 2, 11, __ (5)2, 3, 5, 7, 11, __ (6)1, 8, 9, 4, __, 1/6 (7)1, 2, 3, 7, 16, __, 321 (8)1, 2, …

linux异常情况,排查处理中

登录客户环境后,发现一个奇怪情况如下图,之前也遇到过,直接fuser -ck /backup操作的话,主机将会重启,因数据库运行中,等待停机维护时间,同时也在想办法不重启的情况下解决该问题 [rootdb ~]# f…

linux发送http请求命令

一、http get请求 1、curl命令不带参 curl “http://www.baidu.com” 如果这里的URL指向的是一个文件或者一幅图都可以直接下载到本地 curl -i “http://www.baidu.com” 显示全部信息 curl -l “http://www.baidu.com” 只显示头部信息 curl -v “http://www.baidu.com”…

git撤销提交到本地的commit

有些时候,当我们提交代码到本地后,突然发现因为某些原因需要撤销提交本地的代码。 就比如我,因为代码写错了分支,已经提交到本地了,而我需要取消,并且还要把代码搞得另外的分支上。 提交前: …

SpiderFlow爬虫平台 前台RCE漏洞复现(CVE-2024-0195)

0x01 产品简介 SpiderFlow是新一代爬虫平台,以图形化方式定义爬虫流程,以流程图的方式定义爬虫,不写代码即可完成爬虫,是一个高度灵活可配置的爬虫平台。 0x02 漏洞概述 SpiderFlow爬虫平台src/main/java/org/spiderflow/controller/FunctionController.java文件的Functi…

【2023】ArrayList和LinkedList详解介绍对比

一、ArrayList 1、概述 ArrayList是实现了List接口的动态数组,所谓动态数组就是他的大小是可变的。实现了所有可选列表操作,并允许包括Null在内的所有元素。除了实现 List 接口外,此类还提供一些方法来操作内部用来存储列表的数组的大小。 …