seq2seq编码器-解码器实现

        我们在之前的文章快速上手LSTM-CSDN博客中提及了RNN的几种不同的类型,其中有同步的 many to many 的根据视频的每一帧对视频分类任务,以及异步的 many to many 文本翻译。对于这种输入和输出不等长的序列,我们采用seq2seq(sequence to sequence)模型解决。

1. Seq2seq

        seq2seq 是由 encoder(编码器)和 decoder(解码器)构成,这个 encoder 和 decoder 都是由 RNN 组成的。其中 encoder 负责对输入句子的理解,转化为 context vector(语义向量),decoder 负责对理解后的句子的向量进行处理,解码,获得输出。这个过程就和我们人在看到一段话,理解段落大意之后按照自己的方式表达出来。

        那么在这需要注意一个问题,在 encoder 的过程中得到的 context vector 作为 decoder 的输入,那么这样一个输入如何得到多个输出呢?

        其实就是当前时间步输出,作为下一个单元的输入,然后得到下一个时间步的输出,依次循环直至遇到结束符 “<EOS>”(“<END>”)。当然,我们收集的数据都是没有这些特殊词元(“<UNK>”,“<PAD>”,“<SOS>”, “<EOS>”等)的,需要我们在数据集中自行添加。

2. encoder

        encoder 的目的就是对文本进行编码,这里首先要明白我们的输入是会先经过 embedding 的得到的embedded,我们在 encoder 和 decoder 中可以使用rnn,lstm或者是gru,两个编码器都得使用同样的。在encoder中每一个时间步的输入都会得到结果,比如上面这句“Are you free tomorrow?” 会被处理成 'Are' 'you' 'free' 'tomorrow' '?' <EOS>,所以我们输入的句子长度是比原始句子多一出一个 <EOS> 词元的,整个过程是一个和句子长度相关(包含<EOS>词元)的循环

注意:1. 我们一般使用 encoder 最后一个时间步的输出作为句子的编码结果。

           2. <EOS> 词元只会在 encoder 中被使用

"""
编码器
"""
import torch.nn as nn
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence
import configclass Encoder(nn.Module):def __init__(self):super(Encoder, self).__init__()self.embedding = nn.Embedding(num_embeddings=len(config.num_sequence),embedding_dim=config.embedding_dim,padding_idx=config.num_sequence.PAD) # padding_idx 不会被更新,注意传入的是数值不是字符串self.gru = nn.GRU(input_size=config.embedding_dim,num_layers=config.num_layer,hidden_size=config.hidden_size,bidirectional=False, batch_first=True,dropout=0.5)def forward(self, input, input_length):""":param input: [batch_size, seq_len]:param input_length: 输入进编码器的句子的真实长度:return:"""input_embedded = self.embedding(input)  # input_embedded: [batch_size, seq_len, embedding_dim]# 打包,加速计算input_embedded = pack_padded_sequence(input_embedded, input_length, batch_first=True)out, hidden = self.gru(input_embedded)  # hidden: [num_layers*num_directions, batch_size, hidden_size]# out: [batch_size, seq_len, hidden_size]# 解包out, out_length = pad_packed_sequence(out, padding_value=config.num_sequence.PAD,batch_first=True)return out, out_length, hidden

我们可以看一下输出打印一下 encoder的结构 和 encoder的输出

if __name__ == '__main__':from dataset import train_data_loaderencoder = Encoder()print(encoder)for data, label, data_length, label_length in train_data_loader:out, out_length, hidden = encoder(data, data_length)print(out.shape)print(hidden.shape)print(out_length)break
Encoder((embedding): Embedding(14, 50, padding_idx=0)(gru): GRU(50, 32, num_layers=2, batch_first=True, dropout=0.5)
)
torch.Size([128, 8, 32])
torch.Size([2, 128, 32])
tensor([8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 7, 7, 7, 7, 7,7, 7, 7, 7, 7, 6, 6, 5])

补充:

        在上面使用了 torch.nn.utils.rnn 提供的一对函数 pack_padded_sequence() 用来打包数据和 pad_packed_sequence() 用来解包数据,这里有个比较坑的点,它输入的数据需要根据输入句子真实长度降序排序,且要求输入的数据是填充过的。它能够帮我们加速计算,因为我们的每一个 batch 中的句子长度都必须一致,但是输入的文本的句子大多数情况都是长短不一的,所以我们会对句子长度进行填充,截断,使句子保持一致,对于我们用于填充的特殊词元,我们并不希望它进入encoder,如果填充的值进入了 encoder 里会对句子语义的理解出现误差,因为填充的地方在原来的句子是没有词的,还有就是会加大计算量

