MultiHeadAttention在Tensorflow中的实现原理


前言

通过这篇文章,你可以学习到Tensorflow实现MultiHeadAttention的底层原理。


一、MultiHeadAttention的本质内涵

1.Self_Atention机制

MultiHeadAttention是Self_Atention的多头堆嵌,有必要对Self_Atention机制进行一次深入浅出的理解,这也是MultiHeadAttention的核心所在。

Self_Attention并不直接使用输入向量,而是先将其进行映射,使得输入向量的每个位置提供一个query和context,context充当字典。在context的每个位置都提供一个key和value向量。

query:尝试去获取某类信息的序列。

context:包含key序列和value序列,是query感兴趣的内容。

最终输出的形状将与query序列相同。

一个常见的类比是,这种操作就像字典查询。一个模糊的、可区分的、矢量的字典查询。

如下是一个普通的 python 字典类型数据,有 3 个键和 3 个值,并被传递给一个query——"What color is it ?"。这个query会与key="color"最契合,最终得到查询结果value="blue"

query是你要尝试去找的东西。key表示字典里有哪些信息,而value就是这些信息。当你在正则字典中查找一个query时,字典会找到匹配的key,并返回其相关的value。这个查询要么有一个匹配的键,要么没有。你可以想象一个模糊的字典,其中的键不一定要完全匹配。如果你在上面的字典中查找 query—"What species is it ?",也许你希望它返回 key="type",value="pickup",因为那是与query最匹配的key和value。

注意力层就像这样做了一个模糊查找,但它不仅仅是在寻找最好的key,而是根据query与每个key的匹配程度来组合这些value。

那是如何做到这一点的呢?在注意力层中,query、key和value都是向量。注意力层不是简单地做哈希查找,而是结合query和key向量来确定它们的匹配程度——计算query和key的向量点积,再将所有匹配程度经过Softmax映射完后,即得到 "注意力得分"。最终该层返回所有value的加权平均值,以 "注意力分数 "为权重。

对于一段具体的文本来说,每一个字都会引发一个疑问query,并提供一个关键值key和一个目标内容value。每个query都会去寻找感兴趣的key,并按注意力分数提取并组合value,

图中越粗的红线对应的attention权重更大,query与key的紧密程度也越近。attention权重如此分布也是很符合情理的,要想回答query =“他是谁?”,我们很大可能会在“是”后面寻找答案,因为“爱人”提供的信息最多,所以它俩的attention权重最大。

总的来说,Self_Attention模拟的是一个符合人脑思维逻辑的研究过程。每当遇到一些新的信息,我们总会产生一定量的疑问(query),为了解决疑问,我们需要在信息中捕捉关键字(key),进而凝练出该关键字中所蕴涵的答案(value)。特定的疑问(query)需要联系特定的关键字(key),进而得出最终答案,这个最终答案往往是折合了不同value而得来的。

2.MultiHead_Atention机制

在不同情景中,字引发的query是不同的,例如,

“他是男的,已婚。”

query可以是”他的性别是什么?”,或者”他结婚了吗?”。不同的query会产生不同的注意力分数。单一的Self_attention无法捕捉多层面query和key之间的依赖关系,因为它只进行一次attention的分配。意在解决此类局限性,MultiHead_Atention会计算多次Self_attention。

利用MultiHead_Atention机制,可以为每一个输入学习到一个信息量丰富的向量表示。

二、使MultiHeadAttention在TensorFlow中的代码实现

1.参数说明

TensorFlow中是用tf.keras.layers.MultiHeadAttention()实现的。它的参数分为两类,一种参数为初始化参数,存在于__init__方法中;另一种为调用参数,存在于call方法中。

主要的初始化参数:

num_heads:Self_Attention的层数

key_dim:输入分别经过query和key的多头映射层后,其各自输出在axis=-1上的长度。因为后续需要计算query和key的点积,所以需要保证query和key在最后一个轴上的长度相等。

value_dim:输入经过value的多头映射层后,其输出在axis=-1上的长度。如果不指定,则默认等于key_dim

output_shape:  输入经过整个MultiHeadAttention层后的输出shape,默认与进入query多头映射层的输入shape相同

主要的call方法参数:

'''  B即Batch_size,每一批中的样本数;

    T是query的个数,即一段序列产生的疑问个数;S是value和key的个数,即一段序列产生的关键字和关键信息的个数,序列产生的key和value是成对出现的,所以value映射层 和key映射层的输入张量在axis=1处的长度S相同。T和S是可以随意指定的,只需在样本集进入Embedding层之前,先通过一个dense层进行T和S的指定(T和S等于各自dense层中的神经元个数)。例如,文本集shape=(B, S),经过一个具有T个神经元的Dense层→shape=(B, T),再经Embedding层→shape=(B, T, dim),得到query映射层的输入张量。当然,如果不愿如此麻烦,可直接将经Embedding层得到shape=(B, S, dim)的张量作为query映射层的输入;

    dim通常是Embedding向量的长度(每个字对应一个Embedding向量)'''

