实战:循环神经网络与文本内容情感分类

在传统的神经网络模型中,是从输入层到隐含层再到输出层,层与层之间是全连接的,每层之间的节点是无连接的。但是这种普通的神经网络对于很多问题却无能为力。例如,你要预测句子的下一个单词是什么,一般需要用到前面的单词,因为一个句子中前后单词并不是独立的,即一个序列当前的输出与前面的输出也有关。

具体的表现形式为网络会对前面的信息进行记忆并应用于当前输出的计算中,即隐藏层之间的节点不再是无连接的,而是有连接的,并且隐藏层的输入不仅包括输入层的输出,还包括上一时刻隐藏层的输出,这种传统的神经网络模型如图8-2所示。

8.2.1  基于循环神经网络的中文情感分类准备工作

在讲解循环神经网络的理论知识之前,最好的学习方式就是通过实例实现并运行对应的项目,本小节将带领读者完成一下循环神经网络的情感分类实战的准备工作。

1. 数据的准备

首先是数据集的准备工作。在本节中,我们需要完成的是中文数据集的情感分类,因此事先准备了一套已完成情感分类的数据集,读者可以参考本书配套代码中dataset目录下的chnSenticrop.txt文件确认一下。此时我们需要完成数据的读取和准备工作,其实现代码如下:

max_length = 80         #设置获取的文本长度为80
labels = []             #用以存放label
context = []            #用以存放汉字文本
vocab = set()           
with open("../dataset/cn/ChnSentiCorp.txt", mode="r", encoding="UTF-8") as emotion_file:for line in emotion_file.readlines():line = line.strip().split(",")# labels.append(int(line[0]))if int(line[0]) == 0:labels.append(0)    #由于在后面直接采用PyTorch自带的crossentroy函数,因此这里直接输入0,否则输入[1,0]else:labels.append(1)text = "".join(line[1:])context.append(text)for char in text: vocab.add(char)   #建立vocab和vocab编号voacb_list = list(sorted(vocab))
# print(len(voacb_list))
token_list = []
#下面是对context内容根据vocab进行token处理
for text in context:token = [voacb_list.index(char) for char in text]token = token[:max_length] + [0] * (max_length - len(token))token_list.append(token)

2. 模型的建立

接下来可以根据需求建立模型。在这里我们实现了一个带有单向GRU和一个双向GRU的循环神经网络,代码如下:

class RNNModel(torch.nn.Module):def __init__(self,vocab_size = 128):super().__init__()self.embedding_table = torch.nn.Embedding(vocab_size,embedding_dim=312)self.gru  =  torch.nn.GRU(312,256)  # 注意这里输出有两个:out与hidden,out是序列在模型运行后全部隐藏层的状态,而hidden是最后一个隐藏层的状态self.batch_norm = torch.nn.LayerNorm(256,256)self.gru2  =  torch.nn.GRU(256,128,bidirectional=True)  # 注意这里输出有两个:out与hidden,out是序列在模型运行后全部隐藏层的状态,而hidden是最后一个隐藏层的状态def forward(self,token):token_inputs = tokenembedding = self.embedding_table(token_inputs)gru_out,_ = self.gru(embedding)embedding = self.batch_norm(gru_out)out,hidden = self.gru2(embedding)return out

这里要注意的是,对于GRU进行神经网络训练,无论是单向还是双向GUR,其结果输出都是两个隐藏层状态,即out与hidden。这里的out是序列在模型运行后全部隐藏层的状态,而hidden是此序列最后一个隐藏层的状态。

在这里我们使用的是2层GRU,有读者会注意到,在我们对第二个GRU进行定义时,使用了一个额外的参数bidirectional,这个参数用来定义循环神经网络是单向计算还是双向计算的,其具体形式如图8-3所示。

从图8-3中可以很明显地看到,左右两个连续的模块并联构成了不同方向的循环神经网络单向计算层,而这两个方向同时作用后生成了最终的隐藏层。

8.2.2  基于循环神经网络的中文情感分类

上一小节完成了循环神经网络的数据准备以及模型的建立,下面我们可以对中文数据集进行情感分类,完整的代码如下:

