【自然语言处理】Encoder-Decoder模型中Attention机制的引入

Encoder-Decoder 模型中引入 Attention 机制,是为了改善基本Seq2Seq模型的性能,特别是当处理长序列时,传统的Encoder-Decoder模型容易面临信息压缩的困难。Attention机制可以帮助模型动态地选择源序列中相关的信息,从而提高翻译等任务的质量。

一、为什么需要Attention机制?

在基本的 Encoder-Decoder 模型中,Encoder将整个源句子的所有信息压缩成一个固定大小的向量(上下文向量),然后Decoder使用这个向量来生成目标序列。这个单一的上下文向量对于较短的句子可能足够,但对于较长的句子,模型可能无法有效捕捉到整个句子中所有重要的信息。这样容易导致信息丢失,尤其是当句子很长时,Decoder在生成目标词时可能无法获取到源句子的细节信息。

二、Attention机制的核心思想

Attention机制的核心思想是:在每个时间步生成目标单词时,Decoder不再依赖于固定的上下文向量,而是能够通过“注意力”权重,动态地从源句子的所有隐状态中选择最相关的部分。这样,Decoder每生成一个目标词时,能够更好地“关注”源句子中与当前生成词最相关的部分。

三、Attention机制的工作流程

在每一步解码时,Attention机制会根据Decoder的当前状态计算出一组权重,表示源句子中各个位置的隐状态对当前解码步骤的重要性。这些权重用于加权源句子的隐状态,以得到一个上下文向量,这个上下文向量会与当前Decoder的隐状态一起用于生成下一个目标词。由于它跨越两个序列:源语言序列(编码器输出)作为 Key 和 Value;目标语言序列(解码器的当前状态)作为 Query,因此也叫交叉注意力

Attention的具体步骤如下:

  1. 计算注意力权重

    • 对于Decoder的每一步(生成每个目标词时),通过Decoder的当前隐状态和源句子每个时间步的隐状态来计算注意力权重。
    • 这些权重表示源句子中每个位置的重要性,可以使用加性Attention点积Attention来计算。
  2. 计算上下文向量

    • 通过将注意力权重与源句子的隐状态进行加权平均,得到一个新的上下文向量。
    • 这个上下文向量包含了源句子中当前对Decoder最重要的信息。
  3. 解码下一步

    • 将新的上下文向量与当前Decoder的隐状态结合,用于生成当前的目标词。

四、Attention机制的公式

对于每个时间步 t:

  1. 计算注意力得分:通常使用Decoder当前的隐状态 ht 和源句子每个位置的隐状态 hs 计算注意力得分,可以通过以下公式计算:

在这里插入图片描述

常见的 score 函数有加性(Bahdanau Attention)和点积(Luong Attention):

  • 加性Attention:使用一个简单的前馈网络对 ht 和 hs 进行线性变换并加和。
  • 点积Attention:直接计算 ht 和 hs 的点积。
  1. 计算注意力权重:对得分 et,s​ 进行Softmax操作,得到权重:

在这里插入图片描述

这些权重 αt,s 表示源句子中各个位置对当前解码的影响力。

  1. 计算上下文向量:使用注意力权重对源句子的隐状态进行加权平均,得到上下文向量 ct:

在这里插入图片描述

  1. 生成下一个词:将上下文向量 ct 与Decoder的隐状态 ht 结合,生成下一个词。

五、引入Attention机制的Encoder-Decoder代码实现

以下是一个带有 Attention 机制的 Encoder-Decoder 模型的简化实现,使用 PyTorch 进行构建。

