图解Attention学习笔记

教程是来自https://github.com/datawhalechina/learn-nlp-with-transformers/blob/main/docs/

图解Attention

Attention出现的原因是:基于循环神经网络(RNN)一类的seq2seq模型,在处理长文本时遇到了挑战,而对长文本中不同位置的信息进行attention有助于提升RNN的模型效果。

seq2seq框架

seq2seq(sequence to sequence)是一种常见的NLP模型结构,也就是从一个文本序列得到一个新的文本序列,其典型的任务是机器翻译任务、文本摘要任务。
seq2seq模型的输入可以是一个单词、字母或者图像特征序列,输出是另外一个单词、字母或图像特征序列,以机器翻译任务为例,序列指的是一连串的单词。
seq2seq模型由编码器(Encoder)和解码器(Decoder)组成,绿色的编码器会处理输入序列中的每个元素并获得输入信息,这些信息会被转换成为一个黄色的向量(称为context向量),当我们处理完整个输入序列后,编码器把context向量发送给紫色的解码器,解码器通过context向量中的信息,逐个元素输出新的序列。
在这里插入图片描述
以机器翻译任务为例,seq2seq在Transformer还没出现的时候,编码器和解码器一般采用的是循环神经网络RNN,编码器将输入的法语单词序列编码成context向量,然后解码器根据context向量解码出英语单词序列。
在这里插入图片描述
context向量本质是一组浮点数,而context的数组长度是基于编码器RNN的隐藏层神经元数量,上图是长度为4的context向量,但在实际应用中,context向量的长度是自定义的,比如可能是256,512或1024。
RNN具体处理输入序列的过程如下:

  1. 假设输入是一个句子,这个句子可以由 n n n个词表示: s e n t e n c e = w 1 , w 2 , . . . , w n sentence={w_1,w_2,...,w_n} sentence=w1,w2,...,wn
  2. RNN首先将句子中的每一个词映射成为一个向量,得到一个向量序列: X = x 1 , x 2 , . . . , x n X={x_1,x_2,...,x_n} X=x1,x2,...,xn,每个单词映射得到的向量通常又叫做word embedding
  3. 然后在处理第 t ∈ [ 1 , n ] t\in [1,n] t[1,n]个时间步的序列输入 x t x_t xt时,RNN网络的输入和输出可以表示为: h t = R N N ( x t , h t − 1 ) h_t=RNN(x_t,h_{t-1}) ht=RNN(xt,ht1)
  • 输入:RNN在时间步 t t t的输入之一为单词 w t w_t wt经过映射得到的向量 x t x_t xt,另一个输入是上一个时间步 t − 1 t-1 t1得到的hidden state向量 h t − 1 h_{t-1} ht1,同样是一个向量。
  • 输出:RNN在时间步 t t t的输出为 h t h_t hthidden state向量。
    下图是word embedding的例子,为了简单起见,每个单词被映射成一个长度为4的向量。
    在这里插入图片描述
    在这里插入图片描述
    看上面的动态图,相当于第一步 t 1 t_1 t1先把一个单词输入(转化为嵌入向量 x 1 x_1 x1),得到其对应的hidden state #1,然后根据hidden state #1 x 2 x_2 x2得到hidden state #2,以此类推,最终编码器输出的是最后一个hidden state,将其作为输入传给解码器,和编码器相同,解码器也是在每个时间步得到隐藏层状态,并传递到下一个时间步,一步步输出得到序列。

Attention

