RNN实战

本主要是利用RNN做多分类任务,在熟悉RNN训练的过程中,我们可以理解
1)超参数 batch_size和pad_size对训练过程的影响。
2)文本处理过程中是如何将文本的文字表示转化为向量表示
3)RNN梯度消失和序列长度的关系
4)利用pytorch如何训练一个网络模型以及保存和加载
5)理解多分类任务中的混淆矩阵

数据集HUCNews中抽取了20万条新闻标题,文本长度在20到30之间。一共10个类别,每类2万条。
类别:财经、房产、股票、教育、科技、社会、时政、体育、游戏、娱乐。

数据集划分

数据集数据量
训练集18万
验证集1万
测试集1万

重要参数如下

self.dropout = 0.3  # 随机失活
self.num_epochs = 7  # epoch数
self.batch_size = 256  # batch size
self.pad_size = 7  # 每句话处理成的长度(短填长切)
self.learning_rate = 1e-3  # 学习率
self.hidden_size = 128  # rnn隐藏层
self.num_layers = 2  # rnn层数,注意RNN中的层数必须大于1,dropout才会生效

RNN.py 模型文件,主要是配置文件和RNN网络模型定义。

# coding: UTF-8
import torch
import torch.nn as nn
import numpy as npclass Config(object):"""配置参数"""def __init__(self, dataset, embedding):self.model_name = 'RNN'self.train_path = dataset + '/data/train.txt'  # 训练集self.dev_path = dataset + '/data/dev.txt'  # 验证集self.test_path = dataset + '/data/test.txt'  # 测试集self.class_list = [x.strip() for x in open(dataset + '/data/class.txt', encoding='utf-8').readlines()]  # 类别名单self.vocab_path = dataset + '/data/vocab.pkl'  # 词表self.save_path = dataset + '/saved_dict/' + self.model_name + 'ckpt'  # 模型训练结果self.log_path = dataset + '/log/' + self.model_nameself.embedding_pretrained = torch.tensor(np.load(dataset + '/data/' + embedding)["embeddings"].astype('float32')) \if embedding != 'random' else None  # 预训练词向量self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')  # 设备self.dropout = 0.3  # 随机失活self.require_improvement = 10000  # 若超过10000batch效果还没提升,则提前结束训练self.num_classes = len(self.class_list)  # 类别数self.n_vocab = 0  # 词表大小,在运行时赋值self.num_epochs = 7  # epoch数self.batch_size = 256  # batch sizeself.pad_size = 7  # 每句话处理成的长度(短填长切)self.learning_rate = 1e-3  # 学习率self.embed = self.embedding_pretrained.size(1) \if self.embedding_pretrained is not None else 300  # 字向量维度, 若使用了预训练词向量,则维度统一self.hidden_size = 128  # rnn隐藏层self.num_layers = 2  # rnn层数,注意RNN中的层数必须大于1,dropout才会生效class Model(nn.Module):def __init__(self, config):super(Model, self).__init__()if config.embedding_pretrained is not None:self.embedding = nn.Embedding.from_pretrained(config.embedding_pretrained, freeze=False)else:self.embedding = nn.Embedding(config.n_vocab, config.embed, padding_idx=config.n_vocab - 1)self.rnn = nn.RNN(config.embed, config.hidden_size, config.num_layers,batch_first=True, dropout=config.dropout)self.fc = nn.Linear(config.hidden_size, config.num_classes)def forward(self, x):# 将原始数据转化成密集向量表示 [batch_size, seq_len, embedding]out = self.embedding(x[0])out, hidden_ = self.rnn(out)# out[:, -1, :] seq_len最后时刻的输出等价 hidden_out = self.fc(out[:, -1, :])return out

run_rnn.py文件,主程序入口,指定运行参数以及文本加载过程,最后调用train_eval.py的train函数进行模型训练。

