程序员学长 | 当 LSTM 遇上 Attention

本文来源公众号“程序员学长”,仅用于学术分享,侵权删,干货满满。

原文链接:当 LSTM 遇上 Attention

今天我们一起来聊一下深度学习中的注意力(Attention)机制

注意力机制是深度学习中引入的一种技术,特别适用于序列到序列的任务(Sequence to Sequence,Seq2Seq)。通过引入注意力机制,Seq2Seq 模型能够在解码每个时间步时,动态地选择和关注输入序列中的不同部分,从而更好地捕捉输入序列的全局信息。

在讨论注意力机制之前,我们先来了解一下 Seq2Seq 模型。

Seq2Seq

序列到序列(Seq2Seq)模型是一种深度学习架构,广泛应用于将一个序列数据转换为另一个序列数据的任务中,例如机器翻译、自动问答、语音识别等。这种模型特别适用于输入序列和输出序列长度不固定的情况。

基本结构

序列到序列模型通常由两部分组成:编码器(Encoder)和解码器(Decoder)。

  1. 编码器

    编码器的作用是接受输入序列,并将其转换成一个固定大小的状态向量(通常称为上下文向量)。这个向量旨在捕捉输入序列的关键信息。

    在实现上,编码器通常是一个循环神经网络(RNN)或其变体,如长短期记忆网络(LSTM)或门控循环单元(GRU)。

    关于 RNN、LSTM 以及 GRU 可以参考如下文章:程序员学长 | 快速学会一个算法,RNN-CSDN博客和程序员学长 | 快速学会一个算法模型,LSTM-CSDN博客。

  1. 解码器

    解码器的任务是将编码器生成的状态向量转换为输出序列。它从编码器传递的上下文向量开始生成输出,并逐步生成输出序列中的每个元素。

    解码器通常也是基于RNN、LSTM或GRU构建的,它在生成每个输出元素时会参考前一个元素的输出,以及编码器的上下文向量。

工作流程

序列到序列模型的工作流程可以概括为以下几步:

  1. 输入处理

    将输入序列(如文本、语音等)转化为模型能够处理的格式,通常是一系列编码向量。

  2. 序列编码

    通过编码器处理输入向量,逐步更新内部状态,最终生成一个紧凑的上下文向量。

  3. 状态传递

    上下文向量被传递给解码器,作为其初始化状态。

  4. 序列解码

    解码器根据上下文向量逐步生成输出序列的每个元素。在生成每个新元素时,解码器会考虑已生成的序列和从编码器接收到的上下文。

  5. 输出生成

    解码器输出的序列经过后处理(如解码或转换)后形成最终的输出序列。

Seq2Seq 模型的缺点
  1. 固定大小的上下文向量

    在传统的 Seq2Seq 模型中,无论输入序列的长度如何,编码器都必须将所有的信息压缩到一个固定大小的上下文向量中。这可能导致信息丢失,特别是在处理长序列时。

  2. 长距离依赖问题

    尽管LSTM和GRU设计用来缓解梯度消失问题,并能在一定程度上处理长距离依赖,但在实际应用中,当序列非常长时,模型仍然难以捕捉序列中的远距离依赖关系。

在 Seq2Seq 中引入 Attention

如上图所示,在编码器和解码器中加入了注意力机制。

注意力权重的计算

案例说明

我们使用的示例是一个将句子从英语翻译成意大利语的网络。

该网络由两部分组成:

  1. 编码器,它对英语句子的含义进行编码;

  2. 解码器,它将编码的信息解码为句子到意大利语的翻译。

现在,当我们在编码器部分完成对句子信息的提取后,我们就可以开始解码信息并使用解码器将句子翻译成意大利语了。

解码器的第一个输入是一个起始标记,以及初始隐藏状态和上下文向量,它们构成了第一个隐藏状态。对于该隐藏状态,我们可以得到新句子的第一个输出。

我们将该输出用作下一步的输入,与前一个隐藏状态和上下文向量一起构建新的隐藏状态的输出。

该过程持续,直到我们获得停止标记作为输出。