query:输入query多头映射层且shape为(B, T, dim)的张量

value:输入value多头映射层且shape为(B, S, dim)的张量

key:输入key多头映射层且shape为(B, S, dim)的张量,如果未指定,则key=value

use_causal_mask:布尔值,是否开启causal_mask(因果掩码)机制

2.整体结构

tf.keras.layers.MultiHeadAttention类中call()方法的逻辑过程就是MultiHeadAttention的前向传播过程,我将其提炼成以下三部分,

        ''' 多头映射层 '''query = self._query_dense(query)key = self._key_dense(key)value = self._value_dense(value)''' 注意力层 '''attention_output, attention_scores = self._compute_attention(query, key, value, attention_mask, training)''' 输出层 '''attention_output = self._output_dense(attention_output)

3.多头映射层

由query多头映射层—query_dense,value多头映射层—value_dense,key多头映射层—key_dense组成。

每个映射层执行的张量运算是一样的,张量运算逻辑为,

                                                   ' abc , cde -> abde '               



该层的训练参数总数为,

4.注意力层

计算query与key之间的内积,张量运算逻辑为,

                                              ' aecd, abcd -> acbe '

内积能够反映向量之间的相关程度,内积结果越大则相关性越大,联系也越紧密。得到query和key的内积后,为了得到attention分数,需要将内积结果进行softmax映射。


最后利用attention分数对value进行加权叠加,张量运算逻辑为,

                                                        'acbe,aecd->abcd' 

  



注意:如果指定use_causal_mask=True引入Causal_Mask(因果掩码)机制,则在softmax映射时,会传入一个左下三角为True右上三角为False的,布尔类型的,且与attention_scores.shape相同掩码张量,此时掩码张量中为False的对应位置(对应attention内积张量)将会被softmax忽略。如此一来就会导致每个query只会与当前及其以前的key进行内积,并不会考虑未来的key。进而导致在每个query处产生的新value只会是当前value与过往value在sttention分数上的加权叠加。这样的结构是因果的,符合在预测中结果会对输入产生影响的因果逻辑。因果掩码会在Decoder中使用。


注意力层无可训练的参数。

5.输出映射层

属于MultiHeadAttention的最后一层,负责将注意力层得到的value在sttention分数上的加权叠加后的张量进行输出映射。张量运算逻辑为,

                                                       ' abcd, cde -> abe '



该层训练参数总共为,


验证

import tensorflow as tflayer = tf.keras.layers.MultiHeadAttention(num_heads=2, key_dim=2)
target = tf.keras.Input(shape=[9, 16])
source = tf.keras.Input(shape=[4, 16])
output_tensor, weights = layer(query=target, value=source,return_attention_scores=True)''' 手动计算训练参数总数 '''
sum = 16*2*2*3+2*2*3+2*2*16+16
print(f'手动计算的训练参数总数为 : {sum}')
print(f'训练参数总共为 : {layer.count_params()}')
print(f'输出shape为 : {output_tensor.shape}')
print(f'注意力分数shape为 : {weights.shape}')手动计算的训练参数总数为 : 284
训练参数总共为 : 284
输出shape为 : (None, 9, 16)
注意力分数shape为 : (None, 2, 9, 4)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/826617.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Linux Makefile用法

1、什么是makefile? Makefile:将不同模块放在不同的目录中,定义一系列的规则进行 “自动化编译”2、Makefile写法 vim makefile 填写样例: app:sub.c add.c mult.c div.c main.cgcc sub.c add.c mult.c div.c main.c -o app3、工作…

刷代码随想录有感(39):每层最大值

题干: 代码: /*** Definition for a binary tree node.* struct TreeNode {* int val;* TreeNode *left;* TreeNode *right;* TreeNode() : val(0), left(nullptr), right(nullptr) {}* TreeNode(int x) : val(x), left(nullptr), …

OpenCV基本图像处理操作(十一)——图像特征Sift算法

图像尺度空间 在一定的范围内,无论物体是大还是小,人眼都可以分辨出来,然而计算机要有相同的能力却很难,所以要让机器能够对物体在不同尺度下有一个统一的认知,就需要考虑图像在不同的尺度下都存在的特点。 尺度空间的…

《6G数据面架构研究》

目录 一、数据服务的定义二、6G数据服务驱动力及面临的挑战6G数据服务的业务驱动6G数据服务的技术驱动6G数据服务的网络内在驱动6G数据面面临的挑战 三、6G数据服务典型场景自动化网络运维用户体验提升通信感知数据服务 四、6G数据面架构研究数据面架构视图功能定义说明&#x…

kafka部分partition的leader=-1修复方案整理

kafka部分partition的leader-1修复方案整理 1. 背景说明2. 修复测试2.1 创建正常的topic并验证生产和消费2.2 停止kafka模拟leader-12.3 修复parition2.4 修复完成验证生产消费是否恢复 3. 疑问和思考3.1 kafka在进行数据消费时,如果有partition的leader-1&#xff…

从迷宫问题理解dfs

文章目录 迷宫问题打印路径1思路定义一个结构体要保存所走的路径,就需要使用到栈遍历所有的可能性核心代码 部分函数递归图源代码 迷宫问题返回最短路径这里的思想同上面类似。源代码 迷宫问题打印路径1 定义一个二维数组 N*M ,如 5 5 数组下所示&…

十一、Yocto集成tcpdump等网络工具

文章目录 Yocto集成tcpdump等网络工具networking layer集成 Yocto集成tcpdump等网络工具 本篇文章为基于raspberrypi 4B单板的yocto实战系列的第十一篇文章: 一、yocto 编译raspberrypi 4B并启动 二、yocto 集成ros2(基于raspberrypi 4B) 三、Yocto创建自定义的lay…

ctf.show_web14

在switch中,case 里如果没有 break,则会继续向下执行 case。 过滤了information_schema.tables、information_schema.column、空格 information_schema.tables 或 .columns 用反引号 information_schema.tables 同时查3个字段 ?query-1/**/union/**/…

