机器翻译之创建Seq2Seq的编码器、解码器

1.创建编码器、解码器的基类

1.1创建编码器的基类

from torch import nn#构建编码器的基类
class Encoder(nn.Module):   #继承父类nn.Moduledef __init__(self, **kwargs):   #**kwargs:不定常的关键字参数super().__init__(**kwargs)def forward(self, X, *args):  #*args:不定常的位置参数#若继承了Encoder这个基类,就必须实现forward(),否则就会报下这个错raise  NotImplementedError          

1.2创建解码器的基类

#创建解码器的基类
#创建解码器的基类比创建编码器的基类多一个 state的初始化
class Decoder(nn.Module):def __init__(self, **kwargs):super().__init__(**kwargs)#初始化statedef init_state(self, enc_outputs, *args):raise NotImplementedError#前向传播,解码器比编码器多传入一个statedef forward(self, X, state):raise NotImplementedError

 1.3合并编码器和解码器的基类

class EncoderDecoder(nn.Module):def __init__(self, encoder, decoder, **kwargs):super().__init__(**kwargs)self.encoder = encoderself.decoder = decoderdef forward(self, enc_X, dec_X, *args):"""enc_X:编码器需传入的数据dec_X:解码器需传入的数据"""enc_outputs = self.encoder(enc_X, *args)dec_state = self.decoder.init_state(enc_outputs, *args)return self.decoder(dec_X, dec_state)

 2.基于上述基类,正式创建Seq2Seq编码器与解码器的类

import collections
import math
import torch
import dltools

2.1创建Seq2Seq的编码器类 

class Seq2SeqEncoder(Encoder):  #继承父类Encoderdef __init__(self, vocab_size, embed_size, num_hiddens, num_layers, dropout=0, **kwargs):super().__init__(**kwargs)"""vocab_size:词汇表大小embed_size:嵌入层大小num_hiddens:隐藏层的神经元数量num_layers:隐藏层的层数dropout=0 : 默认所有的神经元参与计算"""#初始化嵌入层self.embedding = nn.Embedding(vocab_size, embed_size)#初始化神经网络层self.rnn = nn.GRU(embed_size, num_hiddens, num_layers, dropout=dropout)def forward(self, X, *args):#在进行embedding之前,X的shape=(batch_size, num_steps, vocab_size)X = self.embedding(X) #X经过embedding处理,X的shape=(batch_size, num_steps, embed_size)X = X.permute(1, 0, 2)  #经过permute调换维度之后,X的shape=(num_steps, batch_size, embed_size)#此时, pytorch 会自动完成隐藏状态的初始化,即0, 不需要手动传入stateoutputs, state = self.rnn(X)#outputs的shape=(num_steps, batch_size, num_hiddens) ,最后一维是神经元的数量#state的shape=(num_layers, batch_size, num_hiddens)return outputs, state
#测试代码
encoder = Seq2SeqEncoder(vocab_size=10, embed_size=8, num_hiddens=32, num_layers=2)
encoder.eval()
# batch_size=4, num_steps=7
X = torch.zeros((4, 7), dtype=torch.long)
outputs, state = encoder(X)print(outputs.shape, state.shape)
torch.Size([7, 4, 16]) torch.Size([2, 4, 16])

2.2 创建Seq2Seq的解码器类

class Seq2SeqDecoder(Decoder):def __init__(self, vocab_size, embed_size, num_hiddens, num_layers, dropout=0, **kwargs):super().__init__(**kwargs)#初始化嵌入层self.embedding = nn.Embedding(vocab_size, embed_size)#初始化神经网络层self.rnn = nn.GRU(embed_size + num_hiddens, num_hiddens, num_layers, dropout=dropout)#初始化输出层self.dense = nn.Linear(num_hiddens, vocab_size)#定义函数:获取状态statedef init_state(self, enc_outputs, *args):#编码器输出的结果有两个,第二个为statereturn enc_outputs[1]#前向传播def forward(self, X, state):#X的原始shape=(batch_size, num_steps, vocab_size)X = self.embedding(X)  #X的shape=(batch_size, num_steps, embed_size)X = X.permute(1, 0, 2)  #调整数据维度, X的shape=(num_steps, batch_size, embed_size)# 把X和state拼接到一起. 方便计算. # X现在的形状(num_steps, batch_size, embed_size) , # state的形状(batch_size, num_hiddens)# 要把state的形状扩充成三维. 变成(num_steps, batch_size, num_hiddens)context = state[-1].repeat(X.shape[0], 1, 1)  #扩充X.shape[0]=num_steps次,1:所对应的维度不变X_and_context = torch.cat((X, context), 2) #按照索引为2的维度合并#此时,X_and_context的shape=(num_steps, batch_size, embed_size+num_hiddens)#神经网络层outputs, state = self.rnn(X_and_context, state)#输出层outputs = self.dense(outputs).permute(1, 0, 2) #将数据维度重新调换过来#outputs的shape=(batch_size, num_steps, vocab_size)#state的shape=(num_layers, batch_size, num_hiddens)return outputs, state
#测试
decoder = Seq2SeqDecoder(vocab_size=10, embed_size=8, num_hiddens=32, num_layers=2)
decoder.eval()
state = decoder.init_state(encoder(X))
outputs, state = decoder(X, state)
outputs.shape, state.shape
(torch.Size([4, 7, 10]), torch.Size([2, 4, 32]))

