python-pytorch 下批量seq2seq+Bahdanau Attention实现问答1.0.000

python-pytorch 下批量seq2seq+Bahdanau Attention实现简单问答1.0.000

    • 前言
    • 原理看图
    • 数据准备
    • 分词、index2word、word2index、vocab_size
    • 输入模型的数据构造
    • 注意力模型
    • decoder的编写
    • 关于损失函数和优化器
    • 在预测时
    • 完整代码
    • 参考

前言

前面实现了 luong的dot 、general、concat注意力实现简单问答,这里参考官方文档,实现了python-pytorch 下批量seq2seq+Bahdanau Attention实现问答

原理看图

在这里插入图片描述
这里模型选择和官方不一样,官方选择的是GRU,我更喜欢使用LSTM,解码器和编码器都是如此。
意思大致思路是:

  1. 计算encoder的encoder_outputs、encoder_hn、encoder_cn
  2. 使用encoder_outputs、encoder_hn计算新的向量和注意力
  3. 在deconder中,以SOS单字开始,循环句子最大长度,在循环中,使用新的向量和单字SOS做cat计算得到decoder的LSTM输入数据,将该LSTM存起来,最后做cat计算得到decoder的输出

数据准备

结果类似还是采用前面的结构和数据

seq_example = [“你认识我吗”, “你住在哪里”, “你知道我的名字吗”, “你是谁”, “你会唱歌吗”, “谁是张学友”]
seq_answer = [“当然认识”, “我住在成都”, “我不知道”, “我是机器人”, “我不会”, “她旁边那个就是”]

分词、index2word、word2index、vocab_size

分词然后做基础准备,包括数据:index2word、word2index、vocab_size、最长的句子长度seq_length,和一些超参数的设置

输入模型的数据构造

  1. 长度要统一
  2. 问答的句子以EOS结尾,不足补0,如

tensor([[ 3, 4, 5, 6, 2, 0, 0],
[ 3, 7, 8, 9, 2, 0, 0],
[ 3, 10, 5, 11, 12, 6, 2],
[ 3, 13, 14, 2, 0, 0, 0],
[ 3, 15, 16, 6, 2, 0, 0],
[14, 13, 17, 2, 0, 0, 0]])

注意力模型

可以复用,用官方的即可

# Bahdanau
# query=hidden [layer_num,batch_size,hidden_size] keys=encoder_outputs  [seq_len,batch_size,hidden_size]
class Attention(nn.Module):def __init__(self):super(Attention, self).__init__()self.Wa = nn.Linear(hidden_size, hidden_size)self.Ua = nn.Linear(hidden_size, hidden_size)self.Va = nn.Linear(hidden_size, 1)def forward(self, query, keys):scores = self.Va(torch.tanh(self.Wa(query) + self.Ua(keys))) #[seq_len,batch_size,1]scores = scores.permute(1,0,2).squeeze(2).unsqueeze(1)#[batch_size,1,seq_len]weights = nn.functional.softmax(scores, dim=-1)#[batch_size,1,seq_len]context = torch.bmm(weights, keys.permute(1,0,2))#[batch_size,1,hidden_size]return context, weights

decoder的编写

思路是,获得encoder的输出和hn后,计算得到向量,然后使用向量和目标的每一个字做cat计算,输入decoder的模型中,然后得出一个字的预测,循环完了以后,就会得到最大句子长度,最后做cat和softmax计算得到输出。另外,这里要区分训练和测试,训练的时候有target,测试的没有target数据。

关于损失函数和优化器

NLLLoss+Adam的组合优于CrossEntropyLoss+SGD的组合

在预测时

获取到模型输出,size是[batch_size,seq_len,vocab_size]后,对结果做topk计算,会得到每一字在vocab_size的概率,连接起来就是一句话

完整代码