这个过程对于短句很有效,但当句子变长时可能会失败。原因是解码器在所有步骤中使用上下文向量,并且需要它包含有关原始句子的所有信息。对于长句,将全部信息保存在一个固定大小的向量中可能非常困难。

在 seq2seq 模型添加 Attention

我们将保留以前的编码器-解码器架构,但这次我们在网络中添加了另一种机制,为解码器的每个步骤构建一个新的上下文向量

我们的编码器仍然像以前一样遍历输入序列并创建隐藏状态,最后为解码器创建初始隐藏状态。

现在,我们不再使用编码器的最终隐藏状态来制作上下文向量,而是使用解码器的初始隐藏状态和所有其他隐藏状态来构建它。

为此,我们将实现一个对齐函数,它是对编码器的隐藏状态和解码器的隐藏状态进行操作。此函数计算编码器每个隐藏状态的对齐分数(标量)。

这些分数表明,在给定解码器当前隐藏状态的情况下,我们应该在多大程度上关注编码器的每个隐藏状态。

这些概率是标准化的对齐分数,它们将用作编码器隐藏状态的注意力权重。新的上下文向量将是编码器隐藏状态乘以注意力权重的加权和。

下面是如何在 PyTorch 中实现 LSTM 注意力机制的基本示例。

import torch
import torch.nn as nn
import torch.nn.functional as Fclass EncoderLSTM(nn.Module):def __init__(self, input_dim, emb_dim, hidden_dim, n_layers):super(EncoderLSTM, self).__init__()self.embedding = nn.Embedding(input_dim, emb_dim)self.rnn = nn.LSTM(emb_dim, hidden_dim, n_layers, batch_first=True)def forward(self, src):embedded = self.embedding(src)outputs, (hidden, cell) = self.rnn(embedded)return outputs, hidden, cellclass Attention(nn.Module):def __init__(self):super(Attention, self).__init__()def forward(self, encoder_outputs, decoder_hidden):# encoder_outputs: (batch_size, seq_len, hidden_dim)# decoder_hidden: (batch_size, hidden_dim)# Calculate the attention scores.scores = torch.bmm(encoder_outputs, decoder_hidden.unsqueeze(2)).squeeze(2)  # (batch_size, seq_len)attn_weights = F.softmax(scores, dim=1)  # (batch_size, seq_len)context_vector = torch.bmm(attn_weights.unsqueeze(1), encoder_outputs).squeeze(1)  # (batch_size, hidden_dim)return context_vector, attn_weightsclass DecoderLSTMWithAttention(nn.Module):def __init__(self, output_dim, emb_dim, hidden_dim, n_layers):super(DecoderLSTMWithAttention, self).__init__()self.embedding = nn.Embedding(output_dim, emb_dim)self.rnn = nn.LSTM(emb_dim + hidden_dim, hidden_dim, n_layers, batch_first=True)self.out = nn.Linear(hidden_dim, output_dim)self.attention = Attention()def forward(self, input, encoder_outputs, hidden, cell):input = input.unsqueeze(1)  # (batch_size, 1)embedded = self.embedding(input)  # (batch_size, 1, emb_dim)context_vector, attn_weights = self.attention(encoder_outputs, hidden[-1])  # using the last layer's hidden staternn_input = torch.cat([embedded, context_vector.unsqueeze(1)], dim=2)  # (batch_size, 1, emb_dim + hidden_dim)output, (hidden, cell) = self.rnn(rnn_input, (hidden, cell))prediction = self.out(output.squeeze(1))return prediction, hidden, cell# Example usage
INPUT_DIM = 1000  # e.g., size of the source language vocabulary
OUTPUT_DIM = 1000  # e.g., size of the target language vocabulary
EMB_DIM = 256
HIDDEN_DIM = 512
N_LAYERS = 2encoder = EncoderLSTM(INPUT_DIM, EMB_DIM, HIDDEN_DIM, N_LAYERS)
decoder = DecoderLSTMWithAttention(OUTPUT_DIM, EMB_DIM, HIDDEN_DIM, N_LAYERS)src_seq = torch.randint(0, INPUT_DIM, (32, 10))  # batch of 32, sequence length 10
encoder_outputs, hidden, cell = encoder(src_seq)input = torch.randint(0, OUTPUT_DIM, (32,))  # batch of 32, single time step
output, hidden, cell = decoder(input, encoder_outputs, hidden, cell)

