python-pytorch 实现seq2seq+luong general concat attention 完整代码

接上一篇https://blog.csdn.net/m0_60688978/article/details/139046644


# def getAQ():
#     ask=[]
#     answer=[]
#     with open("./data/flink.txt","r",encoding="utf-8") as f:
#         lines=f.readlines()
#         for line in lines:
#             ask.append(line.split("----")[0])
#             answer.append(line.split("----")[1].replace("\n",""))
#     return answer,ask# seq_answer,seq_example=getAQ()import torch
import torch.nn as nn
import torch.optim as optim
import jieba
import os
from tqdm import tqdmseq_example = ["你认识我吗", "你住在哪里", "你知道我的名字吗", "你是谁", "你会唱歌吗", "谁是张学友"]
seq_answer = ["当然认识", "我住在成都", "我不知道", "我是机器人", "我不会", "她旁边那个就是"]# 所有词
example_cut = []
answer_cut = []
word_all = []
# 分词
for i in seq_example:example_cut.append(list(jieba.cut(i)))
for i in seq_answer:answer_cut.append(list(jieba.cut(i)))
#   所有词
for i in example_cut + answer_cut:for word in i:if word not in word_all:word_all.append(word)
# 词语索引表
word2index = {w: i+3 for i, w in enumerate(word_all)}
# 补全
word2index['PAD'] = 0
# 句子开始
word2index['SOS'] = 1
# 句子结束
word2index['EOS'] = 2
index2word = {value: key for key, value in word2index.items()}
# 一些参数
vocab_size = len(word2index)
seq_length = max([len(i) for i in example_cut + answer_cut]) + 1
print("vocab_size is",vocab_size,", seq_length is ",seq_length)
embedding_size = 128
num_classes = vocab_size
hidden_size = 256
batch_size=6
seq_len=7# 将句子用索引表示
def make_data(seq_list):result = []for word in seq_list:seq_index = [word2index[i] for i in word]if len(seq_index) < seq_length:seq_index += [0] * (seq_length - len(seq_index))result.append(seq_index)return result
encoder_input = make_data(example_cut)
decoder_input = make_data([['SOS'] + i for i in answer_cut])
decoder_target = make_data([i + ['EOS'] for i in answer_cut])# 训练数据
encoder_input, decoder_input, decoder_target = torch.tensor(encoder_input,dtype=torch.long),torch.tensor(decoder_input,dtype=torch.long), torch.tensor(decoder_target,dtype=torch.long)class encoder(nn.Module):def __init__(self):super(encoder, self).__init__()self.embedding=nn.Embedding(vocab_size,embedding_size)self.lstm=nn.LSTM(embedding_size,hidden_size)def forward(self, inputx):embeded=self.embedding(inputx)output,(encoder_h_n, encoder_c_n)=self.lstm(embeded.permute(1,0,2))return output,(encoder_h_n,encoder_c_n)cc1=encoder()
out,(hn,cn)=cc1(encoder_input)
out.size(),hn.size(),out,hn#out [seq_len_size,batch_size,hidden_size]  hn [layer_num,batch_size,hidden_size]# concat
# class Attention(nn.Module):
#     def __init__(self):
#         super(Attention, self).__init__()
#         self.wa=nn.Linear(hidden_size*2,hidden_size*2,bias=False)#         self.wa1=nn.Linear(hidden_size*2,hidden_size,bias=False)#     def forward(self, hidden, encoder_outputs):
#         """
#         hidden:[layer_num,batch_size,hidden_size]
#         encoder_outputs:[seq_len,batch_size,hidden_size]
#         """#         hiddenchange=hidden.repeat(encoder_outputs.size(0),1,1)#[seq_len,batch_size,hidden_size]#         concated=torch.cat([hiddenchange.permute(1,0,2),encoder_outputs.permute(1,0,2)],dim=-1)# [batch_size,seq_len,hidden_size*2]
#         waed=self.wa(concated)# [batch_size,seq_len,hidden_size*2]
#         tanhed=torch.tanh(waed)# [batch_size,seq_len,hidden_size*2]#         self.va=nn.Parameter(torch.FloatTensor(encoder_outputs.size(1),hidden_size*2))#[batch_size,hidden_size*2]
# #         print("tanhed size",tanhed.size(),self.va.unsqueeze(2).size())#         attr=tanhed.bmm(self.va.unsqueeze(2))# [batch_size,seq_len,1]#         context=attr.permute(0,2,1).bmm(encoder_outputs.permute(1,0,2))# [batch_size,1,seq_len]#         return context,attr# general
class Attention(nn.Module):def __init__(self):super(Attention, self).__init__()self.va=nn.Linear(hidden_size,hidden_size,bias=False)def forward(self, hidden, encoder_outputs):"""hidden:[layer_num,batch_size,hidden_size]encoder_outputs:[seq_len,batch_size,hidden_size]"""score=encoder_outputs.permute(1,0,2).bmm(self.va(hidden).permute(1,2,0))# [batch_size,seq_len,layer_num]attr=nn.functional.softmax(score,dim=1)# [batch_size,seq_len,layer_num]context=attr.permute(0,2,1).bmm(encoder_outputs.permute(1,0,2)) #[batch_size,layer_num,hidden_size]return context,attr# # dot
# class Attention(nn.Module):
#     def __init__(self):
#         super(Attention, self).__init__()#     def forward(self, hidden, encoder_outputs):
#         """
#         hidden:[layer_num,batch_size,hidden_size]
#         encoder_outputs:[seq_len,batch_size,hidden_size]
#         """
#         score=encoder_outputs.permute(1,0,2).bmm(hidden.permute(1,2,0))# [batch_size,seq_len,layer_num]
#         attr=nn.functional.softmax(score,dim=-1)# [batch_size,seq_len,layer_num]#         context=attr.permute(0,2,1).bmm(encoder_outputs.permute(1,0,2)) #[batch_size,layer_num,hidden_size]#         return context,attrclass decoder(nn.Module):def __init__(self):super(decoder,self).__init__()self.embeded=nn.Embedding(vocab_size,embedding_size)self.lstm=nn.LSTM(embedding_size,hidden_size)self.attr=Attention()self.fc=nn.Linear(hidden_size*2,vocab_size,bias=False)self.tan=nn.Linear(hidden_size*2,hidden_size*2,bias=False)def forward(self,input,encoder_outputs,hn,cn):decoder_input_embedded=self.embeded(input)#[batch_size,seq_len,embedding_size]lstm_output,(lstm_hn,lstm_cn)=self.lstm(decoder_input_embedded.permute(1,0,2),(hn,cn)) #lstm_output  [seq_len,batch_size,hidden_size]  lstm_hn [layer_num,batch_size,hidden_size]context,attr=self.attr(hn,encoder_outputs) #[batch_size,layer_num,hidden_size] [batch_size,seq_len,layer_num]#         这里是解决预测时长度为1,而训练是seq长度为7的关键context=context.repeat(1,lstm_output.size(0),1) #[batch_size,seq_len,hidden_size]#         print("lstm_output size is",lstm_output.size(),"context size is",context.size())concated=torch.cat([lstm_output.permute(1,0,2),context],dim=-1) #[batch_size,seq_len,hidden_size*2]concat_output=torch.tanh(self.tan(concated))# [batch_size,seq_len,hidden_size*2]
#         concat_output=nn.functional.log_softmax(self.fc(concat_output),dim=-1) # [batch_size,seq_len,vocab_size]concat_output=self.fc(concat_output) # [batch_size,seq_len,vocab_size]#         这里从softmax变为log_softmax,开始损失从3.xx变为了2.xreturn concat_output,lstm_hn,lstm_cnclass seq2seq(nn.Module):def __init__(self):super(seq2seq,self).__init__()self.word_vec=nn.Linear(vocab_size,embedding_size)self.encoder=encoder()self.decoder=decoder()def forward(self,encoder_input, decoder_input,istrain=1):encoder_outputs,(hn,cn)=self.encoder(encoder_input)if istrain:#outputs [batch_size,seq_len,vocab_size]  hn [layer_num,batch_size,hidden_size]outputs,hn,cn=self.decoder(decoder_input,encoder_outputs,hn,cn)return outputs,hn,cnelse:finaloutputs=[]for i in range(seq_length):#outputs [batch_size,seq_len,vocab_size]  hn [layer_num,batch_size,hidden_size]outputs,hn,cn=self.decoder(decoder_input,encoder_outputs,hn,cn)gailv=nn.functional.softmax(outputs[0],dim=1)out=torch.topk(gailv,1)[1].item()if out in [0,2]:return finaloutputsfinaloutputs.append(out)decoder_input=torch.tensor([[out]])return finaloutputsmodel = seq2seq()
print(model)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.05)model.train()
for epoch in tqdm(range(3000)):pred,lstm_hn,lstm_cn = model(encoder_input, decoder_input,1)
#     print("pred",pred.size())
#     print("decoder_target",decoder_target.view(-1).size())loss = criterion(pred.reshape(-1, vocab_size), decoder_target.view(-1))optimizer.zero_grad()loss.backward()optimizer.step()if (epoch + 1) % 100 == 0:print("Epoch: %d,  loss: %.5f " % (epoch + 1, loss))
# 保存模型
torch.save(model.state_dict(), './seq2seqModel.pkl')model.eval()
question_text = '谁是张学友'
question_cut = list(jieba.cut(question_text))
encoder_x = make_data([question_cut])
decoder_x = [[word2index['SOS']]]
encoder_x,  decoder_x = torch.LongTensor(encoder_x), torch.LongTensor(decoder_x)
out= model(encoder_x,decoder_x,0)
answer = ''
for i in out:answer += index2word[i]
print('问题:', question_text)
print('回答:', answer)  

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/16950.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

