python-pytorch seq2seq+attention笔记1.0.2

python-pytorch seq2seq+attention笔记1.0.0

    • 1. LSTM模型的数据size
    • 2. 关于LSTM的输入数据包含hn和cn时,hn和cn的size
    • 3. LSTM参数中默认batch_first
    • 4. Attention机制的三种算法
    • 5. 模型的编码器
    • 6. 模型的解码器
    • 7. 最终模型
    • 8. 数据的准备
    • 9. 遇到的问题
    • 10. 完整代码
    • 11. 参考链接

记录

  • 2024年5月14日09:27:39----0.5.10
  • 2024年5月14日11:32:47----0.5.12
  • 2024年5月14日11:51:03----1.0.0

1. LSTM模型的数据size

一定是按这个来:维度(batch_size, seq_length, embedding_dim) 是一个三维的tensor;其中,batch_size指每次输入的文本数量;seq_length指每个文本的词语数或者单字数;embedding_dim指每个词语或者每个字的向量长度。

2. 关于LSTM的输入数据包含hn和cn时,hn和cn的size

LSTM的输入数据是上个时间窗的hn和cn时,hn和cn的size要求一定是和LSTM模型参数吻合。公式是(laynum,batchsize,hidden or embeding size)或者(batchsize,hidden or embeding size)。

3. LSTM参数中默认batch_first

其实改变的是模型的hn和cn的size,不改变output的size。因为,cn和hn的size是和batch_size有关系的,是layernum、batch_size、hidden_size

4. Attention机制的三种算法

dot 、general、concat三种,常见使用general算法。

general大概思路是:
计算分数:decoder中LSTM的output和encoder的output做bmm计算
计算权重:将计算出来的分数做softmax,得到行上的概率分布或者权重
计算新向量:将权重和encoder的outpu再做bmm计算
拼接decoder的output和新向量
对新的拼接结果做tanh计算
最后全连接到vocab_size

concat算法思路是:
tang计算:encoder的hn和输出encoder_output相加
计算分数:对相加做tanh计算得到对其分数
计算权重:对其分数做行上做softmax计算得到权重
计算新向量:权重和encoder_output做bmm计算得到新向量
拼接decoder输入的ebeded和新向量作为decoder中LSTM或者GRU的输入
最后返回LSTM或者GRU等的out

5. 模型的编码器

思路很简单,就是将word2index后,通过embedding,将数据给LSTM模型就可以了,返回的是 LSTM的output、hn、cn。
当前你可以根据自己的习惯,在使用LSTM时候增加参数batch_fist或者bidirectional。

此时inputx的是word2index后的数据。

class encoder(nn.Module):def __init__(self):super(encoder, self).__init__()self.embedding=nn.Embedding(vocab_size,n_hidden)self.lstm=nn.LSTM(n_hidden,n_hidden*2,batch_first=False)def forward(self, inputx):embeded=self.embedding(inputx.long())output,(encoder_h_n, encoder_c_n)=self.lstm(embeded.permute(1,0,2))return output,(encoder_h_n,encoder_c_n)

6. 模型的解码器

将解码器的输入embedding后,加上编码器的outout、hn、cn,给LSTM模型输出ouput、hn、cn,做general的attention,最终返回新的LSTM的output、hn、cn。

class lstm_decoder(nn.Module):def __init__(self):super(lstm_decoder, self).__init__()self.embedding=nn.Embedding(vocab_size,embedding_size)self.decoder = nn.LSTM(embedding_size, n_hidden * 2, 1,batch_first=False)self.fc = nn.Linear(n_hidden * 2, num_classes)def forward(self, input_x, encoder_output, hn, cn):embeded=self.embedding(input_x)decoder_output, (decoder_h_n, decoder_c_n) = self.decoder(embeded.float().permute(1,0,2), (hn, cn))decoder_output = decoder_output.permute(1, 0, 2)encoder_output = encoder_output.permute(1, 0, 2)# 下面是实现attention编码# 计算分数scoredecoder_output_score = decoder_output.bmm(encoder_output.permute(0,2,1))# 计算权重atat = nn.functional.softmax(decoder_output_score, dim=2)# 计算新的context向量ctct = at.bmm(encoder_output)# 拼接ct和decoder_htht_joint = torch.cat((ct, decoder_output), dim=2)fc_joint = torch.tanh(self.att_joint(ht_joint))# 实现attention编码结束fc_out = self.fc(fc_joint)return fc_out, decoder_h_n, decoder_c_n

