机器学习深度学习——NLP实战(情感分析模型——RNN实现)

👨‍🎓作者简介:一位即将上大四,正专攻机器学习的保研er
🌌上期文章:机器学习&&深度学习——NLP实战(情感分析模型——数据集)
📚订阅专栏:机器学习&&深度学习
希望文章对你们有所帮助

NLP实战(情感分析模型——RNN实现)

  • 引入
  • 使用循环神经网络表示单个文本
  • 加载预训练的词向量
  • 训练和评估模型
  • 小结

引入

与词相似度和类比任务一样,我们也可以将预先训练的词向量应用于情感分析。上节已经下载过了IMDb评论数据集了,这个数据集也不算很大(虽然下载了很久。。。),使用在大规模语料库上预训练的文本表示可以减少模型的过拟合。我们将使用GloVe模型来表示每个词元,并将这些词元表示送入多层双向循环神经网络以获得文本序列表示,该文本序列表示将被转换为情感分析输出。对于相同的下游应用,之后再讲不同的架构选择。
在这里插入图片描述
如上图所示,将GloVe送入基于循环神经网络的架构,用于情感分析。

import torch
from torch import nn
from d2l import torch as d2lbatch_size = 64
train_iter, test_iter, vocab = d2l.load_data_imdb(batch_size)

使用循环神经网络表示单个文本

在文本分类任务中,可变长度的文本序列将被转换为固定长度的类别。在下面的BiRNN类中,虽然文本序列的每个词元经过嵌入层self.embedding获得其单独的预训练GloVe表示,但是整个序列由双向循环神经网络self.encoder编码。更具体的说,双向长短期记忆网络在初始和最终时间步的隐状态(在最后一层)被连结起来作为文本序列的表示。然后,通过一个具有两个输出(“积极”和“消极”)的全连接层(self.decoder),将此单一文本表示转换为输出类别。

class BiRNN(nn.Module):def __init__(self, vocab_size, embed_size, num_hiddens,num_layers, **kwargs):super(BiRNN, self).__init__(**kwargs)self.embedding = nn.Embedding(vocab_size, embed_size)# 将bidirectional设置为True以获取双向循环神经网络self.encoder = nn.LSTM(embed_size, num_hiddens, num_layers=num_layers,bidirectional=True)self.decoder = nn.Linear(4 * num_hiddens, 2)def forward(self, inputs):# inputs的形状是(批量大小,时间步数)# 因为长短期记忆网络要求其输入的第一个维度是时间维,# 所以在获得词元表示之前,输入会被转置。# 输出形状为(时间步数,批量大小,词向量维度)embeddings = self.embedding(inputs.T)self.encoder.flatten_parameters()# 返回上一个隐藏层在不同时间步的隐状态,# outputs的形状是(时间步数,批量大小,2*隐藏单元数)outputs, _ = self.encoder(embeddings)# 连结初始和最终时间步的隐状态,作为全连接层的输入,# 其形状为(批量大小,4*隐藏单元数)encoding = torch.cat((outputs[0], outputs[-1]), dim=1)outs = self.decoder(encoding)return outs

让我们构造一个具有两个隐藏层的双向循环神经网络来表示单个文本以进行情感分析。

embed_size, num_hiddens, num_layers = 100, 100, 2
devices = d2l.try_all_gpus()
net = BiRNN(len(vocab), embed_size, num_hiddens, num_layers)def init_weights(m):if type(m) == nn.Linear:nn.init.xavier_uniform_(m.weight)if type(m) == nn.LSTM:for param in m._flat_weights_names:if "weight" in param:nn.init.xavier_uniform_(m._parameters[param])
net.apply(init_weights)

加载预训练的词向量

下面,我们为词表中的单词加载预训练的100维(要与embed_size一致)的GloVe嵌入。我们使用这些预训练的词向量来表示评论中的词元,并且在训练期间不要更新这些向量。

glove_embedding = d2l.TokenEmbedding('glove.6b.100d')
embeds = glove_embedding[vocab.idx_to_token]
net.embedding.weight.data.copy_(embeds)
net.embedding.weight.requires_grad = False

训练和评估模型

现在我们可以训练双向循环神经网络进行情感分析。使用Adam优化算法。