import time
import torch
import numpy as np
from train_eval import train, init_network
from importlib import import_module
import argparse
from utils import build_dataset, build_iterator, get_time_difparser = argparse.ArgumentParser(description='Chinese Text Classification')
parser.add_argument('--model', default='RNN', type=str, required=True)
parser.add_argument('--embedding', default='pre_trained', type=str, help='random or pre_trained')
parser.add_argument('--word', default=False, type=bool, help='True for word, False for char')
args = parser.parse_args()if __name__ == '__main__':dataset = 'THUCNews'  # 数据集# 搜狗新闻:embedding_SougouNews.npz, 腾讯:embedding_Tencent.npz, 随机初始化:randomembedding = 'embedding_SougouNews.npz'if args.embedding == 'random':embedding = 'random'model_name = args.modelx = import_module('models.' + model_name)config = x.Config(dataset, embedding)np.random.seed(1)torch.manual_seed(1)torch.cuda.manual_seed_all(1)torch.backends.cudnn.deterministic = Truestart_time = time.time()print("Loading data...")# args.word 分词方式, True是词级别,默认是Falsevocab, train_data, dev_data, test_data = build_dataset(config, args.word)# build_iterator返回格式 [([词/字在词典中的位置] ,label, len(word)), ...]train_iter = build_iterator(train_data, config)dev_iter = build_iterator(dev_data, config)test_iter = build_iterator(test_data, config)time_dif = get_time_dif(start_time)print("Time usage:", time_dif)# len(vocab)="<PAD>", len(vocab) -1 ="<UNK>"config.n_vocab = len(vocab)model = x.Model(config).to(config.device)init_network(model)print(model.parameters)train(config, model, train_iter, dev_iter, test_iter)

train_eval.py 文件,主要对模型参数进行初始化,函数train主要是从自定义迭代器中加载数据进行训练。test函数是在模型训练完后对测试数据集进行测试。evaluate函数主要是在训练过程中对验证集数据进行验证。

# coding: UTF-8
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn import metrics
import time
from utils import get_time_dif
from tensorboardX import SummaryWriter
import matplotlib.pyplot as plt# 权重初始化,默认xavier
def init_network(model, method='xavier', exclude='embedding', seed=123):for name, w in model.named_parameters():if exclude not in name:if 'weight' in name:if method == 'xavier':nn.init.xavier_normal_(w)elif method == 'kaiming':nn.init.kaiming_normal_(w)else:nn.init.normal_(w)elif 'bias' in name:nn.init.constant_(w, 0)else:passdef train(config, model, train_iter, dev_iter, test_iter):loss_list = []start_time = time.time()model.train()optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)total_batch = 0  # 记录进行到多少batchdev_best_loss = float('inf')last_improve = 0  # 记录上次验证集loss下降的batch数flag = False  # 记录是否很久没有效果提升writer = SummaryWriter(log_dir=config.log_path + '/' + time.strftime('%m-%d_%H.%M', time.localtime()))# dev_acc_list = []# dev_loss_list = []for epoch in range(config.num_epochs):print('Epoch [{}/{}]'.format(epoch + 1, config.num_epochs))for i, (trains, labels) in enumerate(train_iter):outputs = model(trains)# 打印tensor的所有数据# torch.set_printoptions(threshold=float('inf'))model.zero_grad()loss = F.cross_entropy(outputs, labels)loss_list.append(loss.detach().numpy())loss.backward()optimizer.step()if total_batch % 100 == 0:true = labels.data.cpu()# 取出每一行最大的那个概率的索引值predic = torch.max(outputs.data, 1)[1].cpu()train_acc = metrics.accuracy_score(true, predic)dev_acc, dev_loss = evaluate(config, model, dev_iter)# dev_acc_list.append(dev_acc)# dev_loss_list.append(dev_loss)if dev_loss < dev_best_loss:dev_best_loss = dev_losstorch.save(model.state_dict(), config.save_path)improve = '*'last_improve = total_batchelse:improve = ''time_dif = get_time_dif(start_time)msg = 'Iter: {0:>6},  Train Loss: {1:>5.2},  Train Acc: {2:>6.2%},  Val Loss: {3:>5.2},  Val Acc: {4:>6.2%},  Time: {5} {6}'print(msg.format(total_batch, loss.item(), train_acc, dev_loss, dev_acc, time_dif, improve))writer.add_scalar("loss/train", loss.item(), total_batch)writer.add_scalar("loss/dev", dev_loss, total_batch)writer.add_scalar("acc/train", train_acc, total_batch)writer.add_scalar("acc/dev", dev_acc, total_batch)model.train()total_batch += 1if total_batch - last_improve > config.require_improvement:# 验证集loss超过10000batch没下降,结束训练print("No optimization for a long time, auto-stopping...")flag = Truebreakif flag:breakwriter.close()size = len(loss_list)x_axis = [i for i in range(0, size)]plt.plot(x_axis, loss_list, color='red')plt.show()test(config, model, test_iter)def test(config, model, test_iter):model.load_state_dict(torch.load(config.save_path))model.eval()start_time = time.time()test_acc, test_loss, test_report, test_confusion = evaluate(config, model, test_iter, test=True)msg = 'Test Loss: {0:>5.2},  Test Acc: {1:>6.2%}'print(msg.format(test_loss, test_acc))print("Precision, Recall and F1-Score...")print(test_report)print("Confusion Matrix...")print(test_confusion)time_dif = get_time_dif(start_time)print("Time usage:", time_dif)def evaluate(config, model, data_iter, test=False):model.eval()loss_total = 0predict_all = np.array([], dtype=int)labels_all = np.array([], dtype=int)# 模型评估的时候无梯度模式with torch.no_grad():for texts, labels in data_iter:outputs = model(texts)loss = F.cross_entropy(outputs, labels)loss_total += losslabels = labels.data.cpu().numpy()predict = torch.max(outputs.data, 1)[1].cpu().numpy()labels_all = np.append(labels_all, labels)predict_all = np.append(predict_all, predict)acc = metrics.accuracy_score(labels_all, predict_all)if test:report = metrics.classification_report(labels_all, predict_all, target_names=config.class_list, digits=4)confusion = metrics.confusion_matrix(labels_all, predict_all)return acc, loss_total / len(data_iter), report, confusion# 用于训练过程中的验证return acc, loss_total / len(data_iter)