MyBatis多数据源配置与使用,基于ThreadLocal+AOP

导读 MyBatis多数据源配置与使用其一其二1. 引依赖2. 配置文件3. 编写测试代码4. 自定义DynamicDataSource类5. DataSourceConfig配置类6. AOP与ThreadLocal结合7. 引入AOP依赖8. DataSourceContextHolder9. 自定义注解UseDB10. 创建切面类UseDBAspect11. 修改DynamicDataSourc…

jQuery里添加事件 (代码)

直接上代码 <!DOCTYPE html> <html><head></head><body><input type"text" placeholder"城市" id"city" /><input type"button" value"添加" id"btnAdd" /><ul id…

PTA 计算矩阵两个对角线之和

计算一个nn矩阵两个对角线之和。 输入格式: 第一行输入一个整数n(0<n≤10)&#xff0c;第二行至第n1行&#xff0c;每行输入n个整数&#xff0c;每行第一个数前没有空格&#xff0c;每行的每个数之间各有一个空格。 输出格式: 两条对角线元素和&#xff0c;输出格式见样例…

Android存储系统成长记

用心坚持输出易读、有趣、有深度、高质量、体系化的技术文章 本文概要 您一定使用过Context的getFileStreamPath方法或者Environment的getExternalStoragePublicDirectory方法&#xff0c;甚至还有别的方法把数据存储到文件中&#xff0c;这些都是存储系统提供的服务&#x…