7. 最终模型

在训练时候,encoder的output、hn、cn作为decoder的的输入一部分。最终模型的输出和target数据(view(-1

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/12663.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

好文推荐:基于热红外的双源能量平衡(TSEB)模型--从植物到全球尺度的蒸散诊断简史

文献 近日,美国农业部农业研究服务局(USDA-ARS)的科学家们发表了一篇重要的研究论文——“Agricultural and Forest Meteorology” (https://www.sciencedirect.com/journal/agricultural-and-forest-meteorology)&…

智慧公厕系统:改变“上厕所”体验的科技革新

公共厕所是城市建设中不可或缺的基础设施,然而,由于较为落后的管理模式,会常常存在着管理不到位、脏乱差的问题。为了改善公厕的使用体验,智慧公厕系统应运而生,并逐渐成为智慧城市建设的重要组成部分。本文将以智慧公…

AI工具如何简化日常生活?从论文到PPT,AI助手大集合

AI助手大集合,猛戳进来! 在工作和生活中,我经常使用各种各样的人工智能工具,如AI写作软件、AI语音助手、AI绘图工具等。我发现,这些工具能够极大地提高工作效率并简化日常生活。作为一名AI工具的忠实爱好者&#xff0…

Python爬虫——如何使用urllib的HTTP基本库

怎样通过 urllib库 发送 HTTP 请求? urllib库主要由四个模块组成: urllib.request 打开和读取 URLurllib.error 包含 urllib.request 抛出的异常urllib.parse 用于解析 URLurllib.robotparser 用于解析 robots.txt 文件 1. 使用urllib.parse解析URL 使用urlparse(…

【3dmax笔记】022:文件合并、导入、导出

文章目录 一、合并二、导入三、导出四、注意事项一、合并 只能合并 max 文件(高版本能够合并低版本模型,低版本不能合并高版本的模型)。点击【文件】→【导入】→【合并】: 选择要合并的文件,后缀名为3dmax默认的格式,max文件。 二、导入 点击【文件】→【导入】→【导…

等保测评介绍

等保测评,全称为信息安全等级保护测评,是一种依据国家信息安全等级保护制度要求,对信息系统实施安全等级保护管理的过程。这一过程包括对信息系统的全面安全风险评估,目的是发现潜在的安全隐患,并提出相应的安全改进措…

【ZYNQ】Vivado 封装自定义 IP

在 FPGA 开发设计中,IP 核的使用通常是不可缺少的。FPGA IP 核是指一些已经过验证的、可重用的模块或者组件,可以帮助构建更加复杂的系统。本文主要介绍如何使用 Vivado 创建与封装用户自定义 IP 核,并使用创建的 IP 核进行串口回环测试。 目…

为什么只有const-static-枚举/整型才可以类内初始化

在C中,静态数据成员(static member)是类的所有对象共享的一个变量。由于它们不是与类的任何特定对象实例相关联的,因此不能在类的构造函数中初始化它们。静态数据成员的初始化必须在类定义之外进行,除非它们满足特定的…

谷歌I/O 2024大会全面硬刚OpenAI

🦉 AI新闻 🚀 谷歌发布升级版Gemini机器人 竞争OpenAI ChatGPT-4 摘要:谷歌展示了升级版的 Gemini 聊天机器人,其支持实时处理视频和语音输入,并准确回答问题。此次发布时机与 OpenAI 公布 ChatGPT-4o 新模型几乎同步…

pycharm导入项目,创建虚拟环境,下载依赖

1、安装conda,此处省略 2、管理员身份打开CMD命令行,创建虚拟环境 conda create --name env_name python3.7 -y 其中,env_name替换为自己想要的环境名字,python3.7表示指定python版本为3.7,-y意味着遇到询问直接回复…

Redis经典问题:BigKey问题

大家好,我是小米,今天来和大家聊聊Redis中的一个经典问题:BigKey问题。在互联网系统中,我们经常需要保存大量的用户数据,比如用户的个人信息、粉丝列表、发表的微博内容等等。这些数据往往会被存储在Redis这样的缓存系统中,以提高系统的性能和响应速度。但是,在处理这些…

什么样的开放式耳机好用舒服?五款高人气质量绝佳产品力荐!

​随着人们越来越注重个人的身体健康问题,掀起了一股运动浪潮,现在大家都会喜欢跑跑步,运动一下使自己的身体更好,那么在运动时候如果能有音乐听的话,人们的运动状态就能达到更好的水平。鉴于传统入耳式耳机给用户带来…

k8s源码编译失败:Makefile:1: *** 缺失分隔符。 停止。

目录 问题解决 更换Arch或系统 问题解决 编译k8s源码的kubelet时执行make失败:Makefile:1: *** 缺失分隔符。 停止。 首先,查看文件内容 # cat Makefile build/root/Makefile 修改Makefile,给第一行前增加include,如下&…

基于梯度流的扩散映射卡尔曼滤波算法的信号预处理matlab仿真

目录 1.程序功能描述 2.测试软件版本以及运行结果展示 3.核心程序 4.本算法原理 4.1 扩散映射(Diffusion Maps) 4.2 卡尔曼滤波 4.3 基于梯度流的扩散映射卡尔曼滤波(GFDMKF) 5.完整程序 1.程序功能描述 基于梯度流的扩散…

一文全解聚碳酸酯PC材料在汽车灯罩制造中的诸多显著优势!汽车车灯的灯罩如果破损破裂破洞了要怎么修复?

聚碳酸酯PC材料在汽车灯罩制造中具有诸多显著优势。除了优异的抗冲击性、透明性、耐热性和稳定性外,还有以下一些重要优势: 出色的光学性能:PC材料的光学性能优异,能够确保灯罩内的光源均匀分布,减少光斑和眩光&#…

现代R语言【Tidyverse、Tidymodel】的机器学习

机器学习已经成为继理论、实验和数值计算之后的科研“第四范式”,是发现新规律,总结和分析实验结果的利器。机器学习涉及的理论和方法繁多,编程相当复杂,一直是阻碍机器学习大范围应用的主要困难之一,由此诞生了Python…

从0到1,百亿级任务调度平台的架构与实现

尼恩:百亿级海量任务调度平台起源 在40岁老架构师 尼恩的读者交流群(50)中,经常性的指导小伙伴们改造简历。 经过尼恩的改造之后,很多小伙伴拿到了一线互联网企业如得物、阿里、滴滴、极兔、有赞、希音、百度、网易、美团的面试机会&#x…

软件库V1.5版本iApp源码V3

软件库V1.5版本iApp源码V3 配置教程在【mian.iyu】的【载入事件】 更新内容: 1、分类对接蓝奏(免费,付费,会员,广告),支持蓝奏文件描述设置为简介(改动:首页.iyu&#…

Pikachu 靶场 SQL 注入通关解析

前言 Pikachu靶场是一种常见的网络安全训练平台,用于模拟真实世界中的网络攻击和防御场景。它提供了一系列的实验室环境,供安全专业人士、学生和爱好者练习和测试他们的技能。 Pikachu靶场的目的是帮助用户了解和掌握网络攻击的原理和技术,…

【美团面试2024/05/14】前端面试题滑动窗口

一、题目描述 设有一字符串序列 s&#xff0c;确定该序列中最长的无重复字母的子序列&#xff0c;并返回其长度。 备注 0 < s.length < 5 * 104 s 由英文字母、数字、符号和空格组成 示例1 输入 s "abcabcbb" 输出 3 二、原题链接 这道题在LeetCode上的原题链…