import torch
import torch.nn as nn# Encoder模型
class Encoder(nn.Module):def __init__(self, input_size, embedding_dim, hidden_size):super(Encoder, self).__init__()self.embedding = nn.Embedding(input_size, embedding_dim)self.lstm = nn.LSTM(embedding_dim, hidden_size, batch_first=True)def forward(self, src):embedded = self.embedding(src)  # [batch_size, src_len, embedding_dim]outputs, (hidden, cell) = self.lstm(embedded)  # [batch_size, src_len, hidden_size]return outputs, hidden, cell# Attention模型
class Attention(nn.Module):def __init__(self, hidden_size):super(Attention, self).__init__()self.attn = nn.Linear(hidden_size * 2, hidden_size)self.v = nn.Parameter(torch.rand(hidden_size))def forward(self, hidden, encoder_outputs):src_len = encoder_outputs.shape[1]hidden = hidden.unsqueeze(1).repeat(1, src_len, 1)energy = torch.tanh(self.attn(torch.cat((hidden, encoder_outputs), dim=2)))  # [batch_size, src_len, hidden_size]energy = torch.sum(self.v * energy, dim=2)  # [batch_size, src_len]return torch.softmax(energy, dim=1)  # [batch_size, src_len]# Decoder模型
class Decoder(nn.Module):def __init__(self, output_size, embedding_dim, hidden_size):super(Decoder, self).__init__()self.embedding = nn.Embedding(output_size, embedding_dim)self.lstm = nn.LSTM(embedding_dim + hidden_size, hidden_size, batch_first=True)self.fc_out = nn.Linear(hidden_size * 2, output_size)self.attention = Attention(hidden_size)def forward(self, input_token, hidden, cell, encoder_outputs):input_token = input_token.unsqueeze(1)  # [batch_size, 1]embedded = self.embedding(input_token)  # [batch_size, 1, embedding_dim]# 计算注意力权重attn_weights = self.attention(hidden[-1], encoder_outputs)  # [batch_size, src_len]# 使用注意力权重对encoder输出进行加权平均attn_applied = torch.bmm(attn_weights.unsqueeze(1), encoder_outputs)  # [batch_size, 1, hidden_size]# 将注意力上下文向量和嵌入层输入拼接lstm_input = torch.cat((embedded, attn_applied), dim=2)  # [batch_size, 1, embedding_dim + hidden_size]# 通过LSTMoutput, (hidden, cell) = self.lstm(lstm_input, (hidden, cell))  # [batch_size, 1, hidden_size]# 生成最终输出output = torch.cat((output.squeeze(1), attn_applied.squeeze(1)), dim=1)  # [batch_size, hidden_size * 2]prediction = self.fc_out(output)  # [batch_size, output_size]return prediction, hidden, cell# Seq2Seq模型
class Seq2Seq(nn.Module):def __init__(self, encoder, decoder, device):super(Seq2Seq, self).__init__()self.encoder = encoderself.decoder = decoderself.device = devicedef forward(self, src, tgt, teacher_forcing_ratio=0.5):batch_size = tgt.shape[0]target_len = tgt.shape[1]target_vocab_size = self.decoder.fc_out.out_featuresoutputs = torch.zeros(batch_size, target_len, target_vocab_size).to(self.device)encoder_outputs, hidden, cell = self.encoder(src)input_token = tgt[:, 0]for t in range(1, target_len):output, hidden, cell = self.decoder(input_token, hidden, cell, encoder_outputs)outputs[:, t, :] = outputtop1 = output.argmax(1)input_token = tgt[:, t] if torch.rand(1).item() < teacher_forcing_ratio else top1return outputs
代码说明:
  1. Encoder

    • 编码源句子,生成隐状态和输出序列。
    • 输出序列会在注意力机制中使用。
  2. Attention

    • Attention 模型根据当前隐状态和Encoder输出计算注意力权重。
  3. Decoder

    • 使用Attention得到的注意力权重对Encoder输出进行加权平均,得到上下文向量。
    • Decoder在当前时间步会将 当前输入(上一个时间步生成的词)、上一个时间步的隐状态 和 注意力上下文向量 拼接起来,输入到LSTM或GRU中,更新隐状态并生成当前时间步的输出。
  4. Seq2Seq

    • 将Encoder和Decoder结合,逐步生成目标序列。
    • 使用了教师强制机制来控制训练时的输入。
Decoder代码详细解释:
  1. attn_weights = self.attention(hidden[-1], encoder_outputs):

    • hidden[-1] 是Decoder当前时间步的最后一层隐状态(对于多层LSTM来说)。encoder_outputs 是Encoder所有时间步的输出。
    • 调用 self.attention 计算当前时间步的注意力权重。
  2. attn_applied = torch.bmm(attn_weights.unsqueeze(1), encoder_outputs):

    • attn_weights 是注意力权重,形状为 [batch_size, src_len]
    • unsqueeze(1) 将其变为 [batch_size, 1, src_len],然后与 encoder_outputs(形状为 [batch_size, src_len, hidden_size])进行批量矩阵乘法(torch.bmm)。
    • 这样得到的结果 attn_applied 是加权后的上下文向量,形状为 [batch_size, 1, hidden_size],表示根据注意力权重加权后的源句子信息。
  3. torch.cat((embedded, attn_applied), dim=2):

    • 将Decoder的当前输入(嵌入表示)和上下文向量拼接在一起,输入到LSTM中。