PTA 判断两个矩阵相等

Peter得到两个n行m列矩阵&#xff0c;她想知道两个矩阵是否相等&#xff0c;请你用“Yes”&#xff0c;“No”回答她&#xff08;两个矩阵相等指的是两个矩阵对应元素都相等&#xff09;。 输入格式: 第一行输入整数n和m&#xff0c;表示两个矩阵的行与列&#xff0c;用空格隔…

修改元组元素

自学python如何成为大佬(目录):https://blog.csdn.net/weixin_67859959/article/details/139049996?spm1001.2014.3001.5501 场景模拟&#xff1a;伊米咖啡馆&#xff0c;由于麝香猫咖啡需求量较大&#xff0c;库存不足&#xff0c;店长想把它换成拿铁咖啡。 实例08 将麝香猫…

chrome浏览器驱动下载

跑自动化的时候&#xff0c;需要打开谷歌浏览器&#xff0c;这个时候提示浏览器驱动找不到咋办呢&#xff1f; 1、网上搜索找到了这篇文章&#xff1a;https://www.cnblogs.com/laoluoits/p/17710501.html&#xff1b;按照文章介绍&#xff0c; 首先找到&#xff1a;CNPM Bin…

D - Permutation Subsequence(AtCoder Beginner Contest 352)

题目链接: D - Permutation Subsequence (atcoder.jp) 题目大意&#xff1a; 分析&#xff1a; 相对于是记录一下每个数的位置 然后再长度为k的区间进行移动 然后看最大的pos和最小的pos的最小值是多少 有点类似于滑动窗口 用到了java里面的 TreeSet和Map TreeSet存的是数…

解决 Spring Boot 应用启动失败的问题:Unexpected end of file from server

解决 Spring Boot 应用启动失败的问题&#xff1a;Unexpected end of file from server 博主猫头虎的技术世界 &#x1f31f; 欢迎来到猫头虎的博客 — 探索技术的无限可能&#xff01; 专栏链接&#xff1a; &#x1f517; 精选专栏&#xff1a; 《面试题大全》 — 面试准备的…