3.编码器 、解码器理论图

 

4.知识点个人理解

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/54381.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Git入门学习(1)

Git 00.准备工作-gitee注册 今天Git的设置中需要用到gitee的注册信息,先自行完成注册工作,可以 参考笔记 或第二天视频(10.Git远程仓库-概念和gitee使用准备) 传送门: gitee(码云):https://gitee.com/ 注…

详解:冒泡排序

1.是什么 冒泡排序(Bubble Sort)是一种简单的排序算法。它重复地遍历要排序的数列,一次比较两个元素,如果它们的顺序错误就把它们交换过来。遍历数列的工作是重复地进行直到没有再需要交换,也就是说该数列已经排序完成…

2024华为杯研赛D题保姆级教程思路分析+教程

2024年中国研究生数学建模竞赛D题保姆级教程思路分析 D题:大数据驱动的地理综合问题(数学分析,统计学) 关键词:地理、气候、统计(细致到此题:统计指标、统计模型、统计结果解释) …

c++249多态

#include<iostream> using namespace std; class Parent { public:Parent(int a){this->a a;cout << " Parent" << a << endl;} public:virtual void print()//在子类里面可写可不写 {cout << "Parent" <<a<&l…

OpenCV 2

目录 图像平滑处理 高斯与中值滤波 图像阈值 ​编辑 Canny边缘检测 非极大值抑制 边缘检测效果 轮廓检测方法 ​编辑 ​编辑​编辑 轮廓检测结果 轮廓特征与近似 图像平滑处理 以上两种出来的图片效果 以上的效果&#xff0c;因为填的是normalize False&#xff0c;越界…

Vue接入高德地图并实现基本的路线规划功能

目录 一、申请密钥 二、安装依赖 三、代码实现 四、运行截图 五、官方文档 一、申请密钥 登录高德开放平台&#xff0c;点击我的应用&#xff0c;先添加新应用&#xff0c;然后再添加Key。 如图所示填写对应的信息&#xff0c;系统就会自动生成。 二、安装依赖 npm i am…

艾丽卡的区块链英语小课堂

系列文章目录 IT每日英语&#xff08;三&#xff09; 文章目录 系列文章目录前言1.principle2.efficient3.implement4.accumulated5,occupation6.phases7.validator8.nominated9.commissions10.significantly 前言 欢迎来到艾丽卡的区块链英语小课堂&#xff0c;在这里&…

vmware + ubuntu + 初始配置(超级用户权限、vim安装、ssh登陆、共享文件夹、git)

1 VMware Ubuntu下载与安装 下载与安装 2 使用超级用户权限 &#xff08;1&#xff09;执行命令&#xff1a;sudo passwd root 然后在弹出的密码中输入密码即可&#xff0c;具体如下&#xff1a; 第一个密码是当前用户密码 后面两个是root用户密码 //推荐使用一个密码 3 vi…

航空航司reese84逆向

reese84逆向 Reese84 是一种用于保护网站防止自动化爬虫抓取的防护机制&#xff0c;尤其是在航空公司网站等需要严格保护数据的平台上广泛使用。这种机制通过复杂的指纹识别和行为分析技术来检测和阻止非人类的互动。例如&#xff0c;Reese84 可以通过分析访问者的浏览器指纹、…

基于PHP的电脑线上销售系统

作者&#xff1a;计算机学姐 开发技术&#xff1a;SpringBoot、SSM、Vue、MySQL、JSP、ElementUI、Python、小程序等&#xff0c;“文末源码”。 专栏推荐&#xff1a;前后端分离项目源码、SpringBoot项目源码、SSM项目源码 系统展示 【2025最新】基于phpMySQL的电脑线上销售系…

