实战:循环神经网络与文本内容情感分类

在传统的神经网络模型中,是从输入层到隐含层再到输出层,层与层之间是全连接的,每层之间的节点是无连接的。但是这种普通的神经网络对于很多问题却无能为力。例如,你要预测句子的下一个单词是什么,一般需要用到前面的单词,因为一个句子中前后单词并不是独立的,即一个序列当前的输出与前面的输出也有关。

具体的表现形式为网络会对前面的信息进行记忆并应用于当前输出的计算中,即隐藏层之间的节点不再是无连接的,而是有连接的,并且隐藏层的输入不仅包括输入层的输出,还包括上一时刻隐藏层的输出,这种传统的神经网络模型如图8-2所示。

8.2.1  基于循环神经网络的中文情感分类准备工作

在讲解循环神经网络的理论知识之前,最好的学习方式就是通过实例实现并运行对应的项目,本小节将带领读者完成一下循环神经网络的情感分类实战的准备工作。

1. 数据的准备

首先是数据集的准备工作。在本节中,我们需要完成的是中文数据集的情感分类,因此事先准备了一套已完成情感分类的数据集,读者可以参考本书配套代码中dataset目录下的chnSenticrop.txt文件确认一下。此时我们需要完成数据的读取和准备工作,其实现代码如下:

max_length = 80         #设置获取的文本长度为80
labels = []             #用以存放label
context = []            #用以存放汉字文本
vocab = set()           
with open("../dataset/cn/ChnSentiCorp.txt", mode="r", encoding="UTF-8") as emotion_file:for line in emotion_file.readlines():line = line.strip().split(",")# labels.append(int(line[0]))if int(line[0]) == 0:labels.append(0)    #由于在后面直接采用PyTorch自带的crossentroy函数,因此这里直接输入0,否则输入[1,0]else:labels.append(1)text = "".join(line[1:])context.append(text)for char in text: vocab.add(char)   #建立vocab和vocab编号voacb_list = list(sorted(vocab))
# print(len(voacb_list))
token_list = []
#下面是对context内容根据vocab进行token处理
for text in context:token = [voacb_list.index(char) for char in text]token = token[:max_length] + [0] * (max_length - len(token))token_list.append(token)

2. 模型的建立

接下来可以根据需求建立模型。在这里我们实现了一个带有单向GRU和一个双向GRU的循环神经网络,代码如下:

class RNNModel(torch.nn.Module):def __init__(self,vocab_size = 128):super().__init__()self.embedding_table = torch.nn.Embedding(vocab_size,embedding_dim=312)self.gru  =  torch.nn.GRU(312,256)  # 注意这里输出有两个:out与hidden,out是序列在模型运行后全部隐藏层的状态,而hidden是最后一个隐藏层的状态self.batch_norm = torch.nn.LayerNorm(256,256)self.gru2  =  torch.nn.GRU(256,128,bidirectional=True)  # 注意这里输出有两个:out与hidden,out是序列在模型运行后全部隐藏层的状态,而hidden是最后一个隐藏层的状态def forward(self,token):token_inputs = tokenembedding = self.embedding_table(token_inputs)gru_out,_ = self.gru(embedding)embedding = self.batch_norm(gru_out)out,hidden = self.gru2(embedding)return out

这里要注意的是,对于GRU进行神经网络训练,无论是单向还是双向GUR,其结果输出都是两个隐藏层状态,即out与hidden。这里的out是序列在模型运行后全部隐藏层的状态,而hidden是此序列最后一个隐藏层的状态。

在这里我们使用的是2层GRU,有读者会注意到,在我们对第二个GRU进行定义时,使用了一个额外的参数bidirectional,这个参数用来定义循环神经网络是单向计算还是双向计算的,其具体形式如图8-3所示。

从图8-3中可以很明显地看到,左右两个连续的模块并联构成了不同方向的循环神经网络单向计算层,而这两个方向同时作用后生成了最终的隐藏层。

8.2.2  基于循环神经网络的中文情感分类