lr, num_epochs = 0.01, 5
trainer = torch.optim.Adam(net.parameters(), lr=lr)
loss = nn.CrossEntropyLoss(reduction="none")
d2l.train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs,devices)
d2l.plt.show()

运行结果:

loss 0.300, train acc 0.875, test acc 0.818
79.4 examples/sec on [device(type=‘cpu’)]

运行图片:
在这里插入图片描述
接着我们定义一下函数来使用训练好的模型net预测文本序列的情感。

#@save
def predict_sentiment(net, vocab, sequence):"""预测文本序列的情感"""sequence = torch.tensor(vocab[sequence.split()], device=d2l.try_gpu())label = torch.argmax(net(sequence.reshape(1, -1)), dim=1)return 'positive' if label == 1 else 'negative'

最后,让我们使用训练好的模型对两个简单的句子进行情感预测。

predict_sentiment(net, vocab, 'this movie is so great')

运行结果:

‘positive’

predict_sentiment(net, vocab, 'this movie is so bad')

运行结果:

‘negative’

小结

1、预训练的词向量可以表示文本序列中的各个词元。
2、双向循环神经网络可以表示文本序列。例如通过连结初始和最终时间步的隐状态,可以使用全连接的层将该单个文本表示转换为类别。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/45494.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

re学习(34)攻防世界-csaw2013reversing2(修改汇编顺序)

参考文章: re学习笔记(27)攻防世界-re-csaw2013reversing2_Forgo7ten的博客-CSDN博客攻防世界逆向入门题之csaw2013reversing2_沐一 林的博客-CSDN博客 三种做法 1、ida静态分析修改指令 main函数反编译的代码 由于运行之后的是乱码&…

[oneAPI] 手写数字识别-BiLSTM

[oneAPI] 手写数字识别-BiLSTM 手写数字识别参数与包加载数据模型训练过程结果 oneAPI 比赛:https://marketing.csdn.net/p/f3e44fbfe46c465f4d9d6c23e38e0517 Intel DevCloud for oneAPI:https://devcloud.intel.com/oneapi/get_started/aiAnalyticsToo…

python rtsp 硬件解码 二

上次使用了python的opencv模块 述说了使用PyNvCodec 模块,这个模块本身并没有rtsp的读写,那么读写rtsp是可以使用很多方法的,我们为了输出到pytorch直接使用AI程序,简化rtsp 输入,可以直接使用ffmpeg的子进程 方法一 …

STM8遇坑[EEPROM读取debug不正常release正常][ STVP下载成功单运行不成功][定时器消抖莫名其妙的跑不通流程]

EEPROM读取debug不正常release正常 这个超级无语,研究和半天,突然发现调到release就正常了,表现为写入看起来正常读取不正常,这个无语了,不想研究了 STVP下载不能够成功运行 本文摘录于:https://blog.csdn.net/qlexcel/article/details/71270780只是做学习备份之…

GEEMAP 中如何拉伸图像

图像拉伸是最基础的图像增强显示处理方法,主要用来改善图像显示的对比度,地物提取流程中往往首先要对图像进行拉伸处理。图像拉伸主要有三种方式:线性拉伸、直方图均衡化拉伸和直方图归一化拉伸。 GEE 中使用 .sldStyle() 的方法来进行图像的…

8.5.tensorRT高级(3)封装系列-基于生产者消费者实现的yolov5封装

目录 前言1. yolov5封装总结 前言 杜老师推出的 tensorRT从零起步高性能部署 课程,之前有看过一遍,但是没有做笔记,很多东西也忘了。这次重新撸一遍,顺便记记笔记。 本次课程学习 tensorRT 高级-基于生产者消费者实现的yolov5封装…

postgresql中基础sql查询