import numpy as npmax_length = 80         #设置获取的文本长度为80
labels = []             #用以存放label
context = []            #用以存放汉字文本
vocab = set()           with open("../dataset/cn/ChnSentiCorp.txt", mode="r", encoding="UTF-8") as emotion_file:for line in emotion_file.readlines():line = line.strip().split(",")# labels.append(int(line[0]))if int(line[0]) == 0:labels.append(0)    #由于在后面直接采用PyTorch自带的crossentroy函数,因此这里直接输入0,否则输入[1,0]else:labels.append(1)text = "".join(line[1:])context.append(text)for char in text: vocab.add(char)   #建立vocab和vocab编号voacb_list = list(sorted(vocab))
# print(len(voacb_list))
token_list = []
#下面的内容是对context根据vocab进行token处理
for text in context:token = [voacb_list.index(char) for char in text]token = token[:max_length] + [0] * (max_length - len(token))token_list.append(token)seed = 17
np.random.seed(seed);np.random.shuffle(token_list)
np.random.seed(seed);np.random.shuffle(labels)dev_list = np.array(token_list[:170])
dev_labels = np.array(labels[:170])token_list = np.array(token_list[170:])
labels = np.array(labels[170:])import torch
class RNNModel(torch.nn.Module):def __init__(self,vocab_size = 128):super().__init__()self.embedding_table = torch.nn.Embedding(vocab_size,embedding_dim=312)self.gru  =  torch.nn.GRU(312,256)  # 注意这里输出有两个:out与hidden,out是序列在模型运行后全部隐藏层的状态,而hidden是最后一个隐藏层的状态self.batch_norm = torch.nn.LayerNorm(256,256)self.gru2  =  torch.nn.GRU(256,128,bidirectional=True)  # 注意这里输出有两个:out与hidden,out是序列在模型运行后全部隐藏层的状态,而hidden是最后一个隐藏层的状态def forward(self,token):token_inputs = tokenembedding = self.embedding_table(token_inputs)gru_out,_ = self.gru(embedding)embedding = self.batch_norm(gru_out)out,hidden = self.gru2(embedding)return out#这里使用顺序模型的方式建立了训练模型
def get_model(vocab_size = len(voacb_list),max_length = max_length):model = torch.nn.Sequential(RNNModel(vocab_size),torch.nn.Flatten(),torch.nn.Linear(2 * max_length * 128,2))return modeldevice = "cuda"
model = get_model().to(device)
model = torch.compile(model)
optimizer = torch.optim.Adam(model.parameters(), lr=2e-4)loss_func = torch.nn.CrossEntropyLoss()batch_size = 128
train_length = len(labels)
for epoch in (range(21)):train_num = train_length // batch_sizetrain_loss, train_correct = 0, 0for i in (range(train_num)):start = i * batch_sizeend = (i + 1) * batch_sizebatch_input_ids = torch.tensor(token_list[start:end]).to(device)batch_labels = torch.tensor(labels[start:end]).to(device)pred = model(batch_input_ids)loss = loss_func(pred, batch_labels.type(torch.uint8))optimizer.zero_grad()loss.backward()optimizer.step()train_loss += loss.item()train_correct += ((torch.argmax(pred, dim=-1) == (batch_labels)).type(torch.float).sum().item() / len(batch_labels))train_loss /= train_numtrain_correct /= train_numprint("train_loss:", train_loss, "train_correct:", train_correct)test_pred = model(torch.tensor(dev_list).to(device))correct = (torch.argmax(test_pred, dim=-1) == (torch.tensor(dev_labels).to(device))).type(torch.float).sum().item() / len(test_pred)print("test_acc:",correct)print("-------------------")

在上面代码中,我们顺序建立循环神经网络模型,在使用GUR对数据进行计算后,又使用Flatten对序列embedding进行平整化处理;而最后的Linear是分类器,作用是对结果进行分类。具体结果请读者自行测试查看。

本文节选自《PyTorch语音识别实战》,获出版社和作者授权发布。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/728724.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

已解决CircuitBreakerOpenException断路器打开异常的正确解决方法,亲测有效!!!

已解决CircuitBreakerOpenException断路器打开异常的正确解决方法,亲测有效!!! 目录 问题分析 报错原因 解决思路 解决方法 总结 问题分析 在微服务架构中,为了提高系统的可用性和稳定性,通常会使用…

【前端】尚硅谷Promise

文章目录 【前端目录贴】

Python3统计json格式文件中各个key对应值出现的频次

