Mindspore框架循环神经网络RNN模型实现情感分类|(二)RNN模型构建

Mindspore框架循环神经网络RNN模型实现情感分类

Mindspore框架循环神经网络RNN模型实现情感分类|(一)IMDB影评数据集准备
Mindspore框架循环神经网络RNN模型实现情感分类|(二)RNN模型构建
Mindspore框架循环神经网络RNN模型实现情感分类|(三)损失函数与优化器
Mindspore框架循环神经网络RNN模型实现情感分类|(四)模型训练与推理

tips:安装依赖库

pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore==2.2.14
pip install tqdm requests

一、RNN模型构建

数据集准备完成了输入文本通过查字典(序列化)的向量化。并使用nn.Embedding层加载了Glove词向量。下一步将使用RNN循环神经网络做特征提取,最后将RNN连接至全连接网络nn.Dednse,将特征转化为分类。

nn.Embedding -> nn.RNN -> nn.Dense

本项目,采用规避RNN梯度消的变种LSTM(Long short-term memory)代替RNN做特征提取层。

1.1 关于RNN

循环神经网络(Recurrent Neural Network, RNN)是一类以序列(sequence)数据为输入,在序列的演进方向进行递归(recursion)且所有节点(循环单元)按链式连接的神经网络。下图为RNN的一般结构:

RNN-0

图示左侧为一个RNN Cell循环,右侧为RNN的链式连接平铺。实际上不管是单个RNN Cell还是一个RNN网络,都只有一个Cell的参数,在不断进行循环计算中更新。

由于RNN的循环特性,和自然语言文本的序列特性(句子是由单词组成的序列)十分匹配,因此被大量应用于自然语言处理研究中。下图为RNN的结构拆解:

RNN

1.2 关于LSTM(Long short-term memory)

RNN单个Cell的结构简单,因此也造成了梯度消失(Gradient Vanishing)问题,具体表现为RNN网络在序列较长时,在序列尾部已经基本丢失了序列首部的信息。为了克服这一问题,LSTM(Long short-term memory)被提出,通过门控机制(Gating Mechanism)来控制信息流在每个循环步中的留存和丢弃。下图为LSTM的结构拆解:

LSTM

本项目选择LSTM变种而不是经典的RNN做特征提取,可规避梯度消失问题,并获得更好的模型效果。
在MindSpore中nn.LSTM对应的公式:

h 0 : t , ( h t , c t ) = LSTM ( x 0 : t , ( h 0 , c 0 ) ) h_{0:t}, (h_t, c_t) = \text{LSTM}(x_{0:t}, (h_0, c_0)) h0:t,(ht,ct)=LSTM(x0:t,(h0,c0))

