程序员学长 | 当 LSTM 遇上 Attention

本文来源公众号“程序员学长”,仅用于学术分享,侵权删,干货满满。

原文链接:当 LSTM 遇上 Attention

今天我们一起来聊一下深度学习中的注意力(Attention)机制

注意力机制是深度学习中引入的一种技术,特别适用于序列到序列的任务(Sequence to Sequence,Seq2Seq)。通过引入注意力机制,Seq2Seq 模型能够在解码每个时间步时,动态地选择和关注输入序列中的不同部分,从而更好地捕捉输入序列的全局信息。

在讨论注意力机制之前,我们先来了解一下 Seq2Seq 模型。

Seq2Seq

序列到序列(Seq2Seq)模型是一种深度学习架构,广泛应用于将一个序列数据转换为另一个序列数据的任务中,例如机器翻译、自动问答、语音识别等。这种模型特别适用于输入序列和输出序列长度不固定的情况。

基本结构

序列到序列模型通常由两部分组成:编码器(Encoder)和解码器(Decoder)。

  1. 编码器

    编码器的作用是接受输入序列,并将其转换成一个固定大小的状态向量(通常称为上下文向量)。这个向量旨在捕捉输入序列的关键信息。

    在实现上,编码器通常是一个循环神经网络(RNN)或其变体,如长短期记忆网络(LSTM)或门控循环单元(GRU)。

    关于 RNN、LSTM 以及 GRU 可以参考如下文章:程序员学长 | 快速学会一个算法,RNN-CSDN博客和程序员学长 | 快速学会一个算法模型,LSTM-CSDN博客。

  1. 解码器

    解码器的任务是将编码器生成的状态向量转换为输出序列。它从编码器传递的上下文向量开始生成输出,并逐步生成输出序列中的每个元素。

    解码器通常也是基于RNN、LSTM或GRU构建的,它在生成每个输出元素时会参考前一个元素的输出,以及编码器的上下文向量。

工作流程

序列到序列模型的工作流程可以概括为以下几步:

  1. 输入处理

    将输入序列(如文本、语音等)转化为模型能够处理的格式,通常是一系列编码向量。

  2. 序列编码

    通过编码器处理输入向量,逐步更新内部状态,最终生成一个紧凑的上下文向量。

  3. 状态传递

    上下文向量被传递给解码器,作为其初始化状态。

  4. 序列解码

    解码器根据上下文向量逐步生成输出序列的每个元素。在生成每个新元素时,解码器会考虑已生成的序列和从编码器接收到的上下文。

  5. 输出生成

    解码器输出的序列经过后处理(如解码或转换)后形成最终的输出序列。

Seq2Seq 模型的缺点
  1. 固定大小的上下文向量

    在传统的 Seq2Seq 模型中,无论输入序列的长度如何,编码器都必须将所有的信息压缩到一个固定大小的上下文向量中。这可能导致信息丢失,特别是在处理长序列时。

  2. 长距离依赖问题

    尽管LSTM和GRU设计用来缓解梯度消失问题,并能在一定程度上处理长距离依赖,但在实际应用中,当序列非常长时,模型仍然难以捕捉序列中的远距离依赖关系。

在 Seq2Seq 中引入 Attention

如上图所示,在编码器和解码器中加入了注意力机制。

注意力权重的计算

案例说明

我们使用的示例是一个将句子从英语翻译成意大利语的网络。

该网络由两部分组成:

  1. 编码器,它对英语句子的含义进行编码;

  2. 解码器,它将编码的信息解码为句子到意大利语的翻译。

现在,当我们在编码器部分完成对句子信息的提取后,我们就可以开始解码信息并使用解码器将句子翻译成意大利语了。

解码器的第一个输入是一个起始标记,以及初始隐藏状态和上下文向量,它们构成了第一个隐藏状态。对于该隐藏状态,我们可以得到新句子的第一个输出。

我们将该输出用作下一步的输入,与前一个隐藏状态和上下文向量一起构建新的隐藏状态的输出。

该过程持续,直到我们获得停止标记作为输出。

这个过程对于短句很有效,但当句子变长时可能会失败。原因是解码器在所有步骤中使用上下文向量,并且需要它包含有关原始句子的所有信息。对于长句,将全部信息保存在一个固定大小的向量中可能非常困难。

