基于pytorch使用LSTM实现文本匹配任务

本文学习纪录

PyTorch深度学习项目实战100例

使用LSTM来实现文本匹配任务

使用LSTM(Long Short-Term Memory)网络来实现文本匹配任务是自然语言处理(NLP)中的一个常见应用。文本匹配任务的目标是确定两个文本段落是否在某种程度上相似或相关,例如在问答系统、文档检索、相似问题匹配等场景中非常有用。

句⼦1:我不爱吃剁椒⻥头,但是我爱吃⻥头
句⼦2:我爱吃⼟⾖,但是不爱吃地⽠

模型构建

输入层:两个独立的输入,分别对应两个文本序列。
LSTM层:为每个输入文本设计一个LSTM层来捕获序列信息。可以使用双向LSTM(BiLSTM)来获取前后文信息。
相似度计算:使用余弦相似度、曼哈顿距离、欧式距离等方法计算两个LSTM层的输出向量之间的相似度。
输出层:根据相似度分数输出匹配程度,可以是二分类(匹配或不匹配)或者回归(相似度得分)。

在这里插入图片描述

定义网络

# 定义网络结构
class LSTM(nn.Module):def __init__(self, vocab_size, hidden_dim, num_layers, embedding_dim, output_dim):super(LSTM, self).__init__()self.hidden_dim = hidden_dim  # 隐层大小self.num_layers = num_layers  # LSTM层数# 嵌入层,会对所有词形成一个连续型嵌入向量,该向量的维度为embedding_dim# 然后利用这个向量来表示该字,而不是用索引继续表示self.embeddings_x = nn.Embedding(vocab_size + 1, embedding_dim)self.embeddings_y = nn.Embedding(vocab_size + 1, embedding_dim)# 定义LSTM层,第一个参数为每个时间步的特征大小,这里就是每个字的维度# 第二个参数为隐层大小# 第三个参数为lstm的层数self.lstm_x = nn.LSTM(embedding_dim, hidden_dim, num_layers)self.lstm_y = nn.LSTM(embedding_dim, hidden_dim, num_layers)self.cos_sim = nn.CosineSimilarity(dim=1, eps=1e-6)# 利用全连接层将其映射为2维,即0和1的概率self.fc = nn.Linear(1, output_dim)def forward(self, x_input, y_input):# 1.首先形成嵌入向量embeds_x = self.embeddings_x(x_input)embeds_y = self.embeddings_y(x_input)# 2.将嵌入向量导入到lstm层output_x, _ = self.lstm_x(embeds_x)output_y, _ = self.lstm_x(embeds_y)timestep, batch_size, hidden_dim = output_x.shapeoutput_x = output_x.reshape(timestep, batch_size, -1)output_y = output_y.reshape(timestep, batch_size, -1)# 3.获取lstm最后一个隐层表示向量output_x = output_x[-1]output_y = output_y[-1]# 4.计算两个向量的余弦相似度sim = self.cos_sim(output_x, output_y)sim = sim.view(-1, 1)# 5.形成最终输出结果output = self.fc(sim)return output

模型训练

# 6.模型训练
model = LSTM(vocab_size=vocab_size, hidden_dim=hidden_dim, num_layers=num_layers,embedding_dim=embedding_dim, output_dim=output_dim)Configimizer = optim.Adam(model.parameters(), lr=lr) # 优化器
criterion = nn.CrossEntropyLoss() # 多分类损失函数model.to(device)
loss_meter = meter.AverageValueMeter()best_acc = 0 # 保存最好准确率
best_model = None # 保存对应最好准确率的模型参数for epoch in range(epochs):model.train() # 开启训练模式epoch_acc = 0 # 每个epoch的准确率epoch_acc_count = 0 # 每个epoch训练的样本数train_count = 0 # 用于计算总的样本数,方便求准确率loss_meter.reset()train_bar = tqdm(train_loader)  # 形成进度条for data in train_bar:x_input, y_input, label = data  # 解包迭代器中的X和Yx_input = x_input.long().transpose(1, 0).contiguous()x_input = x_input.to(device)y_input = y_input.long().transpose(1, 0).contiguous()y_input = y_input.to(device)Configimizer.zero_grad()# 形成预测结果output_ = model(x_input, y_input)# 计算损失loss = criterion(output_, label.long().view(-1))loss.backward()Configimizer.step()loss_meter.add(loss.item())# 计算每个epoch正确的个数epoch_acc_count += (output_.argmax(axis=1) == label.view(-1)).sum()train_count += len(x_input)# 每个epoch对应的准确率epoch_acc = epoch_acc_count / train_count# 打印信息print("【EPOCH: 】%s" % str(epoch + 1))print("训练损失为%s" % (str(loss_meter.mean)))print("训练精度为%s" % (str(epoch_acc.item() * 100)[:5]) + '%')# 保存模型及相关信息if epoch_acc > best_acc:best_acc = epoch_accbest_model = model.state_dict()# 在训练结束保存最优的模型参数if epoch == epochs - 1:# 保存模型torch.save(best_model, './best_model.pkl')