这里nn.LSTM隐藏了整个循环神经网络在序列时间步(Time step)上的循环,送入输入序列、初始状态,即可获得每个时间步的隐状态(hidden state`)拼接而成的矩阵,以及最后一个时间步对应的隐状态。我们使用最后的一个时间步的隐状态作为输入句子的编码特征,送入下一层

Time step:在循环神经网络计算的每一次循环,成为一个Time step。在送入文本序列时,一个Time step对应一个单词。因此在本例中,LSTM的输出 h 0 : t h_{0:t} h0:t对应每个单词的隐状态集合, h t h_t ht c t c_t ct对应最后一个单词对应的隐状态。

下一层:全连接层,即nn.Dense,将特征维度变换为二分类所需的维度1,经过Dense层后的输出即为模型预测结果。

1.3 特征提取网络构建

RNN循环神经网络: nn.LSTM()
初始化参数:

 embeddings:输入向量,hidden_dim:隐藏层特征的维度, output_dim:输出维数, n_layers:RNN 层的数量,bidirectional:是否为双向 RNN, pad_idx:padding_idx参数用于标记输入中的填充值(padding value)。在自然语言处理任务中,文本序列的长度不一致是非常常见的。为了能够对不同长度的文本序列进行批处理,我们通常会使用填充值对较短的序列进行填补。

tips:使用nn.embeddings()创建嵌入层时,可以通过padding_idx参数指定一个特定的索引,用于表示填充值。
embedding_layer = nn.Embedding(num_embeddings, embedding_dim, padding_idx=0),将padding_idx设置为0,表示使用索引为0的词汇作为填充值。在文本序列中,我们将使用0来填充较短的序列。

import math
import mindspore as ms
import mindspore.nn as nn
import mindspore.ops as ops
from mindspore.common.initializer import Uniform, HeUniformclass RNN(nn.Cell):def __init__(self, embeddings, hidden_dim, output_dim, n_layers,bidirectional, pad_idx):super().__init__()vocab_size, embedding_dim = embeddings.shapeself.embedding = nn.Embedding(vocab_size, embedding_dim, embedding_table=ms.Tensor(embeddings), padding_idx=pad_idx)self.rnn = nn.LSTM(embedding_dim,hidden_dim,num_layers=n_layers,bidirectional=bidirectional,batch_first=True)weight_init = HeUniform(math.sqrt(5))bias_init = Uniform(1 / math.sqrt(hidden_dim * 2))self.fc = nn.Dense(hidden_dim * 2, output_dim, weight_init=weight_init, bias_init=bias_init)def construct(self, inputs):embedded = self.embedding(inputs)_, (hidden, _) = self.rnn(embedded)hidden = ops.concat((hidden[-2, :, :], hidden[-1, :, :]), axis=1)output = self.fc(hidden)return output

实例化模型,打印输出

hidden_size = 256
output_size = 1
num_layers = 2
bidirectional = True
lr = 0.001
pad_idx = vocab.tokens_to_ids('<pad>')model = RNN(embeddings, hidden_size, output_size, num_layers, bidirectional, pad_idx)
print(model)

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/48723.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

java代码判断windows还是linux系统

public String getFilePath() {String osName System.getProperty("os.name").toLowerCase();log.info("此系统名称为:{}", osName);// 判断操作系统if (osName.contains("windows")) {return "windows";} else {//linux地址return &…

『大模型笔记』LLM秘密:温度、Top-K和Top-P抽样技术解析!

『大模型笔记』LLM秘密:温度、Top-K和Top-P抽样技术解析! 文章目录 一. LLM秘密:温度、Top-K和Top-P随机采样技术解析!1. 温度(Temperature)参数2. Top-K采样3. Top-P采样4. 总结补充:TopK采样解释:步骤1: 确定Top-K词步骤2: 归一化选择的Top-K词的概率步骤3: 从重新归…

实测!高性能PCIe 5.0 SSD为AI训练贡献了啥?

如今&#xff0c;AI 的火热程度已经不需要解释。算力、算法、数据&#xff0c;构成驱动 AI 技术快速发展的三驾马车&#xff1a; 算力越强&#xff0c;越能够处理更加复杂的训练模型&#xff0c;并加速训练进程&#xff1b;算法越先进&#xff0c;越能高效的从数据中学习并优化…

element-plus的el-table自定义表头筛选查询

文章目录 一、效果二、代码1.代码可直接复制使用 三、问题1.使用el-popover完成筛选框 一、效果 二、代码 1.代码可直接复制使用 <template><div class"page-view" click"handleClickOutside"><el-button click"resetFilters"&…

云服务器2核2G配置可以干嘛?2C2G用来做什么合适?

最便宜的服务器2核2G配置可以做什么&#xff1f;可用来搭建Web网站服务器、小程序服务器、APP后端服务器、图床、开发测试环境、小型电子商务网站及文件存储等使用场景。服务器百科网fwqbk.com举例说明&#xff1a; 搭建个人博客&#xff1a;使用开源WordPress等博客程序系统&…

git使用、git与idea结合、gitee、gitlab

本文章基于黑马程序javase模块中的"git"部分 先言:git在集成idea中,不同版本的idea中页面显示不同,操作时更注重基于选项的文字;git基于命令操作参考文档实现即可,idea工具继承使用重点掌握 1.git概述 git是目前世界上最先进的分布式文件版本控制系统 分布式:将…

照片怎么改大小kb?分享5个改小图片的工具

夏日炎炎&#xff0c;正是出游拍照的大好时节&#xff0c;然而随之而来的问题也让人头疼——手机里的美照越来越多&#xff0c;存储空间却越来越紧张。 不仅如此&#xff0c;上传至社交媒体时&#xff0c;大尺寸的照片常常让加载速度慢如蜗牛&#xff0c;影响了分享的乐趣。 …

java设计模式:02-01-单例模式

单例模式&#xff08;Singleton Pattern&#xff09; 单例模式&#xff08;Singleton Pattern&#xff09;确保一个类只有一个实例&#xff0c;并提供一个全局访问点来访问该实例。单例模式通常用于需要全局唯一实例的场景&#xff0c;例如&#xff1a; 日志记录&#xff08;…

[Spring Boot]Protobuf解析MQTT消息体

简述 本文主要针对在MQTT场景下&#xff0c;使用Protobuf协议解析MQTT的消息体 Protobuf下载 官方下载 https://github.com/protocolbuffers/protobuf/releases网盘下载 链接&#xff1a;https://pan.baidu.com/s/1Uz7CZuOSwa8VCDl-6r2xzw?pwdanan 提取码&#xff1a;an…

【毕业论文】| 基于Unity3D引擎的冒险游戏的设计与实现

&#x1f4e2;博客主页&#xff1a;肩匣与橘 &#x1f4e2;欢迎点赞 &#x1f44d; 收藏 ⭐留言 &#x1f4dd; 如有错误敬请指正&#xff01; &#x1f4e2;本文由肩匣与橘编写&#xff0c;首发于CSDN&#x1f649; &#x1f4e2;生活依旧是美好而又温柔的&#xff0c;你也…

mysql + Oracle

eg627. 变更性别 Salary 表&#xff1a; ----------------------- | Column Name | Type | ----------------------- | id | int | | name | varchar | | sex | ENUM | | salary | int | ----------------------- id 是这个表…

SetPriorityClass 函数潜在得32MB工作集内存限制

MSDN&#xff1a; SetPriorityClass function (processthreadsapi.h) - Win32 apps | Microsoft Learn 当人们调用 SetPriorityClass 函数并且将 “dwPriorityClass” 参数设置为&#xff1a;PROCESS_MODE_BACKGROUND_BEGIN 时。 应用程序得工作集&#xff08;Working Set&…

小程序-4(自定义组件:数据、属性、数据监听器、生命周期函数、插槽、父子通信、behaviors)

目录 1.组件的创建和引用 局部引用组件 全局引用组件 组件和页面的区别 组件样式隔离 ​编辑 组件样式隔离的注意点 修改组件的样式隔离选项 data数据 methods方法 properties属性 data和properties属性的区别 使用setData修改properties的值 2.数据监听器 什么…

昇思25天学习打卡营第19天|MindNLP ChatGLM-6B StreamChat

文章目录 昇思MindSpore应用实践ChatGML-6B简介基于MindNLP的ChatGLM-6B StreamChat Reference 昇思MindSpore应用实践 本系列文章主要用于记录昇思25天学习打卡营的学习心得。 ChatGML-6B简介 ChatGLM-6B 是由清华大学和智谱AI联合研发的产品&#xff0c;是一个开源的、支持…

数据结构——线性表(循环链表)

一、循环链表定义 将单链表中终端结点的指针端由空指针改为指向头结点&#xff0c;就使整个单链表形成一 个环&#xff0c;这种头尾相接的单链表称为单循环链表&#xff0c;简称循环链表(circular linked list)。 循环链表解决了一个很麻烦的问题。如何从当中一 个结点出发&am…

如何看待“微软蓝屏”冲上热搜,Windows系统电脑死机?

微软蓝屏事件席卷全球,Windows系统遭遇重大挑战 近日,一则关于微软蓝屏的新闻在各大社交媒体平台上迅速传开,并登上热搜榜首。全球多地报告称Windows系统电脑出现死机、蓝屏现象,引发了广泛关注。 一、事件概述 据报道,此次微软蓝屏事件始于当地时间2024年7月19日,美国…

vs code中多个c文件的编译、调试,对应的task.json/launch.json文件的设置

目录 0. 容易犯的错误 1. 编译&#xff1a;需要修改原始task.json文件&#xff0c;修改处用加粗表示 2. 调试&#xff1a;需修改launch.json文件 3. 完整配置流程参考 0. 容易犯的错误 注意&#xff0c;不能在同一个文件夹下写两个都包含main()函数的.c文件&#xff0c;这会…

IO多路复用-poll的使用详解【C语言】

1.什么是poll poll 是一种用于监控多个文件描述符状态的系统调用&#xff0c;它可以等待多个文件描述符上的事件发生。它与 select 和 epoll 类似&#xff0c;但在某些场景下使用更为方便。 poll的机制与select类似&#xff0c;与select在本质上没有多大差别&#xff0c;使用…

oops使用笔记

oops-plugin-excel-to-json 使用 gitee上的文档图片不可见 参考 > https://forum.cocos.org/t/topic/156800 配置&#xff0c;打开“项目设置”,拖动到最下面&#xff0c;有一个"Excel to Json",前3项采用默认配置吧Excel, 默认的Excel目录是与assets平级的excel目…

如何确定企业信息系统的安全保护等级

确定企业信息系统的安全保护等级通常需要考虑以下几个方面&#xff1a; 1. 业务信息安全 - 系统所处理的业务数据的重要性、敏感性和保密性。例如&#xff0c;涉及国家机密、商业机密、个人隐私的数据通常具有较高的安全要求。 - 数据遭到破坏、泄露或篡改后对企业、社会或国家…