上一小节完成了循环神经网络的数据准备以及模型的建立,下面我们可以对中文数据集进行情感分类,完整的代码如下:

import numpy as npmax_length = 80         #设置获取的文本长度为80
labels = []             #用以存放label
context = []            #用以存放汉字文本
vocab = set()           with open("../dataset/cn/ChnSentiCorp.txt", mode="r", encoding="UTF-8") as emotion_file:for line in emotion_file.readlines():line = line.strip().split(",")# labels.append(int(line[0]))if int(line[0]) == 0:labels.append(0)    #由于在后面直接采用PyTorch自带的crossentroy函数,因此这里直接输入0,否则输入[1,0]else:labels.append(1)text = "".join(line[1:])context.append(text)for char in text: vocab.add(char)   #建立vocab和vocab编号voacb_list = list(sorted(vocab))
# print(len(voacb_list))
token_list = []
#下面的内容是对context根据vocab进行token处理
for text in context:token = [voacb_list.index(char) for char in text]token = token[:max_length] + [0] * (max_length - len(token))token_list.append(token)seed = 17
np.random.seed(seed);np.random.shuffle(token_list)
np.random.seed(seed);np.random.shuffle(labels)dev_list = np.array(token_list[:170])
dev_labels = np.array(labels[:170])token_list = np.array(token_list[170:])
labels = np.array(labels[170:])import torch
class RNNModel(torch.nn.Module):def __init__(self,vocab_size = 128):super().__init__()self.embedding_table = torch.nn.Embedding(vocab_size,embedding_dim=312)self.gru  =  torch.nn.GRU(312,256)  # 注意这里输出有两个:out与hidden,out是序列在模型运行后全部隐藏层的状态,而hidden是最后一个隐藏层的状态self.batch_norm = torch.nn.LayerNorm(256,256)self.gru2  =  torch.nn.GRU(256,128,bidirectional=True)  # 注意这里输出有两个:out与hidden,out是序列在模型运行后全部隐藏层的状态,而hidden是最后一个隐藏层的状态def forward(self,token):token_inputs = tokenembedding = self.embedding_table(token_inputs)gru_out,_ = self.gru(embedding)embedding = self.batch_norm(gru_out)out,hidden = self.gru2(embedding)return out#这里使用顺序模型的方式建立了训练模型
def get_model(vocab_size = len(voacb_list),max_length = max_length):model = torch.nn.Sequential(RNNModel(vocab_size),torch.nn.Flatten(),torch.nn.Linear(2 * max_length * 128,2))return modeldevice = "cuda"
model = get_model().to(device)
model = torch.compile(model)
optimizer = torch.optim.Adam(model.parameters(), lr=2e-4)loss_func = torch.nn.CrossEntropyLoss()batch_size = 128
train_length = len(labels)
for epoch in (range(21)):train_num = train_length // batch_sizetrain_loss, train_correct = 0, 0for i in (range(train_num)):start = i * batch_sizeend = (i + 1) * batch_sizebatch_input_ids = torch.tensor(token_list[start:end]).to(device)batch_labels = torch.tensor(labels[start:end]).to(device)pred = model(batch_input_ids)loss = loss_func(pred, batch_labels.type(torch.uint8))optimizer.zero_grad()loss.backward()optimizer.step()train_loss += loss.item()train_correct += ((torch.argmax(pred, dim=-1) == (batch_labels)).type(torch.float).sum().item() / len(batch_labels))train_loss /= train_numtrain_correct /= train_numprint("train_loss:", train_loss, "train_correct:", train_correct)test_pred = model(torch.tensor(dev_list).to(device))correct = (torch.argmax(test_pred, dim=-1) == (torch.tensor(dev_labels).to(device))).type(torch.float).sum().item() / len(test_pred)print("test_acc:",correct)print("-------------------")

在上面代码中,我们顺序建立循环神经网络模型,在使用GUR对数据进行计算后,又使用Flatten对序列embedding进行平整化处理;而最后的Linear是分类器,作用是对结果进行分类。具体结果请读者自行测试查看。