测试语句

try:# 数据预处理input_shape = 20 # 序列长度,就是时间步大小,也就是这里的每句话中的词的个数# 用于测试的话sentence1 = "我不爱吃剁椒鱼头,但是我爱吃鱼头"sentence2 = "我爱吃土豆,但是不爱吃地瓜"# 将对应的字转化为相应的序号x_input = [[word2idx[word] for word in sentence1]]x_input = pad_sequences(maxlen=input_shape, sequences=x_input, padding='post', value=0)x_input = torch.from_numpy(x_input)y_input = [[word2idx[word] for word in sentence2]]y_input = pad_sequences(maxlen=input_shape, sequences=y_input, padding='post', value=0)y_input = torch.from_numpy(y_input)# 加载模型model_path = './best_model.pkl'model = LSTM(vocab_size=vocab_size, hidden_dim=hidden_dim, num_layers=num_layers,embedding_dim=embedding_dim, output_dim=output_dim)model.load_state_dict(torch.load(model_path, 'cpu'))# 模型预测,注意输入的数据第一个input_shapey_pred = model(x_input.long().transpose(1, 0), y_input.long().transpose(1, 0))idx2label = {0:"匹配失败!", 1:"匹配成功!"}print('输入语句: %s \t %s' % (sentence1, sentence2))print('文本匹配结果: %s' % idx2label[y_pred.argmax().item()])except KeyError as err:print("您输入的句子有汉字不在词汇表中,请重新输入!")print("不在词汇表中的单词为:%s." % err)

数据集为QA_corpus,训练数据10w条,验证集和测试集均为1w条

其中对应模型文件夹下的args.py文件是超参数

QA_corpus
数据集展示在这里插入图片描述

结果

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/702521.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Vue + Echarts页面内存占用高问题解决

Vue Echarts页面内存占用高问题解决 1.问题描述 目前使用的是Vue2 Echarts4.x的组合,页面如下所示。 就是一个类似于神策的数据看板页面,左侧是一个导航栏,右侧看板页面中包含很多个报表图片,其中报表页面中对Echarts图表进…

STL常用容器(string容器)---C++

STL常用容器目录 1.string容器1.1 string基本概念1.2 string构造函数1.3 string赋值操作1.4 string字符串拼接1.5 string查找和替换1.6 string字符串比较1.7 string字符存取1.8 string插入和删除1.9 string子串 1.string容器 1.1 string基本概念 本质: string是C…

电子签证小程序系统源码后台功能列表

基于ThinkPhp8.0uniapp 开发的电子签证小程序管理系统。能够真正帮助企业基于微信公众号H5、小程序、wap、pc、APP等,实现会员管理、数据分析,精准营销的电子商务管理系统。可满足企业新零售、批发、分销、预约、O2O、多店等各种业务需求,快速积累客户、…

搜索专项---IDA*

文章目录 排书回转游戏 一、排书OJ链接 本题思路:先考虑每一步的决策数量:当抽取长度为 i 的一段时,有 n−i1 种抽法,对于每种抽法,有 n−i 种放法。另外,将某一段向前移动,等价于将跳过的那段向后移动&am…

C++之std::tuple(二) : 揭秘底层实现原理

相关系列文章 C之std::tuple(二) : 揭秘底层实现原理 C三剑客之std::any(一) : 使用 C之std::tuple(一) : 使用精讲(全) C三剑客之std::variant(一) : 使用 C三剑客之std::variant(二):深入剖析 深入理解可变参数(va_list、std::initializer_list和可变参数模版) st…

【JVM】线上一次fullGC排查思路

fullGC问题背景 监控告警发现,今天开始我们线上应用频繁出现fullGC,并且每次出现后磁盘都会被占满 查看监控 查看监控发现FULLGC的机器均为同一个机房的集器,并且该机房有线上error报错,数据库监控对应的时间点也有异常&#x…

数据结构知识点总结-绪论 数据结构基本术语 算法及评价

要求 (1)对数据结构这么课学了哪些知识有个清楚的认知; (2)掌握目录结构,能复述出来每个知识点下都有哪些内容。 如下图所示,可自行制作思维导图,针对自己薄弱的地方进行复习。 …

curl与HTTP状态码