THE END !

文章结束,感谢阅读。您的点赞,收藏,评论是我继续更新的动力。大家有推荐的公众号可以评论区留言,共同学习,一起进步。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/39967.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

「前端」快速排序算法演示

快速排序算法演示。 布局描述 一个简单的HTML页面,用户可以在其中输入一系列用逗号分隔的数字。 一个CSS样式表,提供了一个美观大方的布局和样式。 一个JavaScript脚本,实现了快速排序算法,并在用户点击按钮时对输入的数字进行排序,并显示结果。 效果演示 核心代码 <…

Mysql-基础-DDL操作

1、数据库操作 查询 查询所有数据库 show databases; 创建 创建数据库 create database [if not exists] 数据库名 使用及查询 use 数据库名 select database() 查询当前所处数据库 删除 drop database [if not exists] 数据库名 2、表操作 查询当前库中的所…

【ArcGIS Pro 加载项】修复图层名为要素类别名

ArcPro从目录添加要素类至内容列表&#xff0c;图层名称默认为要素类别名。 但是一番操作之后&#xff0c;这个图层名称可能会被你改了&#xff0c;想复原的话就要手动去图层属性里面复制要素类名称或者别名来重命名了&#xff0c;多少有点不方便。 所以小编通过SDK制作了这个…

firewalld防火墙(二)

一&#xff1a;firewalld高级配置 1&#xff1a;关于iptables的知识 iptables 是Linux系统中传统的命令行防火墙管理工具&#xff0c;它基于内核的netfilter框架工作&#xff0c;用于配置和管理网络规则集&#xff0c;比如过滤&#xff08;允许/拒绝&#xff09;进出的数据包…

python3.8安装详细教程

python3.8下载及安装详细教程 Python 3.8 是一个重要的Python版本&#xff0c;它引入了一系列新功能和改进。以下是对Python 3.8的详细概述&#xff0c;包括其关键特性、安装方法以及版本状态等信息。 Python 3.8的关键特性 海象运算符&#xff08;Walrus Operator&#xff09…

工程文件参考——CubeMX+LL库+SPI主机 阻塞式通用库

文章目录 前言CubeMX配置SPI驱动实现spi_driver.hspi_driver.c 额外的接口补充 前言 SPI&#xff0c;想了很久没想明白其DMA或者IT比较好用的方法&#xff0c;可能之后也会写一个 我个人使用场景大数据流不多&#xff0c;如果是大批量数据交互自然是DMA更好用&#xff0c;但考…

reggie外卖优化

文章目录 一、redis缓存1.1 缓存验证码1.2 缓存菜品数据 二、spring-cache 一、redis缓存 1.1 缓存验证码 不用sesiion&#xff0c;而使用redis来存放验证码。 首先在用户请求验证码&#xff0c;将验证码保存在sesion中&#xff0c;当登录成功之后&#xff0c;将redis中的验证…

Tekla Structures钢结构详图设计软件下载;Tekla Structures高效、准确的合作平台

Tekla Structures&#xff0c;它不仅集成了先进的三维建模技术&#xff0c;还融入了丰富的工程实践经验&#xff0c;为设计师、工程师和建筑商提供了一个高效、准确的合作平台。 在建筑项目的整个生命周期中&#xff0c;Tekla Structures都发挥着举足轻重的作用。从规划阶段开始…

Java Map转泛型对象