本文节选自《PyTorch语音识别实战》,获出版社和作者授权发布。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/728724.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Elemenu中el-table中使用el-popover选中关闭无效解决办法

主要是技术太菜,没找到原因,一点点才找到这个办法解决 因为在el-table-column里,因为是多行,使用trigger"manual" 时,用v-model"visible"来控制时,控件找不到这个值,才换成trigger"click" 先找到弹出关闭事件,再找元素的属性 右键>审核元素…

关于比特币的AI对话

【ChatGPT】 比特币源码开源吗? 是的,比特币的源码是开源的。比特币项目是在MIT许可证下发布的,这意味着任何人都可以查看、修改、贡献和分发代码。比特币的源码托管在GitHub上,可以通过下面的链接进行访问: https://g…

【深度优先】【图论】【C++算法】2045. 到达目的地的第二短时间

作者推荐 视频算法专题 LeetCode2045. 到达目的地的第二短时间 城市用一个 双向连通 图表示,图中有 n 个节点,从 1 到 n 编号(包含 1 和 n)。图中的边用一个二维整数数组 edges 表示,其中每个 edges[i] [ui, vi] 表…

EVE-NG桥接虚拟网卡实现与虚拟机通讯

一、知识补充 1、VMware网络连接 在VM中,给我们提供了以下几种连接网络的模式 桥接模式:直接联机物理网络NAT模式:用于共享主机的IP地址仅主机模式:与主机共享的专用网络自定义:特定虚拟网络LAN区段 特别注意的是&am…

【计算机系统】2.进程管理

【计算机系统】2.进程管理 这个章节十分的重要,作业也要好好做,因为我学的是后端,学计算机进程的处理对于搞并发来说十分有用。 提出问题 6、试从动态性、并发性和独立性上比较进程和程序。19、为什么要在OS中引入线程?A.请用信号量解决以下…

NineData与OceanBase完成产品兼容认证,共筑企业级数据库新生态

近日,云原生智能数据管理平台 NineData 和北京奥星贝斯科技有限公司的 OceanBase 数据库完成产品兼容互认证。经过严格的联合测试,双方软件完全相互兼容、功能完善、整体运行稳定且性能表现优异。 此次 NineData 与 OceanBase 完成产品兼容认证&#xf…

【你也能从零基础学会网站开发】Web建站之HTML+CSS入门篇 传统布局和Web标准布局的区别

🚀 个人主页 极客小俊 ✍🏻 作者简介:web开发者、设计师、技术分享 🐋 希望大家多多支持, 我们一起学习和进步! 🏅 欢迎评论 ❤️点赞💬评论 📂收藏 📂加关注 传统布局与…

【机器学习】包裹式特征选择之基于遗传算法的特征选择

🎈个人主页:豌豆射手^ 🎉欢迎 👍点赞✍评论⭐收藏 🤗收录专栏:机器学习 🤝希望本文对您有所裨益,如有不足之处,欢迎在评论区提出指正,让我们共同学习、交流进…

微信小程序开发系列(二十二)·wxml语法·双向数据绑定model:的用法

目录 1. 单向数据绑定 2. 双向数据绑定 3. 代码 在 WXML 中&#xff0c;普通属性的绑定是单向的&#xff0c;例如&#xff1a;<input value"((value))"/> 如果希望用户输入数据的同时改变 data 中的数据&#xff0c;可以借助简易双向绑定机制。在对应属性…

STM32day2

1.思维导图 个人暂时的学后感&#xff0c;不一定对&#xff0c;没什么东西&#xff0c;为做项目奔波中。。。1.使用ADC采样光敏电阻数值&#xff0c;如何根据这个数值调节LED灯亮度。 while (1){/* USER CODE END WHILE *//* USER CODE BEGIN 3 */adc_val HAL_ADC_GetValue(&a…

开源分子对接程序rDock使用方法(1)-Docking in 3 steps

欢迎浏览我的CSND博客&#xff01; Blockbuater_drug …点击进入 文章目录 前言一、Docking in 3 steps 标准对接rDock 的基本对接步骤及注意事项 二、 三步对接案例Step 1. 结构文件准备Step 2. 产生对接位点Step 3. 运行分子对接3.1 检查输入文件3.2 测试-只进行打分3.3 运行…