目录 一、curl (一)curl简介 (二)curl命令的选项 二、HTTP状态码 (一)状态码的含义 (二)状态码分类 1.默认的状态码 2.自定义状态码 一、curl (一)c…

NGINX服务器配置实现加密的WebSocket连接WSS协议

一、背景 最近在做小程序开发,需要在nginx中配置websocket加密模式,即wss。初次配置wss时,踩了两个小时的坑,本文将踩坑过程分享给大家,有需要用到的伙伴可以直接copy即可实现,节省宝贵时间。 二、WebSo…

代码随想录第41天|● 01背包问题,你该了解这些! ● 01背包问题,你该了解这些! 滚动数组 ● 416. 分割等和子集

文章目录 背包问题背包题目解法一 ● 01背包问题-二维数组五部曲1.确定dp数组2、确定递推公式3、初始化dp数组4、循环代码: 解法二-01背包问题-滚动数组五部曲1:定义dp二、递推公式三、初始化四、循环顺序代码: 698. 划分为k个相等的子集题解…

FairyGUI × Cocos Creator 3.x 使用方式

前言 上一篇文章 FariyGUI Cocos Creator 入门 简单介绍了FairyGUI,并且按照官方demo成功在Cocos Creator2.4.0上运行起来了。 当我今天使用Creator 3.x 再引入2.x的Lib时,发现出现了报错。 这篇文章将介绍如何在Creator 3.x上使用fgui。 引入 首先&…

uniapp开发安卓app华为平板真机预览

首先使用数据线连接平板和电脑设备 一、前期准备 平板需要开启三个地方: 1、打开设置,在搜索框中输入版本号/或者直接点击最下方的【关于平板电脑】,点击版本号进入关于平板的界面,连续点击版本号7次,直到出现提醒“…

2.25基础会计学

资本公积是指由股东投入、但不能构成“股本”或“实收资本”的资金部分。 盈余公积是指公司按照规定从净利润中提取的各种积累资金。 所以区别在于盈余公积来自净利润。 借贷其实就是钱从哪来和到哪去的问题,来源是贷,流向是借。比如购入9w原材料&…

基于自适应波束成形算法的matlab性能仿真,对比SG和RLS两种方法

目录 1.程序功能描述 2.测试软件版本以及运行结果展示 3.核心程序 4.本算法原理 5.完整程序 1.程序功能描述 基于自适应波束成形算法的matlab性能仿真,对比SG和RLS两种方法. 2.测试软件版本以及运行结果展示 MATLAB2022a版本运行 3.核心程序 ........................…

facebook群控如何做?使用静态住宅ip代理有什么好处?

在进行Facebook群控时,ip地址的管理是非常重要的,因为Facebook通常会检测ip地址的使用情况,如果发现有异常的使用行为,比如从同一个ip地址频繁进行登录、发布内容或者在短时间内进行大量的活动等等,就会视为垃圾邮件或…

RK3568平台开发系列讲解(Linux系统篇)字符设备驱动:分配和注册字符设备

🚀返回专栏总目录 文章目录 一、分配和注册字符设备二、file_operations沉淀、分享、成长,让自己和他人都能有所收获!😄 一、分配和注册字符设备 字符设备在内核中表示为struct cdev的实例。在编写字符设备驱动程序时,目标是最终创建并注册与struct file_operations关联…

栈和队列笔试题

答案:(1)seqn[tail]data; tail(tail1)%SEQLEN; (2)data seqn[head]; head (head1)%SEQLEN; (3)head tail; (4)(tail1)%SEQLEN head; (5) while(head!tail) head (h…

JVM内存结构介绍

1.程序计数器(Program Counter Register) 程序计数器是一块较小的内存空间,它的作用可以看做是当前线程所执行的字节码的行号指示器。在虚拟机的概念模型里(仅是概念模型,各种虚拟机可能会通过一些更高效的方式去实现&…

电商评价分析:NLP信息抽取技术在用户评论中的应用与挖掘

一、引言 在2019年,电子商务的蓬勃发展不仅推动了消费市场的增长,也带来了海量的用户评价数据。这些数据,作为消费者对商品和服务直接反馈的载体,蕴含着巨大的价值。然而,由于其非结构化的特性,这些文本信息…

解决ssh:connect to host github.com port 22: Connection timed out与kex_exchange_identification

一、问题 无法进行clone项目和其他Git操作。执行检测连接命令 ssh -T gitgithub,com报错 ssh:connect to host github.com port 22: Connection timed out 即:连接22端口超时 涉及到的文件: C:\Users\JIACHENGER.ssh\config C:\Users\JIACHENGER.ssh\…