model_test.py 是单个文本的推理文件。

utils.py定义了加载数据集函数load_dataset,自定义迭代器将数据转化为tensor格式便于输入到模型。

完整代码github地址

项目结构清晰以后我们主要要记录一下,RNN训练过程中遇到的一些问题,尽管现在已经不怎么使用RNN网络模型了,不过这不影响RNN在时序网络中的地位(LSTM 长短时记忆网络、GRU门控循环单元都是RNN的优化)我们还是有必要好好认识一下RNN的训练过程,以及超参数对损失值的影响。

我们主要参数设置如下,我们只对batch_size和pad_size进行修改看一下模型的损失下降曲线。

self.dropout = 0.3  # 随机失活
self.require_improvement = 10000  # 若超过10000batch效果还没提升,则提前结束训练
self.num_classes = len(self.class_list)  # 类别数
self.n_vocab = 0  # 词表大小,在运行时赋值
self.num_epochs = 7  # epoch数
self.batch_size = 64  # batch size
self.pad_size = 32  # 每句话处理成的长度(短填长切)
self.learning_rate = 1e-3  # 学习率
self.embed = self.embedding_pretrained.size(1) \
if self.embedding_pretrained is not None else 300  # 字向量维度
self.hidden_size = 128  # rnn隐藏层
self.num_layers = 2  # rnn层数,注意RNN中的层数必须大于1,dropout才会生效

batch_size = 64 pad_size = 32 learning_rate = 1e-3

训练过程
在这里插入图片描述

损失函数结果图,可以看出根本就不收敛,pad_size值过大,可能出现出现梯度消失,导致模型参数根本就不更新。

在这里插入图片描述
batch_size = 64 pad_size = 16 learning_rate = 1e-3

训练过程
在这里插入图片描述
从这里足以感性的理解为什么很多人说RNN携带的时序信息走不远,当我们将时序长度pad_size设置16时(其他参数不变)可以看到验证数据集的准确度和损失都还不错的,比pad_size=32要好很多,至少可以知道模型的参数是在更新,且损失值也有下降的趋势。
在这里插入图片描述
混淆矩阵也还可以。 混淆矩阵参考


以上是文本序列长度pad_size对RNN训练的影响。现在我们来看下batch_size大小对RNN训练的影响。为了让模型收敛pad_szie统一取16

batch_size = 128 pad_size = 16 learning_rate = 1e-3

训练过程
在这里插入图片描述

batch_size变大为128更新次数少,每一次迭代考虑的样本更多。每次迭代考虑的样本大了以后,梯度优化的波动变小,下降更平滑。相比batch_size=64,损失图像下下降确实更平滑。混淆矩阵无太大差异。
在这里插入图片描述
batch_size = 256 pad_size = 16 learning_rate = 1e-3

训练过程

在这里插入图片描述
batch_size=256损失值下降更平滑,收敛速度更快,batch_size=64时训练时长在18min左右,而此参数下训练时长仅要5min左右。
在这里插入图片描述
batch_size = 1024 pad_size = 16 learning_rate = 1e-3

训练过程


batch_size=1024时收敛速度更快,而此参数下训练时长仅要2min左右。
在这里插入图片描述