基于RNN的seq2seq模型编码器所有信息都编码到了一个context向量中,单个向量很难包含所有文本序列的信息,在处理长文本的时候,有长程依赖问题,因此Attention(注意力)机制被提出,这使得seq2seq模型可以有区分度、有重点的关注输入序列。
带有注意力的seq2seq模型结构有两点不同:

  • A. 编码器会把更多的数据传递给解码器。编码器把所有时间步的hidden state传递给解码器,而不是只传递最后一个hidden state,如图
    在这里插入图片描述
  • B. 注意力模型的解码器在输出之前,做了一个额外的attention处理。具体为:
  • a. 由于编码器中每个hidden state都对应到输入句子中的一个单词,那么解码器要查看所有接收到的编码器的hidden state
  • b. 给每个hidden state计算出一个分数。
  • c. 所有hidden state的分数经过softmax归一化。
  • d. 将每个hidden state乘以所对应的分数,从而能够让高分对应的hidden state会被放大,而低分对应的hidden state会被缩小。
  • e. 将所有hidden state根据对应分数进行加权求和,得到对应时间步的context向量。(下图的前三个时间步是编码器编码过程,第四个时间步是解码器开始解码)
    在这里插入图片描述
    以第四个时间步(解码器开始解码)为例:
  1. 注意力模型的解码器RNN的输入包括:一个word embedding向量,和一个初始化好的解码器hidden state,图中是 h i n i t h_{init} hinit
  2. RNN处理上述的两个输入,产生一个新的输出和一个新的hidden state(这里和之前一样),图中为 h 4 h_4 h4
  3. 注意力机制的步骤:使用编码器的所有hidden state向量和h4向量来计算这个时间步的context向量(C4).
  4. h4C4拼接起来,得到一个橙色向量。
  5. 把橙色向量输入一个前馈神经网络(这个网络是和整个模型一起训练的)。
  6. 根据前馈神经网络的输出向量得到输出单词:假设输出序列可能的单词有N个,那么这个前馈神经网络的输出向量通常是N维的,每个维度的下标对应一个输出单词,每个维度的数值对应的是该单词的输出概率。(编码器是把词进行词嵌入,将词汇表中的词映射为嵌入向量,解码器的输出是词汇表上每个词的概率,比如通过softmax层将该时间步的隐藏状态转换为词汇表上的概率,然后argmax取出概率最高的作为输出的单词)
  7. 在下一个时间步重复1-6步骤。
    在这里插入图片描述
    上图是解码器结合attention的全过程,最后是一段注意力机制的可视化,看看解码器在每个时间步关注了输入序列的哪些部分:
    在这里插入图片描述
    注意力模型不是无意识的将输出的第一个单词对应到输入的第一个单词,是在训练阶段学习到如何对两种语言的单词进行对应,例子中是法语和英语。
    在这里插入图片描述
    由上图可以看出,模型在输出"European Economic Area"时,注意力分布情况,法语和英语单词的顺序是颠倒的,注意力分布反映出了这一点。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/29895.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

华北水利水电大学-C程序设计作业

目录 基础题 1-1 分析 代码实现 1-2 分析 代码实现 1-3 分析 代码实现 1-4 ​编辑 分析 代码实现 1-5 分析 代码实现 1-6 分析 代码实现 基础题 1-1 从键盘输入10个学生的有关数据,然后把它们转存到磁盘文件上去。其中学生信息包括学号、姓名…

Redis变慢了?

Redis变慢了? 什么是Redis?测定Redis变慢?最大响应延迟平均响应延迟设置Redis慢日志 分析Redis变慢bigkeysbigkey的危害bigkey优化 写在最后 什么是Redis? 作为一个技术人员来说,大家用的最多的可能就是Redis了&#…

EMQX集群搭建

1. 什么是 MQTT? MQTT(Message Queuing Telemetry Transport)是一种轻量级、基于发布-订阅模式的消息传输协议,适用于资源受限的设备和低带宽、高延迟或不稳定的网络环境。它在物联网应用中广受欢迎,能够实现传感器、…

防火墙中的NAT

防火墙的NAT NAT分类 源NAT 基于源IP地址进行转换。 我们之前接触过的静态NAT,动态NAT,NAPT都属于源NAT,都是针对源IP地址进行转换的。源NAT主要目的是为了保证内网用户可以访问公网。 先执行安全策略,后执行NAT 目标NAT 基于…

git的分支管理

✨前言✨ 📘 博客主页:to Keep博客主页 🙆欢迎关注,👍点赞,📝留言评论 ⏳首发时间:20246月19日 📨 博主码云地址:博主码云地址 📕参考书籍&#x…

【TB作品】MSP430G2553,单片机,口袋板, 单相交流电压、电流计设计

题5 单相交流电压、电流计设计 设计基于MSP430的单相工频交流电参数检测仪。交流有效值0-220V,电流有效值0-40A。电压、电流值经电压、电流传感器输出有效值为0-5V的交流信号,传感器输出的电压、电流信号与被测电压、电流同相位。 基本要求如下 &#xf…

05、部署 YUM 仓库及NFS 共享服务

目录 5.1 部署YUM软件仓库 5.1.1 准备网络安装源(服务器端) 1、准备软件仓库目录 2、安装并启用vsftpd服务 5.1.2 配置软件仓库位置(客户端) 5.2 使用yum工具管理软件包 5.2.1 查询软件包 1、yum list——查询软件包列表 …