在 seq2seq 模型添加 Attention

我们将保留以前的编码器-解码器架构,但这次我们在网络中添加了另一种机制,为解码器的每个步骤构建一个新的上下文向量

我们的编码器仍然像以前一样遍历输入序列并创建隐藏状态,最后为解码器创建初始隐藏状态。

现在,我们不再使用编码器的最终隐藏状态来制作上下文向量,而是使用解码器的初始隐藏状态和所有其他隐藏状态来构建它。

为此,我们将实现一个对齐函数,它是对编码器的隐藏状态和解码器的隐藏状态进行操作。此函数计算编码器每个隐藏状态的对齐分数(标量)。

这些分数表明,在给定解码器当前隐藏状态的情况下,我们应该在多大程度上关注编码器的每个隐藏状态。

这些概率是标准化的对齐分数,它们将用作编码器隐藏状态的注意力权重。新的上下文向量将是编码器隐藏状态乘以注意力权重的加权和。

下面是如何在 PyTorch 中实现 LSTM 注意力机制的基本示例。

import torch
import torch.nn as nn
import torch.nn.functional as Fclass EncoderLSTM(nn.Module):def __init__(self, input_dim, emb_dim, hidden_dim, n_layers):super(EncoderLSTM, self).__init__()self.embedding = nn.Embedding(input_dim, emb_dim)self.rnn = nn.LSTM(emb_dim, hidden_dim, n_layers, batch_first=True)def forward(self, src):embedded = self.embedding(src)outputs, (hidden, cell) = self.rnn(embedded)return outputs, hidden, cellclass Attention(nn.Module):def __init__(self):super(Attention, self).__init__()def forward(self, encoder_outputs, decoder_hidden):# encoder_outputs: (batch_size, seq_len, hidden_dim)# decoder_hidden: (batch_size, hidden_dim)# Calculate the attention scores.scores = torch.bmm(encoder_outputs, decoder_hidden.unsqueeze(2)).squeeze(2)  # (batch_size, seq_len)attn_weights = F.softmax(scores, dim=1)  # (batch_size, seq_len)context_vector = torch.bmm(attn_weights.unsqueeze(1), encoder_outputs).squeeze(1)  # (batch_size, hidden_dim)return context_vector, attn_weightsclass DecoderLSTMWithAttention(nn.Module):def __init__(self, output_dim, emb_dim, hidden_dim, n_layers):super(DecoderLSTMWithAttention, self).__init__()self.embedding = nn.Embedding(output_dim, emb_dim)self.rnn = nn.LSTM(emb_dim + hidden_dim, hidden_dim, n_layers, batch_first=True)self.out = nn.Linear(hidden_dim, output_dim)self.attention = Attention()def forward(self, input, encoder_outputs, hidden, cell):input = input.unsqueeze(1)  # (batch_size, 1)embedded = self.embedding(input)  # (batch_size, 1, emb_dim)context_vector, attn_weights = self.attention(encoder_outputs, hidden[-1])  # using the last layer's hidden staternn_input = torch.cat([embedded, context_vector.unsqueeze(1)], dim=2)  # (batch_size, 1, emb_dim + hidden_dim)output, (hidden, cell) = self.rnn(rnn_input, (hidden, cell))prediction = self.out(output.squeeze(1))return prediction, hidden, cell# Example usage
INPUT_DIM = 1000  # e.g., size of the source language vocabulary
OUTPUT_DIM = 1000  # e.g., size of the target language vocabulary
EMB_DIM = 256
HIDDEN_DIM = 512
N_LAYERS = 2encoder = EncoderLSTM(INPUT_DIM, EMB_DIM, HIDDEN_DIM, N_LAYERS)
decoder = DecoderLSTMWithAttention(OUTPUT_DIM, EMB_DIM, HIDDEN_DIM, N_LAYERS)src_seq = torch.randint(0, INPUT_DIM, (32, 10))  # batch of 32, sequence length 10
encoder_outputs, hidden, cell = encoder(src_seq)input = torch.randint(0, OUTPUT_DIM, (32,))  # batch of 32, single time step
output, hidden, cell = decoder(input, encoder_outputs, hidden, cell)