例如:

有一个batch_size为2的数据sentence_1:[9, 4, 6, 6, 3, 7, 8, 1],sentence_1:[3, 5, 6, 0, 0, 0, 0, 0]真实的句子长度分别为8和3。填充词元 <PAD> 的值是0,先不考虑embedding,首先我们会根据句子长度降序排序得到:

batch:  [[9, 2, 4, 6, 3, 7, 8, 1],

             [3, 5, 6, 0, 0, 0, 0, 0]]

pack_padded_sequence() 会根据传入的真实句子长度,以及填充词元的值0,传入encoder的数据会被打包成

batch:  [[9, 2, 4, 6, 3, 7, 8, 1],

             [3, 5, 6]                    ]

 pad_packed_sequence() 会将压缩完数据填充回去。

3.context vector(语义向量C)

        对于 encoder 的输出结果我们使用它的 context vector ,参与 decoder 的计算,有两种方式,第一中方式是只参与 decoder 中第一个时间步,每个时间步的输出做为下一个时间步的输入。如下:

第二种方式是参与 decoder 中的每一个时间步,每个时间步的输出做为下一个时间步的输入。如下:

4. decoder

在解码器中,通过循环依次计算每个时间步。在这里我们将 context vector 作为初始的隐层状态,输入置为 [batch_size, 1] (特殊词元“<SOS>”值为1,代表句子开始,编码器开始工作),每一次时间步输出[batch_size, hidden_size] ,,hidden_size映射到vocab_size,当前这这个输出的词作为下一个时间步的输入再进行解码。解码器完成整个句子的解码之后会获得 seq_len 的输出拼接(concat)成 [batch_size, seq_len, vocab_size]

"""
解码器1.获取encoder的输出,作为decoder初始的hidden_state2.decoder的第一个时间步输入 <SOS>:[batch_size, 1]3.得到第一个时间步输出 hidden_state
"""
import torch.nn as nn
import torch
import torch.nn.functional as F
import configclass Decoder(nn.Module):def __init__(self):super(Decoder, self).__init__()self.embedding = nn.Embedding(num_embeddings=len(config.num_sequence),embedding_dim=config.embedding_dim,padding_idx=config.num_sequence.PAD)self.gru = nn.GRU(input_size=config.embedding_dim,hidden_size=config.hidden_size,num_layers=config.num_layer,bidirectional=False, batch_first=True,dropout=0.5)self.fc = nn.Linear(config.hidden_size, len(config.num_sequence))def forward(self, label, context_vector):""":param label::param context_vector: 语义向量 context vector:return:"""# 获取encoder的输出,作为decoder初始的hidden_statedecoder_hidden = context_vector# 得到 [batch_size, 1] 实现 <SOS># 作为decoder的第一个时间步的输入batch_size = label.size(0)decoder_input = torch.LongTensor(torch.ones([batch_size, 1], dtype=torch.int64)*config.num_sequence.SOS).to(config.device)# 保存预测的结果decoder_outputs = torch.zeros([batch_size, config.max_sentence_len+2, len(config.num_sequence)]).to(config.device)for i in range(config.max_sentence_len+2):decoder_output_t, decoder_hidden_t = self.forward_step(decoder_input, decoder_hidden)# 保存decoder_outputs[:, i, :] = decoder_output_tvalue, idx = torch.topk(decoder_output_t, 1)decoder_input = idx # 拿到这个词的序列return decoder_outputs, decoder_hidden_tdef forward_step(self, decoder_input, decoder_hidden):"""完成每一个时间步的计算:param decoder_input: [batch_size, 1]:param decoder_hidden: [1, batch_size, hidden_size]"""decoder_input_embedded = self.embedding(decoder_input) # [batch_size, 1, embedding_dim]out, decoder_hidden = self.gru(decoder_input_embedded, decoder_hidden) # out: [batch_size, 1, hidden_size]# decoder_hidden: [1, batch_size, hidden_size]# 完成到词表的映射hidden_size -> vocab_sizeout = out.squeeze(1) # out: [batch_size, 1, hidden_size] -> [batch_size, hidden_size]out = self.fc(out) # out: [batch_size, hidden_size] -> [batch_size, vocab_size]output = F.log_softmax(out, dim=-1)# print("output:", output.shape)return output, decoder_hidden