DGit的使用

将Remix连接到远程Git仓库 1.指定克隆的分支和深度 2.清理,如果您不在工作区上工作,请将其删除或推送至 GitHub 或 IPFS 以确保安全。 为了进行推送和拉取,你需要一个 PAT — 个人访问令牌 当使用 dGIT 插件在 GitHub 上推送、拉取、访问私…

网关助力边缘物联网

网关助力边缘物联网 在探讨网关如何助力边缘物联网(IoT)的议题时,我们不得不深入分析这一技术交汇点的复杂性与潜力。边缘计算与物联网的融合,通过将数据处理与分析能力推向网络边缘,即数据生成的地方,极大…

接口性能提升秘籍:本地缓存的总结与实践

🍅我是小宋, 一个只熬夜但不秃头的Java程序员。 🍅关注我,带你轻松过面试。提升简历亮点(14个demo) . . 🌏号:tutou123com。拉你进面试专属群。 优雅的接口调优之本地缓存优化 接口…

Spring中网络请求客户端WebClient的使用详解

Spring中网络请求客户端WebClient的使用详解_java_脚本之家 Spring5的WebClient使用详解-腾讯云开发者社区-腾讯云 在 Spring 5 之前,如果我们想要调用其他系统提供的 HTTP 服务,通常可以使用 Spring 提供的 RestTemplate 来访问,不过由于 …

初识es(elasticsearch)

初识elasticsearch 什么是elasticsearch?: 一个开源的分部署搜索引擎、可以用来实现搜索、日志统计、分析、系统监控等功能。 什么是文档和词条? 每一条数据就是一个文档对文档中的内容进行分词,得到的词语就是词条 什么是正向…

【elementui源码解析】如何实现自动渲染md文档-第四篇

目录 1.前言 2.md-loader - index.js 1)md.render() 2)定义变量 3)while stripTemplate stripScript genInlineComponentText 4)pageScript 5)return 6)demo-block 3.总结 所有章节&#x…

微纳米气泡发生器是微纳米气泡产生装置 未来市场需求将不断释放

微纳米气泡发生器是微纳米气泡产生装置 未来市场需求将不断释放 微纳米气泡发生器即微纳米气泡发生设备,是一种将水和气体混合并产生微纳米气泡的设备。微纳米气泡是指直径在100μm以下的气泡,分为纳米气泡和微米气泡。   微纳米气泡发生器主要由发生设…

录屏录音两不误!电脑录屏录音软件推荐(3款)

在数字化时代,电脑录屏录音软件已成为教学、演示、会议记录等领域不可或缺的工具。它们能够捕捉屏幕上的每一个动作,同时录制音频,为用户提供直观、生动的视听材料。本文将详细介绍三种常用的电脑录屏录音软件,帮助读者了解并掌握…

误删的文件不在回收站如何找回?6个恢复秘诀分享!

“我刚刚误删了一些文件,但是在回收站中没有看到这部分文件,这种情况下还有方法可以找回误删的文件吗?在线等一个答案!” 在数字化时代,文件的安全和完整性对于个人和企业都至关重要。然而,有时候由于疏忽或…

【Android】使用SeekBar控制数据的滚动

项目需求 有一个文本数据比较长,需要在文本右侧加一个SeekBar,然后根据SeekBar的上下滚动来控制文本的滚动。 项目实现 我们使用TextView来显示文本,但是文本比较长的话,需要在TextView外面套一个ScrollView,但是我…

利用K8S技术栈打造个人私有云

1.三个节点:master,slave,client 在Kubernetes集群中,三个节点的职责分别如下: Master节点: docker:用于运行Docker容器。 etcd:一个分布式键值存储系统,用于保存Kuberne…

42、基于神经网络的训练堆叠自编码器进行图像分类(matlab)

1、训练堆叠自编码器进行图像分类的原理及流程 基于神经网络的训练堆叠自编码器进行图像分类的原理和流程如下: 堆叠自编码器(Stacked Autoencoder)是一种无监督学习算法,由多个自编码器(Autoencoder)堆叠…

宝塔软件默认安装位置

自带的JDK /usr/local/btjdk/jdk8Tomcat 各个版本都在bttomcat这个文件夹下面,用版本区分。tomcat_bak8是备份文件 /usr/local/bttomcat/tomcat8nginx /www/server/nginxnginx配置文件存放目录 /www/server/panel/vhost/nginxredis /www/server/redismysql /…