混淆矩阵,可以看出在显存足够大的情况下适当增大batch_size可以达到两点效果1)加快训练的收敛的速度 2)梯度优化的波动减小,收敛过程更加平滑。

在这里插入图片描述

至此我们已经完成了RNN训练中两个比较重要的超参数batch_size和pad_size对训练过程的影响。还有很多其他的超参数这里就不实验了。

pad_size由32变成16时候,显然只用到了一半的数据信息,无论怎么进行超参数的优化都不可能达到最好的结果。如果使用32又会出现梯度消失,从而模型不收敛。LSTM模型就有效的改进了这个缺陷。下一篇文章我们使用同样的超参数和数据集构造一个LSTM模型实验这个改进有多大。

参考
https://github.com/649453932/Chinese-Text-Classification-Pytorch

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/738519.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Bugku---web---source

1.题目描述 2.点开链接&#xff0c;查看源码&#xff0c;发现了一个长得很像flag的flag&#xff0c;拿去base64解码&#xff0c;发现不是flag 3.没办法只能在kali里面扫描一下目录 4.发现是.git源码泄露&#xff0c;在浏览器尝试访问/.git,发现点开文件并不能看到源码 5.在kali…

06多表查询

多表查询 多表查询&#xff0c;也称为关联查询&#xff0c;指两个或更多个表一起完成查询操作。前提条件&#xff1a;这些一起查询的表之间是有关系的&#xff08;一对一、一对多&#xff09;&#xff0c;它们之间一定是有关联字段&#xff0c;这个 关联字段可能建立了外键&am…

【Hadoop大数据技术】——HDFS分布式文件系统(学习笔记)

&#x1f4d6; 前言&#xff1a;Hadoop的核心是HDFS&#xff08;Hadoop Distributed File System&#xff0c;Hadoop分布式文件系统&#xff09;和MapReduce。其中&#xff0c;HDFS是解决海量大数据文件存储的问题&#xff0c;是目前应用最广泛的分布式文件系统。 目录 &#x…

稀碎从零算法笔记Day15-LeetCode:判断子序列

跑样例的时候LC炸了&#xff0c;以为今天回断更 题型&#xff1a;字符串、双指针 链接&#xff1a;392. 判断子序列 - 力扣&#xff08;LeetCode&#xff09; 来源&#xff1a;LeetCode 题目描述&#xff08;此题建议结合样例理解&#xff09; 给定字符串 s 和 t &#xf…

冰蝎的原理与安装使用

冰蝎的原理与安装使用 1、冰蝎原理 1.1简介 冰蝎是一款基于Java开发的动态加密通信流量的新型Webshell客户端&#xff0c;由于通信流量被加密&#xff0c;传统的WAF、IDS 设备难以检测&#xff0c;给威胁狩猎带来较大挑战。冰蝎其最大特点就是对交互流量进行对称加密&#x…

JVM 面试——G1和ZGC的区别

ZGC是一款JDK 11中新加入的具有实验性质的低延迟垃圾收集器ZGC的目标主要有4个 支持TB量级的堆。我们生产环境的硬盘还没有上TB呢&#xff0c;这应该可以满足未来十年内&#xff0c;所有JAVA应用的需求了吧。最大GC停顿时间不超10ms。目前一般线上环境运行良好的JAVA应用Minor …

【前端寻宝之路】学习和使用CSS的所有选择器

&#x1f308;个人主页: Aileen_0v0 &#x1f525;热门专栏: 华为鸿蒙系统学习|计算机网络|数据结构与算法|MySQL| ​&#x1f4ab;个人格言:“没有罗马,那就自己创造罗马~” #mermaid-svg-blSAMs8NTfBKaPl8 {font-family:"trebuchet ms",verdana,arial,sans-serif;f…

如何选择AI项目:从任务自动化到社会价值的全面考虑

目录 前言1 任务自动化的首要选择1.1 公司痛点分析&#xff1a;深入挖掘潜在问题1.2 数据集的收集与大小考虑&#xff1a;确保数据质量和规模匹配 2 AI项目的商业潜力2.1 技术考察与性能目标&#xff1a;确保技术选择符合项目需求2.2 商业考虑与成本效益分析&#xff1a;全面评…

作用域链的理解(超级详细)

文章目录 一、作用域全局作用域函数作用域块级作用域 二、词法作用域三、作用域链 一、作用域 作用域&#xff0c;即变量&#xff08;变量作用域又称上下文&#xff09;和函数生效&#xff08;能被访问&#xff09;的区域或集合 换句话说&#xff0c;作用域决定了代码区块中变…