六、总结:

Attention机制的引入,允许Decoder在生成每个目标词时,能够动态地根据源句子的不同部分调整注意力,使得模型能够处理更长的序列,并提高生成结果的准确性。Attention机制在机器翻译等任务中取得了显著的效果,并且为之后的Transformer等模型的出现奠定了基础。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/56452.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【数据库MySQL作业】

(简答题) 1、按照结构表提示在数据库中创建数据表&#xff0c;要求使用SQL 语句实现。 2、按照下列要求完成操作&#xff0c;要求使用SQL 语句实现。 (1) 为读者表 reader 指定检査约束&#xff0c;即指定性别 sex 字段的值只能是“男”或“女”。 (2) 修改读者表reader 的默认…

RAG中向量召回怎么做

目录 1. 文档嵌入 2. 查询处理 3. 向量搜索 4. 结果融合 5. 实现细节 使用 FAISS 进行向量召回 在检索增强生成&#xff08;Retrieval-Augmented Generation, RAG&#xff09;框架中&#xff0c;向量召回是一个关键步骤&#xff0c;它涉及到从一个大规模的文档库或知识库…

华为---Super VLAN简介及示例配置

目录 1. Super VLAN技术产生背景 2. Super VLAN概念 3. Super VLAN应用场景 4. Super VLAN工作原理 5. Super-VLAN主要配置命令 6. Super-VLAN主要配置步骤 7. 示例配置 7.1 示例场景 7.2 网络拓扑 7.3 配置代码 7.4 代码解析 7.5 测试验证 1. Super VLAN技术产生背…

创建包含可导入浏览器信任的SSL自签名证书

问题&#xff1a;现在的三大浏览器&#xff0c;chrome、edge、firefox 一般都默认启用https检查&#xff0c;这就要求我们自建的局域网内的网址和其他诸如nextcloud、photoprism、tiddlywiki等应用也必须要有证书。解决方法是使用openssl自己生成一个。由此则会再衍生出一个问题…

爬虫prc技术----小红书爬取解决xs

知识星球&#xff1a;知识星球 | 深度连接铁杆粉丝&#xff0c;运营高品质社群&#xff0c;知识变现的工具知识星球是创作者连接铁杆粉丝&#xff0c;实现知识变现的工具。任何从事创作或艺术的人&#xff0c;例如艺术家、工匠、教师、学术研究、科普等&#xff0c;只要能获得一…

vscode解决中文注释乱码,意料之外的原因

问题详情&#xff1a; c文件编码格式是&#xff1a;utf-8&#xff0c;vscode打开格式也是utf-8&#xff0c;但中文注释仍是乱码。可是用notepad打开中文显示却是正常 notepad显示编码如图&#xff1a; notepad打开文件&#xff1a; vscode显示编码如图&#xff1a; vscode打开…

Re75 读论文:Toolformer: Language Models Can Teach Themselves to Use Tools

诸神缄默不语-个人CSDN博文目录 诸神缄默不语的论文阅读笔记和分类 论文全名&#xff1a;Toolformer: Language Models Can Teach Themselves to Use Tools 论文下载地址&#xff1a;https://arxiv.org/abs/2302.04761 这篇文章是介绍tool learning的&#xff0c;大概来说就是…

闪电麦昆 语音控制齿轮行进轨迹,ESP32搭配语音控制板,串口通信,附视频演示地址

演示地址 https://www.bilibili.com/video/BV1cW421d79L/?vd_sourceb8515e53f6d4c564b541d98dcc9df990 语音控制板的配置 web展示页面 esp32 程序 #include <ESP8266WiFi.h> #include <ESP8266WebServer.h> #include <LittleFS.h> #include <WebSo…

Arthas常用的命令(三)--monitor、jad 、stack

monitor&#xff1a;监控方法的执行情况 监控指定类中方法的执行情况 用来监视一个时间段中指定方法的执行次数&#xff0c;成功次数&#xff0c;失败次数&#xff0c;耗时等这些信息 参数说明 方法拥有一个命名参数 [c:]&#xff0c;意思是统计周期&#xff08;cycle of ou…

安装TDengine数据库3.3版本和TDengine数据库可视化管理工具