【数据结构】二、线性表:6.顺序表和链表的对比不同(从数据结构三要素讨论:逻辑结构、物理结构(存储结构)、数据运算(基本操作))

文章目录 6.对比&#xff1a;顺序表&链表6.1逻辑结构6.2物理结构&#xff08;存储结构&#xff09;6.2.1顺序表6.2.2链表 6.3数据运算&#xff08;基本操作&#xff09;6.3.1初始化6.3.2销毁表6.3.3插入、删除6.3.4查找 6.对比&#xff1a;顺序表&链表 6.1逻辑结构 顺…

【短时交通流量预测】基于小波神经网络WNN

课题名称&#xff1a;基于小波神经网络的短时交通流量预测 版本时间&#xff1a;2023-04-27 代码获取方式&#xff1a;QQ&#xff1a;491052175 或者 私聊博主获取 模型简介&#xff1a; 城市交通路网中交通路段上某时刻的交通流量与本路段前几个时段的交通流量有关&#x…

【嵌入式】嵌入式系统稳定性建设:静态代码扫描的稳定性提升术

1. 概述 在嵌入式系统开发过程中&#xff0c;代码的稳定性和可靠性至关重要。静态代码扫描工具作为一种自动化的代码质量检查手段&#xff0c;能够帮助开发者在编译前发现潜在的缺陷和错误&#xff0c;从而增强系统的稳定性。本文将介绍如何在嵌入式C/C开发中使用静态代码扫描…

排序算法——梳理总结

✨冒泡 ✨选择 ✨插入  ✨标准写法  &#x1f3ad;不同写法 ✨希尔排序——标准写法 ✨快排 ✨归并 ✨堆排 ✨冒泡 void Bubble(vector<int>& nums) {// 冒泡排序只能先确定最右边的结果&#xff0c;不能先确定最左边的结果for (int i 0; i < nums.size(); i){…

基于深度学习的交通标志检测识别系统(含UI界面、yolov8、Python代码、数据集)

项目介绍 项目中所用到的算法模型和数据集等信息如下&#xff1a; 算法模型&#xff1a;     yolov8 yolov8主要包含以下几种创新&#xff1a;         1. 添加注意力机制&#xff08;SE、CBAM等&#xff09;         2. 修改可变形卷积&#xff08;DySnake-主干c…

linux系统命令深入研究1——ls的参数

ls list命令有一些常用的参数&#xff0c;其中-a意为列出all全部文件&#xff08;包括隐藏文件&#xff09;&#xff0c;-l列出详细信息&#xff0c;-h以人类可阅读的方式列出文件大小 --full-time是列出详细时间信息&#xff0c;包括最后一次修改时间 -t是按时间排序&#xff…

Git 内幕探索:从底层文件系统到历史编辑的全面指南

微信搜索“好朋友乐平”关注公众号。 1. Git 底层文件对象 #mermaid-svg-uTkvyr26fNmajZ3n {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-uTkvyr26fNmajZ3n .error-icon{fill:#552222;}#mermaid-svg-uTkvyr26fNmaj…

Spark实战-基于Spark日志清洗与数据统计以及Zeppelin使用

Saprk-日志实战 一、用户行为日志 1.概念 用户每次访问网站时所有的行为日志(访问、浏览、搜索、点击)用户行为轨迹&#xff0c;流量日志2.原因 分析日志&#xff1a;网站页面访问量网站的粘性推荐3.生产渠道 (1)Nginx(2)Ajax4.日志内容 日志数据内容&#xff1a;1.访问的…

【动态规划】完全背包

欢迎来到Cefler的博客&#x1f601; &#x1f54c;博客主页&#xff1a;折纸花满衣 &#x1f3e0;个人专栏&#xff1a;题目解析 &#x1f30e;推荐文章&#xff1a;【LeetCode】winter vacation training 目录 &#x1f449;&#x1f3fb;完全背包 &#x1f449;&#x1f3fb;…