Springboot Map转泛型对象 import org.springframework.beans.PropertyAccessorFactory;public abstract class AbstractGoodsProcessor<T>{/*** 封装对象**/Overridepublic T assembleOfBeforeCheck(Map<String, Object> map){return this.mapToParam(map,0);};O…

录音转文字软件免费版哪个好?6个转文字工具让你轻松记录

随着小暑的到来&#xff0c;炎热的天气容易让人心浮气躁&#xff0c;影响工作效率。 在这个季节里&#xff0c;掌握一些办公技巧尤为关键。尤其是当我们需要整理会议记录或讲座内容时&#xff0c;如果能有一种方法&#xff0c;可以迅速将那些冗长的录音转换成清晰的文字&#…

使用Retrofit2+OkHttp监听上传或者下载进度会执行两次的问题

使用Retrofit2OkHttp监听上传或者下载进度RequestBody#writeTo/ResponseBody#source 会执行两次的问题 example&#xff1a; 问题原因&#xff1a; 使用了HttpLoggingInterceptor拦截器&#xff0c;并且日志等级为HttpLoggingInterceptor.Level.BODY 问题解决&#xff1a;

day08. 02 Python中的位运算符案例与解析

理解并掌握Python中的位运算符&#xff1a;异或&#xff08;^&#xff09;、与&#xff08;&&#xff09;、或&#xff08;|&#xff09;、反&#xff08;~&#xff09;、右移&#xff08;>>&#xff09;、无符号右移&#xff08;>>>&#xff0c;注意Python…

艾滋病隐球菌病的病原学诊断方法包括?

艾滋病隐球菌病的病原学诊断方法包括()查看答案 A.培养B.隐球菌抗原C.墨汁染色D.PCR 在感染性疾病研究中&#xff0c;单细胞转录组学的应用包括哪些()? A.细胞异质性研究B.基因组突变检测C.感染过程单细胞分析D.代谢通路分析 开展病原微生物网络实验室体系建设&#xff0c;应通…

Linux--平台设备、平台驱动的注册源码分析

一、设备和驱动的注册 设备注册两种方式&#xff1a; 1、从设备树解析动态注册。设备树dts文件中定义了设备节点&#xff0c;描述了硬件信息&#xff0c;比如寄存器信息&#xff0c;引脚信息等&#xff0c;内核将从设备树中解析得到的platform_device注册到平台总线中。具体设…

一个opencv实现检测程序

引言 图像处理是计算机视觉中的一个重要领域&#xff0c;它在许多应用中扮演着关键角色&#xff0c;如自动驾驶、医疗图像分析和人脸识别等。边缘检测是图像处理中的基本任务之一&#xff0c;它用于识别图像中的显著边界。本文将通过一个基于 Python 和 OpenCV 的示例程序&…

UniApp 中 Web/H5 正确使用反向代理解决跨域问题

因为 Vue3 的构建工具是 Vite&#xff0c;所以配置 vue.config.js 是没用的&#xff08;Vue2 因为使用 webpack 所以才用这个文件&#xff09; 这里提供一份 vue.config.js 的示例&#xff1a; module.exports {devServer: {proxy: {/api: {target: http://example.com,chan…

Python学习速成必备知识,(20道练习题)!

基础题练习 1、打印出1-100之间的所有偶数&#xff1a; for num in range(1, 101):if num % 2 0:print(num) 2、打印出用户输入的字符串的长度&#xff1a; string input("请输入一个字符串&#xff1a;")print("字符串的长度为&#xff1a;", len(str…

使用Python进行文件合并和分割

哈喽,大家好,我是木头左! 引言 在数据处理过程中,经常需要将多个文件合并为一个文件,或者将一个大文件分割成多个小文件。Python作为一种功能强大的编程语言,提供了多种方法来实现这一目标。本文将介绍如何使用Python进行文件合并和分割。 文件合并 1. 逐行合并 最简…

Git不想跟踪某个文件

如果你不想跟踪某个文件&#xff0c;可以将该文件路径添加到 .gitignore 文件中。.gitignore 文件用于告诉 Git 哪些文件或目录应该被忽略&#xff0c;不进行版本控制。以下是具体步骤&#xff1a; 编辑 .gitignore 文件&#xff1a;在项目的根目录下找到或创建一个 .gitignore…

More Effective C++ 35个改善编程与设计的有效方法笔记与心得 5

五. 技术 条款25&#xff1a; 将 constructor 和 non-member functions虚化 请记住&#xff1a; 1. 利用重载技术&#xff08;overload&#xff09;避免隐式类型转换&#xff08;implicit type conversions&#xff09; ‌‌‌‌  重载技术是指在同一个作用域中声明多个同…