需求: 现在有一个文件query_demo.json,里面的数据格式示例如下,query字段代表的是用户的提问,现在想统计所有不同query出现的频次。 [{"query": "会议开始提醒弹窗如何开启","intention": "…

Elemenu中el-table中使用el-popover选中关闭无效解决办法

主要是技术太菜,没找到原因,一点点才找到这个办法解决 因为在el-table-column里,因为是多行,使用trigger"manual" 时,用v-model"visible"来控制时,控件找不到这个值,才换成trigger"click" 先找到弹出关闭事件,再找元素的属性 右键>审核元素…

JS算法总结

1 找出字符串的可整除数组 一个整数可表示为 a10b; 常用取模公式:(a10b)%m(a%m10b)%m 由上式可得:当前字符串所代表的数字,能否被 m 整除。等价于前一个字符串与 m 的余数当前数字的和,这个和能否被 m 整除。 例如&…

关于比特币的AI对话

【ChatGPT】 比特币源码开源吗? 是的,比特币的源码是开源的。比特币项目是在MIT许可证下发布的,这意味着任何人都可以查看、修改、贡献和分发代码。比特币的源码托管在GitHub上,可以通过下面的链接进行访问: https://g…

【深度优先】【图论】【C++算法】2045. 到达目的地的第二短时间

作者推荐 视频算法专题 LeetCode2045. 到达目的地的第二短时间 城市用一个 双向连通 图表示,图中有 n 个节点,从 1 到 n 编号(包含 1 和 n)。图中的边用一个二维整数数组 edges 表示,其中每个 edges[i] [ui, vi] 表…

EVE-NG桥接虚拟网卡实现与虚拟机通讯

一、知识补充 1、VMware网络连接 在VM中,给我们提供了以下几种连接网络的模式 桥接模式:直接联机物理网络NAT模式:用于共享主机的IP地址仅主机模式:与主机共享的专用网络自定义:特定虚拟网络LAN区段 特别注意的是&am…

【计算机系统】2.进程管理

【计算机系统】2.进程管理 这个章节十分的重要,作业也要好好做,因为我学的是后端,学计算机进程的处理对于搞并发来说十分有用。 提出问题 6、试从动态性、并发性和独立性上比较进程和程序。19、为什么要在OS中引入线程?A.请用信号量解决以下…

Unity3D 实现大世界地图的技术原理详解

前言 Unity3D是一款非常强大的游戏引擎,可以用于创建各种类型的游戏,包括大世界地图。在这篇文章中,我们将详细介绍如何使用Unity3D实现大世界地图,并给出相应的技术原理和代码实现。 对惹,这里有一个游戏开发交流小…

代码随想录Day23 | Leetcode93 复原 IP 地址、Leetcode78 子集 | Leetcode90 子集II

上题 93. 复原 IP 地址 - 力扣(LeetCode) 78. 子集 - 力扣(LeetCode) 90. 子集 II - 力扣(LeetCode) 第一题 有效 IP 地址 正好由四个整数(每个整数位于 0 到 255 之间组成,且不…

Linux每日练习

第一部分 1.打开桌面的主文件夹,在图片文件夹下新建一个名为111的文件夹,在视频文件夹下创建一个名为222的文件夹 [rootxcz7 desk]# mkdir -p ./pic/111 [rootxcz7 desk]# mkdir -p ./video/2222.在桌面打开终端,先切换到根目录下&#xff…

NineData与OceanBase完成产品兼容认证,共筑企业级数据库新生态

近日,云原生智能数据管理平台 NineData 和北京奥星贝斯科技有限公司的 OceanBase 数据库完成产品兼容互认证。经过严格的联合测试,双方软件完全相互兼容、功能完善、整体运行稳定且性能表现优异。 此次 NineData 与 OceanBase 完成产品兼容认证&#xf…

【你也能从零基础学会网站开发】Web建站之HTML+CSS入门篇 传统布局和Web标准布局的区别

🚀 个人主页 极客小俊 ✍🏻 作者简介:web开发者、设计师、技术分享 🐋 希望大家多多支持, 我们一起学习和进步! 🏅 欢迎评论 ❤️点赞💬评论 📂收藏 📂加关注 传统布局与…

【机器学习】包裹式特征选择之基于遗传算法的特征选择

🎈个人主页:豌豆射手^ 🎉欢迎 👍点赞✍评论⭐收藏 🤗收录专栏:机器学习 🤝希望本文对您有所裨益,如有不足之处,欢迎在评论区提出指正,让我们共同学习、交流进…

微信小程序开发系列(二十二)·wxml语法·双向数据绑定model:的用法

目录 1. 单向数据绑定 2. 双向数据绑定 3. 代码 在 WXML 中&#xff0c;普通属性的绑定是单向的&#xff0c;例如&#xff1a;<input value"((value))"/> 如果希望用户输入数据的同时改变 data 中的数据&#xff0c;可以借助简易双向绑定机制。在对应属性…

STM32day2

1.思维导图 个人暂时的学后感&#xff0c;不一定对&#xff0c;没什么东西&#xff0c;为做项目奔波中。。。1.使用ADC采样光敏电阻数值&#xff0c;如何根据这个数值调节LED灯亮度。 while (1){/* USER CODE END WHILE *//* USER CODE BEGIN 3 */adc_val HAL_ADC_GetValue(&a…

hive 数据库用户权限授权

CREATE ROLE cz20240304; GRANT cz20240304_role TO USER cz20240304; grant select on table secured_t to role cz20240304_role;hive用户角色授权官网超链接

开源分子对接程序rDock使用方法(1)-Docking in 3 steps

欢迎浏览我的CSND博客&#xff01; Blockbuater_drug …点击进入 文章目录 前言一、Docking in 3 steps 标准对接rDock 的基本对接步骤及注意事项 二、 三步对接案例Step 1. 结构文件准备Step 2. 产生对接位点Step 3. 运行分子对接3.1 检查输入文件3.2 测试-只进行打分3.3 运行…

【数据结构】二、线性表:6.顺序表和链表的对比不同(从数据结构三要素讨论:逻辑结构、物理结构(存储结构)、数据运算(基本操作))

文章目录 6.对比&#xff1a;顺序表&链表6.1逻辑结构6.2物理结构&#xff08;存储结构&#xff09;6.2.1顺序表6.2.2链表 6.3数据运算&#xff08;基本操作&#xff09;6.3.1初始化6.3.2销毁表6.3.3插入、删除6.3.4查找 6.对比&#xff1a;顺序表&链表 6.1逻辑结构 顺…