安装TDengine数据库3.3版本和TDengine数据库可视化管理工具 一、下载安装包二、解压安装包三、部署四、启动服务五、进入数据库六、创建数据库、表和往表中插入数据七、测试 TDengine 性能八、使用数据库九、查询数据十、TDengine数据库可视化界面 一、下载安装包 TDengine-cl…

YOLO11改进 | 注意力机制 | 添加SE注意力机制

秋招面试专栏推荐 &#xff1a;深度学习算法工程师面试问题总结【百面算法工程师】——点击即可跳转 &#x1f4a1;&#x1f4a1;&#x1f4a1;本专栏所有程序均经过测试&#xff0c;可成功执行&#x1f4a1;&#x1f4a1;&#x1f4a1; 本文介绍了YOLOv11添加SE注意力机制&…

C语言 | 第十五章 | 指针函数 函数指针 内存分配 结构体

P 141 返回指针的函数 2023/2/16 一、基本介绍 C语言 允许函数的返回值是一个指针&#xff08;地址&#xff09;&#xff0c;这样的函数称为指针函数。 二、入门案例 案例&#xff1a;请编写一个函数 strlong()&#xff0c;返回两个字符串中较长的一个。 #include<stdi…

区块链技术与农产品溯源:实现透明供应链的关键

引言 随着食品安全问题和消费者对产品质量要求的提升&#xff0c;农产品溯源变得越来越重要。消费者希望知道他们购买的农产品从何而来&#xff0c;经历了哪些过程以及是否符合安全标准。区块链技术因其去中心化、不可篡改和透明的特点&#xff0c;成为实现农产品溯源的理想选…

如何解决与kernel32.dll相关的常见错误:详细指南解析kernel32.dll文件缺失、损坏或错误加载问题

当你的电脑中出现错误kernel32.dll丢失的问题&#xff0c;会导致电脑不能出现正常运行&#xff0c;希望能够有效的帮助你有效的将丢失的kernel32.dll文件进行修复同时也给大家介绍一些关于kernel32.dll文件的相关介绍&#xff0c;希望能够有效的帮助你快速修复错误。 kernel32.…

Unity RPG梦幻场景素材(附下载链接)

Unity RPG梦幻场景素材 点击下载资源 效果图&#xff1a; 资源链接

OpeneBayes 教程上新 | 打败 GPT-4V?超强开源多模态大模型 LLaVA-OneVision 正式上线!

大语言模型&#xff08;Large Language Model&#xff0c;简称 LLM&#xff09;与多模态大模型&#xff08;Large Multimodal Model&#xff0c;简称 LMM&#xff09;是人工智能领域的两个核心发展方向。 LLM 主要致力于处理和生成文本数据&#xff0c;而 LMM 则更进一步&#…

CesiumLab介绍

软考鸭小程序 学软考,来软考鸭! 提供软考免费软考讲解视频、题库、软考试题、软考模考、软考查分、软考咨询等服务 CesiumLab是一个围绕Cesium平台设计的完整易用的数据预处理工具集&#xff0c;它旨在最大化提升三维数据可视化效率。本文将详细介绍CesiumLab的安装、主要功能…

【JavaSE】图书系统

目录 当我们学习完Java的语法后&#xff0c;可以写一个简单的项目进行总结梳理一下&#xff0c;这个项目也会用到我们所学过的Java所有的语法知识&#xff1a;目录是咱们用文件夹包装起来的类。 1.book 在面向对象体系中&#xff0c;提出了一个软件包的概念&#xff0c;即&am…

k8s微服务

一 、什么是微服务 用控制器来完成集群的工作负载&#xff0c;那么应用如何暴漏出去&#xff1f;需要通过微服务暴漏出去后才能被访问 Service是一组提供相同服务的Pod对外开放的接口。 借助Service&#xff0c;应用可以实现服务发现和负载均衡。 service默认只支持4层负载均…

斯坦福大学提出电影剧本可视化工具ScriptViz:能够根据剧本中的文本和对话自动检索相关的电影画面,帮助剧作家更好地构思和调整剧情

title:斯坦福大学提出电影剧本可视化工具ScriptViz&#xff1a;能够根据剧本中的文本和对话自动检索相关的电影画面&#xff0c;帮助剧作家更好地构思和调整剧情 斯坦福大学的研究者们开发了一个电影剧本可视化工具ScriptViz工具&#xff0c;ScriptViz的工作原理可以简单地理解…