# def getAQ():
#     ask=[]
#     answer=[]
#     with open("./data/flink.txt","r",encoding="utf-8") as f:
#         lines=f.readlines()
#         for line in lines:
#             ask.append(line.split("----")[0])
#             answer.append(line.split("----")[1].replace("\n",""))
#     return answer,ask# seq_answer,seq_example=getAQ()import torch
import torch.nn as nn
import torch.optim as optim
import jieba
import os
from tqdm import tqdmseq_example = ["你认识我吗", "你住在哪里", "你知道我的名字吗", "你是谁", "你会唱歌吗", "谁是张学友"]
seq_answer = ["当然认识", "我住在成都", "我不知道", "我是机器人", "我不会", "她旁

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/15571.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【话题】我眼神的IT行业现状与未来趋势

目录 一、挑战 教学资源的重新分配 教师角色的转变 学生学习方式的改变 教育评价体系的挑战 二、机遇 个性化学习 跨学科学习 国际合作与交流 创新教育模式 三、如何培养下一代IT专业人才 更新教育理念 加强基础设施建设 整合课程资源 加强实践教学 培养跨学科…

【Linux】TCP协议【中】{确认应答机制/超时重传机制/连接管理机制}

文章目录 1.确认应答机制2.超时重传机制:超时不一定是真超时了3.连接管理机制 1.确认应答机制 TCP协议中的确认应答机制是确保数据可靠传输的关键部分。以下是该机制的主要步骤和特点的详细解释: 数据分段与发送: 发送方将要发送的数据分成一…

vue深度选择器(:deep​)

处于 scoped 样式中的选择器如果想要做更“深度”的选择&#xff0c;也即&#xff1a;影响到子组件&#xff0c;可以使用 :deep() 这个伪类&#xff1a; <style lang"scss" scoped> .evaluation-situation-details :deep .cl-icon-arrow-right {display: none…

【Python 对接QQ的接口(二)】简单用接口查询【等级/昵称/头像/Q龄/当天在线时长/下一个等级升级需多少天】

文章日期&#xff1a;2024.05.25 使用工具&#xff1a;Python 类型&#xff1a;QQ接口 文章全程已做去敏处理&#xff01;&#xff01;&#xff01; 【需要做的可联系我】 AES解密处理&#xff08;直接解密即可&#xff09;&#xff08;crypto-js.js 标准算法&#xff09;&…

JS根据所选ID数组在源数据中取出对象

let selectIds [1, 3] // 选中id数组let allData [{ id: 1, name: 123 },{ id: 2, name: 234 },{ id: 3, name: 345 },{ id: 4, name: 456 },] // 源数据let newList [] // 最终数据selectIds.map((i) > {allData.filter((item) > {item.id i && newList.pus…

aws sqs基础概念和队列参数解析

分布式队列的组成部分 生产者&#xff0c;向队列发送消息的组件消费者&#xff0c;接受队列消息队列&#xff0c;多个sqs服务器存储冗余存储消息 sqs自动删除超过最大留存时间的消息&#xff08;默认4天&#xff09;&#xff0c;可以通过SetQueueAttributes调整为&#xff08…

【408真题】2009-13

“接”是针对题目进行必要的分析&#xff0c;比较简略&#xff1b; “化”是对题目中所涉及到的知识点进行详细解释&#xff1b; “发”是对此题型的解题套路总结&#xff0c;并结合历年真题或者典型例题进行运用。 涉及到的知识全部来源于王道各科教材&#xff08;2025版&…

VBA即用型代码手册:删除Excel中空白行Delete Blank Rows in Excel

我给VBA下的定义&#xff1a;VBA是个人小型自动化处理的有效工具。可以大大提高自己的劳动效率&#xff0c;而且可以提高数据的准确性。我这里专注VBA,将我多年的经验汇集在VBA系列九套教程中。 作为我的学员要利用我的积木编程思想&#xff0c;积木编程最重要的是积木如何搭建…

IDEA中好用的插件

IDEA中好用的插件 CodeGeeXMybatis Smart Code Help ProAlibaba Java Coding Guidelines​(XenoAmess TPM)​通义灵码常用操作 TranslationStatistic CodeGeeX 官网地址&#xff1a;https://codegeex.cn/ 使用手册&#xff1a;https://zhipu-ai.feishu.cn/wiki/CuvxwUDDqiErQU…

Android 自定义图片进度条

用系统的Progressbar&#xff0c;设置图片drawable作为进度条会出现图片长度不好控制&#xff0c;容易被截断&#xff0c;或者变形的问题。而我有个需求&#xff0c;使用图片背景&#xff0c;和图片进度&#xff0c;而且在进度条头部有个闪光点效果。 如下图&#xff1a; 找了…

Kafka(十三)监控与告警

目录 Kafka监控与告警1 解决方案1.2 基础知识JMX监控指标代理查看KafkaJMX远程端口 1.3 真实案例Kafka Exporter:PromethusPromethus Alert ManagerGrafana 1.3 实际操作部署监控和告警系统1.2.1 部署Kafka Exporter1.2.2 部署Prometheus1.2.3 部署AlertManger1.2.4 添加告警规…

大疆上云API本地部署与飞机上云

文章目录 前言一、安装基础环境1. EMQX 安装(版本4.4.0)2. MySql 安装(版本8.0.26)3. Redis 安装 二、部署后端&#xff08;JDK必须11及以上&#xff09;三、部署前端四、成为大疆开发者五、飞机注册上云六、绑定飞机七、无人机状态查看八、直播流查看 前言 大疆上云API官方文…

HarmonyOS鸿蒙应用开发——ArkTS的“内置组件 + 样式 + 循环和条件渲染”

一、内置组件是咩&#xff1f; 学过前端的都知道&#xff0c;一个组件就是由多个组件组成的&#xff0c;一个组件也可以是多个小组件组成的&#xff0c;组件就是一些什么导航栏、底部、按钮......啥的&#xff0c;但是组件分为【自定义组件】跟【内置组件】 【自定义组件】就…

【狂神说Java】Redis笔记以及拓展

一、Redis 入门 Redis为什么单线程还这么快&#xff1f; 误区1&#xff1a;高性能的服务器一定是多线程的&#xff1f; 误区2&#xff1a;多线程&#xff08;CPU上下文会切换&#xff01;&#xff09;一定比单线程效率高&#xff01; 核心&#xff1a;Redis是将所有的数据放在内…

用于时间序列概率预测的蒙特卡洛模拟

大家好&#xff0c;蒙特卡洛模拟是一种广泛应用于各个领域的计算技术&#xff0c;它通过从概率分布中随机抽取大量样本&#xff0c;并对结果进行统计分析&#xff0c;从而模拟复杂系统的行为。这种技术具有很强的适用性&#xff0c;在金融建模、工程设计、物理模拟、运筹优化以…

【C语言】C语言-设备管理系统(源码+数据文件)【独一无二】

&#x1f449;博__主&#x1f448;&#xff1a;米码收割机 &#x1f449;技__能&#x1f448;&#xff1a;C/Python语言 &#x1f449;公众号&#x1f448;&#xff1a;测试开发自动化【获取源码商业合作】 &#x1f449;荣__誉&#x1f448;&#xff1a;阿里云博客专家博主、5…

AI大模型:大数据+大算力+强算法

前言&#xff1a;好久不见&#xff0c;甚是想念&#xff0c;我是辣条&#xff0c;我又回来啦&#xff0c;兄弟们&#xff0c;一别两年&#xff0c;还有多少老哥们在呢&#xff1f; 目录 一年半没更文我干啥去了&#xff1f; AI大模型火了 人工智能 大模型的理解 为什么学习…

ComfyUI完全入门:图生图局部重绘

大家好&#xff0c;我是每天分享AI应用的萤火君&#xff01; 这篇文章的主题和美女有关&#xff0c;不过并不是教大家生产美女视频&#xff0c;而是讲解 ComfyUI 的图生图局部重绘&#xff0c;其中将会以美女图片为例&#xff0c;来展示局部重绘的强大威力。 先看看效果&…

2024年5月26日 十二生肖 今日运势

小运播报&#xff1a;2024年5月26日&#xff0c;星期日&#xff0c;农历四月十九 &#xff08;甲辰年己巳月庚寅日&#xff09;&#xff0c;法定节假日。 红榜生肖&#xff1a;马、猪、狗 需要注意&#xff1a;牛、蛇、猴 喜神方位&#xff1a;西北方 财神方位&#xff1a;…

《python编程从入门到实践》day38

# 昨日知识点回顾 定义、迁移模型Entry # 今日知识点学习 18.2.7 Django shell 每次修改模型后&#xff0c;看到重启后的效果需要重启shell&#xff0c;退出shell会话Windows系统按ctrlZ或者输入exit() 18.3 创建页面&#xff1a;学习笔记主页 创建页面三阶段&#xf…