ssh免秘钥登录与时钟同步

ssh免秘钥登录及数据拷贝 ssh免秘钥登录及数据拷贝环境生成秘钥拷贝公钥到到远程服务器通过ssh-copy-id命令拷贝公钥到远程服务器通过手动拷贝公钥到远程服务器 非root用户远程拷贝公钥 设置编码方式临时设置编码永久设置方法一永久设置方法二 设置时钟同步使用 ntpdate 命令使…

血糖长期不降,乏力、视力模糊?可能治疗方向有误。

糖尿病能不能治愈,中医能不能治,这是很多人讨论的话题,现在全世界任何一家西医医院都会告诉你,糖尿病没办法治,给你的建议也是终身服药,或者打胰岛素治疗,但是我告诉你,其实这都是治…

Jmeter04:关联

1 Jmeter组件:关联 概括:2个请求之间不是独立的,一个请求响应的结果是作为另一个请求提交的数据,存在数据交互 1.1 是什么? 就是一个请求的结果是另一个请求提交的数据,二者不再是独立 1.2 为什么&#x…

深入理解Java IO流:字符流

深入理解Java IO流:字符流 引言 在Java中,IO(输入/输出)操作是程序与外部世界交互的重要方式。 其中,File类是进行文件操作的基础,而字节流和字符流则是数据传输的两种主要方式。 本文将深入探讨这些概念及…

【Redis 神秘大陆】004 高可用集群

四、Redis 高可用和集群 当你发现这些内容对你有帮助时,为了支持我的工作,不妨给一个免费的⭐Star,这将是对我最大的鼓励!感谢你的陪伴与支持!一起在技术的路上共同成长吧!点击链接:GitHub | G…

算法总结篇 —— dfs(搜索、递归、回溯)

有些事情本来很遥远,你争取,它就会离你越来越近 概念 名词解释如何理解递归 dfs 回溯类问题蓝桥杯——飞机降落 dfs 迷宫搜索类问题蓝桥杯——岛屿个数 概念 名词解释 递归是指在一个函数的定义中调用自身的过程,有些复杂问题可以划分为多…

c++的学习之路:26、AVL树

摘要 本章主要是说一下AVL树的实现,这里说的是插入的底层原理 目录 摘要 一、原理 二、四种旋转 1、左单旋 2、右单旋 3、左右双旋 4、右左双旋 三、代码实现 1、节点创建 2、插入 3、旋转 4、判断是否平衡 5、测试 四、代码 一、原理 前面说了搜索…

为啥转化为可编辑面片后有这么多点和线

可以删一下 按住alt按移除可以删掉 选择你要删的那些线 按住alt点移除

eBay、亚马逊自养号测评如何避免风控账号关联选择合适网络IP环境

在自养号下单中选择适合的网络环境至关重要。经过多次实践与测试,积累了大量的经验,希望能够与大家分享,帮助大家避开陷阱,顺利前行。 市面上的网络环境种类繁多,从纯IP类的Luminati、Rola,到纯环境类的VM…

编写Spark独立应用程序

执行本文之前,先搭建好spark的开发环境,我目前只搭建了standalone模式,参考链接 : Spark Standalone模式部署-CSDN博客 1. 安装sbt 1)下载sbt 网址:https://www.scala-sbt.org/download.html &#xff0c…

【云原生 • Docker】 ELK 8.4.3 docker 保姆级安装部署详细步骤

文章目录 ELK简介二、版本说明三、安装部署3.1 创建docker网络3.2 Elasticsearch拉取docker镜像,版本:8.4.3第一次执行docker脚本可以看到控制台的信息,找到这个信息并保存下来创建Elasticsearch挂载目录给创建的文件夹授权将容器内的文件复制到主机上删除容器修改docker脚本…

【GPTs分享】GPTs分享之Image Recreate | img2img​

简介 该GPT是一个专门用于图像编辑、重建和合并的工具。它通过详细的自动图像描述和生成,帮助用户从源图像中重现或修改图像。此工具设计用于为视障用户提供图像内容的详细描述,并生成全新的图像,以满足特定的视觉需求。 主要功能 \1. 图像…