Spring之注入模型

前言 之前我写过一篇关于BeanDefinition的文章,讲述了各个属性的作用,其中有一个属性我没有提到,因为这个属性比较重要,所以这里单独开一篇文章来说明 上一篇博文链接Spring之BeanDefinitionhttps://blog.csdn.net/qq_38257958/article/details/134823169?spm1001.2014.3001…

【Datawhale学习笔记】从大模型到AgentScope

从大模型到AgentScope AgentScope是一款全新的Multi-Agent框架&#xff0c;专为应用开发者打造&#xff0c;旨在提供高易用、高可靠的编程体验&#xff01; 高易用&#xff1a;AgentScope支持纯Python编程&#xff0c;提供多种语法工具实现灵活的应用流程编排&#xff0c;内置…

pc端vue2项目使用uniapp组件

项目示例下载 运行实例&#xff1a; 这是我在pc端做移动端底代码时的需求&#xff0c;只能在vue2使用&#xff0c;vue3暂时不知道怎么兼容。 安装依赖包时可能会报&#xff1a;npm install Failed to set up Chromium r756035! Set “PUPPETEER_SKIP_DOWNLOAD” env variable …

数据治理实践——金融行业大数据治理的方向与实践

目录 一、证券数据治理服务化背景 1.1 金融数据治理发展趋势 1.2 证券行业数据治理建设背景 1.3 证券行业数据治理目标 1.4 证券行业数据治理痛点 二、证券数据治理服务化实践 2.1 国信证券数据治理建设框架 2.2 国信证券数据治理建设思路 2.3 数据模型管理 2.4 数据…

ChatGPT+MATLAB应用

MatGPT是一个由chatGPT类支持的MATLAB应用程序&#xff0c;由官方Toshiaki Takeuchi开发&#xff0c;允许您轻松访问OpenAI提供的chatGPT API。作为官方发布的内容&#xff0c;可靠性较高&#xff0c;而且也是完全免费开源的&#xff0c;全程自己配置&#xff0c;无需注册码或用…

SpecAugment: A Simple Data Augmentation Method for Automatic Speech Recognition

摘要 我们提出了SpecAugment&#xff0c;这是一种用于语音识别的简单数据增强方法。SpecAugment直接应用于神经网络的特征输入&#xff08;即滤波器组系数&#xff09;。增强策略包括对特征进行变形、遮蔽频道块和遮蔽时间步块。我们在端到端语音识别任务中将SpecAugment应用于…

【SQL】601. 体育馆的人流量(with as 临时表;id减去row_number()思路)

前述 知识点学习&#xff1a; with as 和临时表的使用12、关于临时表和with as子查询部分 题目描述 leetcode题目&#xff1a;601. 体育馆的人流量 思路 关键&#xff1a;如何确定id是连续的三行或更多行记录 方法一&#xff1a; 多次连表&#xff0c;筛选查询方法二&…

vulhub中Weblogic SSRF漏洞复现

Weblogic中存在一个SSRF漏洞&#xff0c;利用该漏洞可以发送任意HTTP请求&#xff0c;进而攻击内网中redis、fastcgi等脆弱组件。 访问http://your-ip:7001/uddiexplorer/&#xff0c;无需登录即可查看uddiexplorer应用。 SSRF漏洞测试 SSRF漏洞存在于http://your-ip:7001/ud…

Python分支结构

我们刚开始写的Python代码都是一条一条语句顺序执行&#xff0c;这种代码结构通常称之为顺序结构。 然而仅有顺序结构并不能解决所有的问题&#xff0c;比如我们设计一个游戏&#xff0c;游戏第一关的通关条件是玩家在一分钟内跑完全程&#xff0c;那么在完成本局游戏后&#x…

js实现导出/下载excel文件

js实现导出/下载excel文件 // response 为导出接口返回数据&#xff0c;如上图 const exportExcel (response, fileName:string) >{const blob new Blob([response.data], {type: response.headers[content-type] //使用获取的excel格式});const downloadElement documen…

mysql5.6---windows和linux安装教程和忘记密码怎么办

一、windows安装 1.完成解压 解压完成之后将其放到你喜欢的地址当中去&#xff0c;这里我默认放在了D盘&#xff0c;这是我的根目录 2.配置环境变量 我的电脑->属性->高级->环境变量->系统变量 选择PATH,在其后面添加: (注意自己的安装地址) D:\mysql-5.6.49…