THE END !

文章结束,感谢阅读。您的点赞,收藏,评论是我继续更新的动力。大家有推荐的公众号可以评论区留言,共同学习,一起进步。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/39967.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

「前端」快速排序算法演示

快速排序算法演示。 布局描述 一个简单的HTML页面,用户可以在其中输入一系列用逗号分隔的数字。 一个CSS样式表,提供了一个美观大方的布局和样式。 一个JavaScript脚本,实现了快速排序算法,并在用户点击按钮时对输入的数字进行排序,并显示结果。 效果演示 核心代码 <…

Mysql-基础-DDL操作

1、数据库操作 查询 查询所有数据库 show databases; 创建 创建数据库 create database [if not exists] 数据库名 使用及查询 use 数据库名 select database() 查询当前所处数据库 删除 drop database [if not exists] 数据库名 2、表操作 查询当前库中的所…

【ArcGIS Pro 加载项】修复图层名为要素类别名

ArcPro从目录添加要素类至内容列表&#xff0c;图层名称默认为要素类别名。 但是一番操作之后&#xff0c;这个图层名称可能会被你改了&#xff0c;想复原的话就要手动去图层属性里面复制要素类名称或者别名来重命名了&#xff0c;多少有点不方便。 所以小编通过SDK制作了这个…

python3.8安装详细教程

python3.8下载及安装详细教程 Python 3.8 是一个重要的Python版本&#xff0c;它引入了一系列新功能和改进。以下是对Python 3.8的详细概述&#xff0c;包括其关键特性、安装方法以及版本状态等信息。 Python 3.8的关键特性 海象运算符&#xff08;Walrus Operator&#xff09…

工程文件参考——CubeMX+LL库+SPI主机 阻塞式通用库

文章目录 前言CubeMX配置SPI驱动实现spi_driver.hspi_driver.c 额外的接口补充 前言 SPI&#xff0c;想了很久没想明白其DMA或者IT比较好用的方法&#xff0c;可能之后也会写一个 我个人使用场景大数据流不多&#xff0c;如果是大批量数据交互自然是DMA更好用&#xff0c;但考…

reggie外卖优化

文章目录 一、redis缓存1.1 缓存验证码1.2 缓存菜品数据 二、spring-cache 一、redis缓存 1.1 缓存验证码 不用sesiion&#xff0c;而使用redis来存放验证码。 首先在用户请求验证码&#xff0c;将验证码保存在sesion中&#xff0c;当登录成功之后&#xff0c;将redis中的验证…

Tekla Structures钢结构详图设计软件下载;Tekla Structures高效、准确的合作平台

Tekla Structures&#xff0c;它不仅集成了先进的三维建模技术&#xff0c;还融入了丰富的工程实践经验&#xff0c;为设计师、工程师和建筑商提供了一个高效、准确的合作平台。 在建筑项目的整个生命周期中&#xff0c;Tekla Structures都发挥着举足轻重的作用。从规划阶段开始…

录音转文字软件免费版哪个好?6个转文字工具让你轻松记录

随着小暑的到来&#xff0c;炎热的天气容易让人心浮气躁&#xff0c;影响工作效率。 在这个季节里&#xff0c;掌握一些办公技巧尤为关键。尤其是当我们需要整理会议记录或讲座内容时&#xff0c;如果能有一种方法&#xff0c;可以迅速将那些冗长的录音转换成清晰的文字&#…

使用Retrofit2+OkHttp监听上传或者下载进度会执行两次的问题

使用Retrofit2OkHttp监听上传或者下载进度RequestBody#writeTo/ResponseBody#source 会执行两次的问题 example&#xff1a; 问题原因&#xff1a; 使用了HttpLoggingInterceptor拦截器&#xff0c;并且日志等级为HttpLoggingInterceptor.Level.BODY 问题解决&#xff1a;

一个opencv实现检测程序

引言 图像处理是计算机视觉中的一个重要领域&#xff0c;它在许多应用中扮演着关键角色&#xff0c;如自动驾驶、医疗图像分析和人脸识别等。边缘检测是图像处理中的基本任务之一&#xff0c;它用于识别图像中的显著边界。本文将通过一个基于 Python 和 OpenCV 的示例程序&…