postgresql中基础sql查询 创建表插入数据创建索引删除表postgresql命令速查简单查询计算查询结果 利用查询条件过滤数据模糊查询 创建表 -- 部门信息表 CREATE TABLE departments( department_id INTEGER NOT NULL -- 部门编号,主键, department_name CHARACTE…

CentOS6.8图形界面安装Oracle11.2.0.1.0

Oracle11下载地址 https://edelivery.oracle.com/osdc/faces/SoftwareDelivery 一、环境 CentOS release 6.8 (Final),测试环境:内存2G,硬盘20G,SWAP空间4G Oracle版本:Release 11.2.0.1.0 安装包:V175…

Lookup Singularity

1. 引言 Lookup Singularity概念 由Barry WhiteHat在2022年11月在zkResearch论坛 Lookup Singularity中首次提出: 其主要目的是:让SNARK前端生成仅需做lookup的电路。Barry预测这样有很多好处,特别是对于可审计性 以及 形式化验证&#xff…

【学习FreeRTOS】第8章——FreeRTOS列表和列表项

1.列表和列表项的简介 列表是 FreeRTOS 中的一个数据结构,概念上和链表有点类似,列表被用来跟踪 FreeRTOS中的任务。列表项就是存放在列表中的项目。 列表相当于链表,列表项相当于节点,FreeRTOS 中的列表是一个双向环形链表列表的…

微软Win11 Dev预览版Build23526发布

近日,微软Win11 Dev预览版Build23526发布,修复了不少问题。牛比如斯Microsoft,也有这么多bug,所以你写再多bug也不作为奇啊。 主要更新问题 [开始菜单] 修复了在高对比度主题下,打开开始菜单中的“所有应…

Spring Boot通过企业邮箱发件被Gmail退回的解决方法

这两天给我们开发的Chrome插件:Youtube中文配音 增加了账户注册和登录功能,其中有一步是邮箱验证,所以这边会在Spring Boot后台给用户的邮箱发个验证信息。如何发邮件在之前的文章教程里就有,这里就不说了,着重说说这两…

通过 kk 创建 k8s 集群和 kubesphere

官方文档:多节点安装 确保从正确的区域下载 KubeKey export KKZONEcn下载 KubeKey curl -sfL https://get-kk.kubesphere.io | VERSIONv3.0.7 sh -为 kk 添加可执行权限: chmod x kk创建 config 文件 KubeSphere 版本:v3.3 支持的 Kuber…

Linux 安全技术和防火墙

目录 1 安全技术 2 防火墙 2.1 防火墙的分类 2.1.1 包过滤防火墙 2.1.2 应用层防火墙 3 Linux 防火墙的基本认识 3.1 iptables & netfilter 3.2 四表五链 4 iptables 4.2 数据包的常见控制类型 4.3 实际操作 4.3.1 加新的防火墙规则 4.3.2 查看规则表 4.3.…

企事业数字培训及知识库平台

前言 随着信息化的进一步推进,目前各行各业都在进行数字化转型,本人从事过医疗、政务等系统的研发,和客户深入交流过日常办公中“知识”的重要性,再加上现在倡导的互联互通、数据安全、无纸化办公等概念,所以无论是企业…

打家劫舍 II——力扣213

动规 int robrange(vector<int>& nums, int start, int end){int first=nums[start]

CountDownLatch和CyclicBarrie

前置提要 什么是闭锁对象 闭锁对象&#xff08;Latch Object&#xff09;是一种同步工具&#xff0c;用于控制线程的等待和执行顺序。闭锁对象可以让一个或多个线程等待&#xff0c;直到特定的条件满足后才能继续执行。 在Java中&#xff0c;CountDownLatch就是一种常见的闭锁对…

STC15单片机PM2.5空气质量检测仪

一、系统方案 本设计采用STC15单片机作为主控制器&#xff0c;PM2.5传感器、按键设置&#xff0c;液晶1602显示&#xff0c;蜂鸣器报警。 二、硬件设计 原理图如下&#xff1a; 三、单片机软件设计 1、首先是系统初始化&#xff1a; void lcd_init()//液晶初始化设置 { de…

SQLite数据库实现数据增删改查

当前文章介绍的设计的主要功能是利用 SQLite 数据库实现宠物投喂器上传数据的存储&#xff0c;并且支持数据的增删改查操作。其中&#xff0c;宠物投喂器上传的数据包括投喂间隔时间、水温、剩余重量等参数。 实现功能&#xff1a; 创建 SQLite 数据库表&#xff0c;用于存储宠…

第一讲:BeanFactory和ApplicationContext接口

BeanFactory和ApplicationContext接口 1. 什么是BeanFactory?2. BeanFactory能做什么&#xff1f;3.ApplicationContext对比BeanFactory的额外功能?3.1 MessageSource3.2 ResourcePatternResolver3.3 EnvironmentCapable3.4 ApplicationEventPublisher 4.总结 1. 什么是BeanF…