5. 结语

        总之,seq2seq模型通过encoder接收一个长度为N的序列,得到一个context vector ,然后由 decoder 把这一个 context vector 转化为长度为M的序列作为输出,从而实现了一个N to M的模型,用于处理输入序列和输出序列不同的任务,比如,文本翻译、文章摘要、问答等等。

        代码实现可参考seq2seq实现案例已上传仓库liaolaa / string-prediction · GitCode

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/665897.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

一步步成为React全栈大师:从环境搭建到应用部署

文章目录 第一步&#xff1a;环境搭建第二步&#xff1a;了解React基础第三步&#xff1a;组件与路由第四步&#xff1a;状态管理第五步&#xff1a;接口与数据交互第六步&#xff1a;样式与布局第七步&#xff1a;测试第八步&#xff1a;构建与部署《深入浅出React开发指南》内…

【面试官问】Redis 持久化

目录 【面试官问】Redis 持久化 Redis 持久化的方式RDB(Redis DataBase)AOF(Append Only File)混合持久化:RDB + AOF 混合方式的持久化持久化最佳方式控制持久化开关主从部署使用混合持久化使用配置更高的机器参考文章所属专区

React 面试题

1、组件通信的方式 父组件传子组件&#xff1a;通过props 的方式 子组件传父组件&#xff1a;父组件将自身函数传入&#xff0c;子组件调用该函数&#xff0c;父组件在函数中拿到子组件传递的数据 兄弟组件通信&#xff1a;找到共同的父节点&#xff0c;用父节点转发进行通信 …

一键转换MOV至MP3:轻松删除原视频,释放存储空间!

你是否曾经有一个MOV格式的视频文件&#xff0c;想要提取其中的音频却苦于没有合适的工具&#xff1f;现在&#xff0c;有了我们的全新视频剪辑工具&#xff0c;这个烦恼全部消失&#xff01;我们为你提供一键式解决方案&#xff0c;将MOV视频文件快速转换为MP3音频格式。 首先…

基于单片机的造纸纸浆液位控制系统结构设计

摘要:为适应无人化与高效化制浆造纸生产体系&#xff0c;造纸企业趋于以嵌入式技术优化造纸过 程中的纸浆液位控制系统&#xff0c;以单片机与传感器相互耦合实现纸浆液位控制。本文基于单片机 设计了造纸纸浆液位控制系统&#xff0c;其结构由控制模块、信息采集模块、物联网模…

备战蓝桥杯---搜索(应用入门)

话不多说&#xff0c;直接看题&#xff1a; 显然&#xff0c;我们可以用BFS&#xff0c;其中&#xff0c;对于判重操作&#xff0c;我们可以把这矩阵化成字符串的形式再用map去存&#xff0c;用a数组去重现字符串&#xff08;相当于map映射的反向操作&#xff09;。移动空格先找…

JVM之Java内存区域

JVM-Java内存区域 Java内存区域是Java虚拟机&#xff08;JVM&#xff09;管理的内存资源的逻辑划分&#xff0c;用于存储程序运行时所需的数据。Java内存区域的合理划分和管理对于程序的性能和稳定性具有重要影响。本文将深入探讨Java内存区域的各个部分&#xff0c;包括方法区…

(delphi11最新学习资料) Object Pascal 学习笔记---第4章第2节( 参数和返回值)

4.2 参数和返回值 ​ 调用函数或过程时&#xff0c;需要传递正确数量的参数&#xff0c;并确保它们符合预期类型。否则&#xff0c;编译器会发出错误信息&#xff0c;就像给变量赋值时类型不匹配一样。前面的 DoubleIt 函数定义了一个 整数参数&#xff0c;如果调用&#xff1…

vit细粒度图像分类(九)RAMS-Trans学习笔记

1.摘要 在细粒度图像识别(FGIR)中&#xff0c;区域注意力的定位和放大是一个重要因素&#xff0c;基于卷积神经网络(cnn)的方法对此进行了大量探索。近年来发展起来的视觉变压器(ViT)在计算机视觉任务中取得了可喜的成果。与cnn相比&#xff0c;图像序列化是一种全新的方式。然…

npm ERR! code CERT_HAS_EXPIRED