免费音乐剪辑软件大揭秘:2024 大学生的音乐创作利器

对于音乐爱好者而言&#xff0c;如果你萌生了尝试音乐剪辑的念头&#xff0c;不妨先从探索一些免费工具开始。在此&#xff0c;我愿分享几款我个人体验过的、值得一试的音乐剪辑免费软件&#xff0c;希望能为你的音乐探索之旅增添乐趣与灵感。 1.福晰音频剪辑 链接直达>&g…

免费在线压缩pdf 压缩pdf在线免费 推荐简单好用

压缩pdf在线免费&#xff1f;在日常生活和工作学习中&#xff0c;处理PDF文件是常见任务。但有时PDF文件体积较大&#xff0c;给传输、存储和分享带来不便。因此&#xff0c;学习PDF文件压缩技巧十分必要。压缩PDF文件是指通过技术手段减小文件占用的存储空间&#xff0c;同时尽…

数据保护从现在开始:如何抵御 .[RestoreBackup@cock.li].SRC 勒索病毒

导言 勒索病毒是一种不断演变的网络威胁&#xff0c;.[RestoreBackupcock.li].SRC、[chewbaccacock.li].SRC勒索病毒便是其中一种新型的攻击手段。该病毒通过加密用户文件并要求支付赎金来恢复访问&#xff0c;给个人和企业带来了严重的安全风险和经济损失。本文91数据恢复将探…

Python 从入门到实战23(属性property)

我们的目标是&#xff1a;通过这一套资料学习下来&#xff0c;通过熟练掌握python基础&#xff0c;然后结合经典实例、实践相结合&#xff0c;使我们完全掌握python&#xff0c;并做到独立完成项目开发的能力。 上篇文章我们讨论了类的定义、使用方法的相关知识。今天我们将学…

信息安全数学基础(19)同余式的基本概念及一次同余式

一、同余式概念 同余式是数论中的一个基本概念&#xff0c;用于描述两个数在除以某个数时所得的余数相同的情况。具体地&#xff0c;设m是一个正整数&#xff0c;a和b是两个整数&#xff0c;如果a和b除以m的余数相同&#xff0c;则称a和b模m同余&#xff0c;记作a≡b(mod m)。反…

【有啥问啥】OpenAI o1的思考之前训练扩展定律、后训练扩展定律与推理扩展定律:原理与应用详解

OpenAI o1的思考之前训练扩展定律、后训练扩展定律与推理扩展定律&#xff1a;原理与应用详解 随着深度学习技术的不断发展&#xff0c;模型的规模和复杂度也迅速提升。研究人员发现了模型训练和推理过程中性能变化的规律&#xff0c;这些规律为我们提供了优化模型设计与训练的…

基于微信小程序的剧本杀游玩一体化平台

作者&#xff1a;计算机学姐 开发技术&#xff1a;SpringBoot、SSM、Vue、MySQL、JSP、ElementUI、Python、小程序等&#xff0c;“文末源码”。 专栏推荐&#xff1a;前后端分离项目源码、SpringBoot项目源码、SSM项目源码 系统展示 基于微信小程序JavaSpringBootVueMySQL的剧…

PyQt / PySide + Pywin32 + ctypes 自定义标题栏窗口 + 完全还原 Windows 原生窗口边框特效项目

项目地址&#xff1a; GitHub - github201014/PyQt-NativeWindow: A class of window include nativeEvent, use PySide or PyQt and Pywin32 and ctypesA class of window include nativeEvent, use PySide or PyQt and Pywin32 and ctypes - github201014/PyQt-NativeWindow…

MyBatis 分批次执行(新增,修改,删除)

import com.google.common.collect.Lists;import java.util.Iterator; import java.util.List; import java.util.function.Consumer;/*** Description mybatis分批插入数据使用* Author WangKun* Date 2024/9/19 11:20* Version*/ public class MyBatisSqlUtils {/*** param d…

4G 网络下资源加载失败?一次运营商封禁 IP 的案例分享

在工作中&#xff0c;网络问题是不可避免的挑战之一。最近&#xff0c;我们在项目中遇到了一起网络资源加载异常的问题&#xff1a;某同事在使用 4G 网络连接公司 VPN 时&#xff0c;云服务的前端资源居然无法加载&#xff01;通过一系列的排查和分析&#xff0c;我们发现问题的…