Python学习速成必备知识,(20道练习题)!

基础题练习 1、打印出1-100之间的所有偶数&#xff1a; for num in range(1, 101):if num % 2 0:print(num) 2、打印出用户输入的字符串的长度&#xff1a; string input("请输入一个字符串&#xff1a;")print("字符串的长度为&#xff1a;", len(str…

excel表格如何换行,这几个操作方法要收藏好

Excel表格作为一款强大的数据处理工具&#xff0c;在日常工作和生活中被广泛应用。当需要在单元格内显示较长的文本内容或使数据更加清晰易读时&#xff0c;我们需要掌握一些换行技巧。下面将介绍几种常用的Excel换行方法&#xff1a; 一、使用快捷键换行 1、首先&#xff0c;…

暑假学习DevEco Studio第一天

学习目标&#xff1a; 掌握构建第一个ArkTS应用 学习内容&#xff1a; 容器的应用 创建流程 点击file&#xff0c;new-> create project 点击empty ->next 进入配置界面 点击finsh&#xff0c;生成下面图片 这里需要注意记住index.ets &#xff0c;这是显示页面 –…

五款免费可视化利器分享,助力打造数字孪生新体验!

在当今数据驱动的时代&#xff0c;可视化工具已成为各行各业不可或缺的助手。它们不仅能帮助我们更好地理解和分析数据&#xff0c;还能以直观、生动的方式呈现复杂信息&#xff0c;提升沟通和决策效率。本文将为大家介绍五款免费的可视化工具&#xff0c;总有一款适合你。 一…

如何做好企业品牌推广,看这篇文章就够了

在当今竞争激烈的市场环境中&#xff0c;品牌推广策略与方式成为企业成功的关键。因为相比起企业&#xff0c;消费者会更愿意为品牌买单。那么企业如何将品牌推广做好呢?今日投媒网与您分享。 1.明确品牌定位与目标受众 一切推广活动的起点在于清晰的品牌定位。首先&#xf…

SolrCloud Autoscaling 自动添加副本

SolrCloud Autoscaling 自动添加副本 前言 问题描述 起因是这样的&#xff0c;我在本地调试 Solr 源码&#xff08;版本 7.7.3&#xff09;&#xff0c;用 IDEA 以 solrcloud 方式启动了 2 个 Solr 服务&#xff0c;如下所示&#xff1a; 上图的启动参数 VM Options 如下&am…

RocketMQ实战:一键在docker中搭建rocketmq和doshboard环境

在本篇博客中&#xff0c;我们将详细介绍如何在 Docker 环境中一键部署 RocketMQ 和其 Dashboard。这个过程基于一个预配置的 Docker Compose 文件&#xff0c;使得部署变得简单高效。 项目介绍 该项目提供了一套 Docker Compose 配置&#xff0c;用于快速部署 RocketMQ 及其…

美国商超入驻细节全面曝光,电竞外设产品的国际化浪潮即将席卷全球

近年来&#xff0c;随着电子竞技(简称电竞)行业的蓬勃发展&#xff0c;电竞外设产品也逐渐成为消费者关注的热点。近期&#xff0c;一系列美国商超入驻细节的全面曝光&#xff0c;预示着电竞外设产品的出海风潮即将到来。 电竞行业迅速崛起&#xff0c;全球市场规模年均增长超1…

ResNet50V2

&#x1f368; 本文为&#x1f517;365天深度学习训练营 中的学习记录博客&#x1f356; 原作者&#xff1a;K同学啊 一、ResNetV1和ResNetV2的区别 ResNetV2 和 ResNetV1 都是深度残差网络&#xff08;ResNet&#xff09;的变体&#xff0c;它们的主要区别在于残差块的设计和…

美业系统实操:手机App如何查看员工业绩?美业门店管理系统Java源码分享

在当今竞争激烈的美业市场中&#xff0c;有效的管理对于提高效率、增强客户体验和推动业务增长至关重要。美业管理系统通过其各种功能和优势&#xff0c;成为现代美业企业不可或缺的利器。 ▶下面以博弈美业进行实操-手机App端如何查看员工业绩&#xff1f; 1.店主登录手机端…