执行npm i报错&#xff1a; npm ERR! code ETIMEDOUT npm ERR! syscall connect npm ERR! errno ETIMEDOUT npm ERR! network request to https://registry.npmjs.org/react-redux failed, reason: connect ETIMEDOUT 104.16.2.35:443 npm ERR! network This is a problem rel…

机器学习算法之支持向量机(SVM)

支持向量机(Support Vector Machine,简称SVM)是一种广泛用于分类、回归和其他学习任务的强大的监督学习算法。SVM的目标是找到一个超平面,以最大化地分隔不同类别的数据点。在二维空间中,这个超平面可以被看作是一条直线,但在更高维度的空间中,它可能是一个平面或者更复…

Android PMS——网络下载应用安装(六)

我们接着上一篇文章继续分析,文章最后调用到了 PackageManagerService 中的 installStage() 方法,这里就是正式开始 APK 的安装过程。 一、安装流程 1、PackageManagerService 源码位置:/frameworks/base/services/core/java/com/android/server/pm/PackageManagerServic…

使用ESP32-S3对MQ-135空气质量传感器的使用记录(Arduino版)

一、硬件上&#xff1a; 1、使用esp32开发板的04引脚与AO连接&#xff0c;检测AO引脚的电平 二、软件上&#xff1a; 1、使用Arduino快速完成开发 2、源码&#xff1a; // Potentiometer is connected to GPIO 04 (Analog ADC1_CH3) const int adcPin 4;// variable for s…

十大排序算法之堆排序

堆排序 在简单选择排序文章中&#xff0c;简单选择排序这个“铁憨憨”只顾着自己做比较&#xff0c;并没有将对比较结果进行保存&#xff0c;因此只能一遍遍地重复相同的比较操作&#xff0c;降低了效率。针对这样的操作&#xff0c;Robertw.Floyd 在1964年提出了简单选择排序…

C#(C Sharp)学习笔记_数据类型与变量赋值【三】

前言 本期内容会介绍到C#的数据类型&#xff0c;变量和赋值基本操作。当然了&#xff0c;我会简略的讲解常用的数据类型的应用及变量和赋值。 1.数据类型 C#中的数据类型与其他编程语言如出一辙&#xff0c;一下为数据类型参考表。 类型描述范围默认值bool布尔值True 或 Fa…

再谈Redis三种集群模式:主从模式、哨兵模式和Cluster模式

总结经验 redis主从:可实现高并发(读),典型部署方案:一主二从 redis哨兵:可实现高可用,典型部署方案:一主二从三哨兵 redis集群:可同时支持高可用(读与写)、高并发,典型部署方案:三主三从 一、概述 Redis 支持三种集群模式,分别为主从模式、哨兵模式和Cluster模式。…

【学习笔记】Python 环境隔离

文章目录 前言venvvenv 环境管理venv 包管理 virtualenv 以及 virtualenvwrapper安装virtualenvwrapper 环境管理virtualenvwrapper 包管理 condaconda 环境管理conda 包管理 总结参考资料 Python 作为最常用的脚本语言&#xff0c;有着非常丰富的第三方库&#xff0c;但是这也…

YOLOv5改进 | 主干篇 | 反向残差块网络EMO一种轻量级的CNN架构(附完整代码 + 修改教程)

一、本文介绍 本文给大家带来的改进机制是反向残差块网络EMO,其的构成块iRMB在之前我已经发过了,同时进行了二次创新,本文的网络就是由iRMB组成的网络EMO,所以我们二次创新之后的iEMA也可以用于这个网络中,再次形成二次创新,同时本文的主干网络为一种轻量级的CNN架构,在…

redis的数据淘汰测略

Redis 提供了多种数据淘汰策略&#xff0c;可以根据实际需求选择适合的策略。以下是 Redis 中常见的数据淘汰策略&#xff1a; volatile-lru&#xff1a;从已设置过期时间的键中挑选最近最少使用的数据进行淘汰。 volatile-ttl&#xff1a;从已设置过期时间的键中挑选即将过期…

记录在树莓派中部署PI-Assistant开源项目(GPT语音对话)的BUG

核心 在部署PI-Assistant&#xff08;https://github.com/Lucky-183/PI-Assistant&#xff09;项目中&#xff0c;首先要进行环境安装&#xff0c;官网文档中提供的安装命令如下&#xff1a; pip install requests arcade RPi.GPIO pydub numpy wave sounddevice pymysql cn2…