Spring AOP失效的场景事务失效的场景

场景一&#xff1a;使用this调用被增强的方法 下面是一个类里面的一个增强方法 Service public class MyService implements CommandLineRunner {private MyService myService;public void performTask(int x) {System.out.println("Executing performTask method&quo…

爬虫学习--15.进程与线程(2)

线程锁 当多个线程几乎同时修改某一个共享数据的时候&#xff0c;需要进行同步控制 某个线程要更改共享数据时&#xff0c;先将其锁定&#xff0c;此时资源的状态为"锁定",其他线程不能改变&#xff0c;只到该线程释放资源&#xff0c;将资源的状态变成"非锁定…

Linux如何设置共享文件夹

打开虚拟机->菜单->虚拟机设置->选项->共享文件夹->总是启用。点击添加按钮->弹出添加向导->点击浏览按钮&#xff0c;从windows中选择一个文件夹&#xff0c;确定即可。

[Windows] GIF动画、动图制作神器 ScreenToGif(免费)

ScreenToGif 是开源免费的 Gif 动画录制工具&#xff0c;小巧原生单文件&#xff0c;功能很实用。它有录制屏幕、录制摄像头、录制画板、图像编辑器等功能&#xff0c;可以将屏幕任何区域及操作过程录制成 GIF 格式的动态图像。保存前还可对 GIF 图像编辑优化&#xff0c;支持自…

末日设计1.00

故事背景: 在不远的未来&#xff0c;世界陷入了末日危机。资源枯竭、社会秩序崩溃&#xff0c;幸存者们为了生存&#xff0c;不得不拿起武器争夺每一寸土地和每一口食物。在这个混乱的世界中&#xff0c;你是一名传奇狙击手&#xff0c;凭借超凡的射击技巧和生存智慧&#xff0…

研二学妹面试字节,竟倒在了ThreadLocal上,这是不要应届生还是不要女生啊?

一、写在开头 今天和一个之前研二的学妹聊天&#xff0c;聊及她上周面试字节的情况&#xff0c;着实感受到了Java后端现在找工作的压力啊&#xff0c;记得在18&#xff0c;19年的时候&#xff0c;研究生计算机专业的学生&#xff0c;背背八股文找个Java开发工作毫无问题&#x…

本地图形客户端查看git提交历史 使用 TortoiseGit

要在本地查看提交记录和修改历史&#xff0c;可以使用 TortoiseGit 和 Git-SCM。这两个工具都提供了强大的功能来管理和查看 Git 仓库中的提交记录和历史修改。 使用 TortoiseGit 查看提交记录和修改历史 查看提交记录&#xff08;Log&#xff09;&#xff1a; 右键点击项目文…

抖音里卖什么最赚钱?4个冷门的高利润商品,还有谁不知道!

哈喽~我的电商月月 做抖音小店的新手朋友&#xff0c;一定很想知道&#xff0c;在抖音里卖什么最赚钱&#xff1f; 很多人都会推荐&#xff0c;日常百货&#xff0c;小风扇&#xff0c;女装&#xff0c;宠物用品等等&#xff0c;这些商品确实很好做&#xff0c;你们可以试试 …

Euraka详解:实现微服务架构的关键组件

在当今互联网时代&#xff0c;微服务架构已经成为许多企业构建和部署应用程序的首选方法之一。而要在微服务架构中实现高可用性和灵活性&#xff0c;服务发现和注册是至关重要的一环。Eureka作为Netflix开源的服务发现组件&#xff0c;为实现这一目标提供了高效可靠的解决方案。…

备忘录可以统计字数吗?备忘录里在哪查看字数?

在这个信息爆炸的时代&#xff0c;很多人喜欢使用备忘录app来记录生活中的点点滴滴。备忘录不仅可以帮助我们记事、安排日程&#xff0c;还能提醒我们完成各种任务&#xff0c;是我们日常生活中不可或缺的小助手。 然而&#xff0c;在使用备忘录时&#xff0c;有时我们会遇到需…

不用BookStack的企业都在用什么知识库软件

现如今&#xff0c;越来越多的企业使用知识库软件对企业内部知识进行管理。BookStack作为一款功能强大的开源知识库软件&#xff0c;成为很多企业的首选。但是还是有一部分人群认为BookStack不适合他们的企业那么他们都是在用什么别的知识库软件呢&#xff1f;LookLook同学今天…