文本分类-RNN-LSTM

1.前言

        本节介绍RNN和LSTM,并采用它们在电影评论数据集上实现文本分类,会涉及以下几个知识点。

        1. 词表构建:包括数据清洗,词频统计,词频截断,词表构建。

        2. 预训练词向量应用:下载并加载Glove的预训练embedding进行训练,主要是如何把词向量放到nn.embedding层中的权重。

        3. RNN及LSTM构建:涉及nn.RNN和nn.LSTM的使用。

2.任务介绍

        本节采用的数据集是斯坦福大学的大型电影评论数据集(large movie review dataset) https://ai.stanford.edu/~amaas/data/sentiment/

        包含25000个训练样本,25000个测试样本,下载解压后得到aclImdb文件夹,aclImdb下有train和test,neg和pos下分别 有txt文件,txt中为电影评论文本。

         来看看一条具体的样本,train/pos/3_10.txt:

        本节任务就是对这样的一条文本进行处理,输出积极/消极的二分类概率向量。

3.数据模块

        文本任务与图像任务不同,输入不再是像素这样的数值,而是字符串,因此需要将字符串转为矩阵运算可接受的向量形 式。

         为此需要在数据处理模块完成以下步骤:

        a.分词:将一长串文本切分为一个个独立语义的词,英文可用空格来切分。

        b. 词嵌入:词嵌入通常分两步。首先将词字符串转为索引序号,然后索引序号根据词嵌入矩阵(embedding层)取对应的向量。其中词与索引之间的映射关系需要提前构建,这就是词表构建的过程。

        因此,代码开发整体流程:

        1. 编写分词功能函数

        2. 构建词表:对训练数据进行分词,统计词频,并构建词表。例如{'UNK': 0, 'PAD': 1, 'the': 2, '.': 3, 'and': 4, 'a': 5, 'of': 6, 'to': 7, ...}

        3. 编写PyTorch的Dataset,实现分词、词转序号、长度填充/截断序号转词向量的过程由模型的nn.Embedding层实现,因此数据模块只需将词变为索引序号即可,接下来一一解析各环节核心功能代码实现。

        序号转词向量的过程由模型的nn.Embedding层实现,因此数据模块只需将词变为索引序号即可,接下来一一解析各环节核心功能代码实现。

4.词表构建

        参考配套代码a_gen_vocabulary.py,首先编写分词功能函数,分词前做一些简单的数据清洗,例如在标点符号前加入空 格、去除掉不是大小写字母及 .!? 符号的数据。

        接着,写一个词表统计类实现词频统计,和词表字典的创建,代码注释非常详细,这里不赘述。 运行代码,即可完成词频统计,词表的构建,并保存到本地npy文件,在训练及推理过程中使用。

        在词表构建过程中有一个截断数量的超参数需要设置,这里设置为20000,即最多有20000个词的表示,不在字典中的词被归为UNK这个词。

         在这个数据集中,原始词表长度为74952,即通过split切分后,有7万多个不一样的字符串,通常可以通过降序排列,取前面一部分即可。

        代码会输出词频统计图,也可以观察出词频下降的速度以及高频词是哪些。

5.Dataset编写

        参考配套代码aclImdb_dataset.py,getitem中主要做两件事,首先获取label,然后获取文本预处理后的列表,列表中元素是词所对应的索引序号。

        在self.word2index.encode中需要注意设置文本最大长度self.max_len,这是由于需要将所有文本处理到相同长度,长度不足的用词填充,长度超出则截断。

6.模型模块——RNN

        模型的构建相对简单,理论知识在这里不介绍,需要了解和温习的推荐看看《动手学》。这里借助动手学的RNN图片讲解代码的实现。

        在构建的模型RNNTextClassifier中,需要三个子module,分别是:

                1. nn.Embedding:将词序号变为词向量,用于后续矩阵运算

                2. nn.RNN:循环神经网络的实现

                3. nn.Linear:最终分类输出层的实现

        在forward时,流程如下:

                1. 获取词向量

                2. 构建初始化隐藏层,默认为全0

                3. rnn推理获得输出层和隐藏层

                4. fc层输出分类概率:fc层的输入是rnn最后一个隐藏层

        更多关于nn.RNN的参数设置,可以参考官方文档:

        torch.nn.RNN(self, input_size, hidden_size, num_layers=1, nonlinearity='tanh', bias=True, batch_first=False, dropout=0.0, bidirectional=False, device=None, dtype=None)

7.模型模块——LSTM

        RNN是神经网络中处理时序任务最为经典的设计,但是其也存在一些缺点,例如梯度消失和梯度爆炸,以及长期依赖问 题。

        当序列很长时,RNN模型很难捕捉到远距离的依赖关系,导致模型预测不准确。

        为此,带门控机制的RNN涌现,包括GRU(Gated Recurrent Unit,门控循环单元)和LSTM(Long Short-Term Memory,长短期记忆网络),其中LSTM应用最广,这里直接跳过GRU。         LSTM模型引入了三个门(input gate、forget gate和output gate),用于控制输入、输出和遗忘的流动,允许模型有选择性地忘记或记住一些信息。

        input gate用于控制输入的流动

        forget gate用于控制遗忘的流动

        output gate用于控制输出的流动

        相较于RNN,除了输出隐藏层向量h,还输出记忆层向量c,不过对于下游使用,不需要关心向量c的存在。 同样地,借助《动手学》中的LSTM示意图来理解代码。

        在这里,借鉴《动手学》的代码,采用的LSTM为双向LSTM,这里简单介绍双向循环神经网络的概念。

         双向循环神经网络(Bidirectional Recurrent Neural Network,Bi-RNN)同时考虑前向和后向的上下文信息,前向层和后向层的输出在每个时间步骤上都被连接起来,形成了一个综合的输出,这样可以更好地捕捉序列中的上下文信息。

        在pytorch代码中,只需要将bidirectional设置为True即可,

        nn.LSTM(embed_size, num_hiddens, num_layers=num_layers, bidirectional=True)。

        当采用双向时,需要注意output矩阵的shape为 [ sequence length , batch size ,2×hidden size]

        更多关于nn.LSTM的参数设置,可以参考官方文档:torch.nn.LSTM(self, input_size, hidden_size, num_layers=1, bias=True, batch_first=False, dropout=0.0, bidirectional=False, proj_size=0, device=None, dtype=None)

        详细参考:https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html#torch.nn.LSTM

8.embedding预训练加载

        模型构建好之后,词向量的embedding层是随机初始化的,要从头训练具备一定逻辑关系的词向量表示是费时费力的, 通常可以采用在大规模预料上训练好的词向量矩阵。

        这里可以参考斯坦福大学的GloVe(Global Vectors for Word Representation)预训练词向量。

        GloVe是一种无监督学习算法,用于获取单词的向量表示,GloVe预训练词向量可以有效地捕捉单词之间的语义关系,被广泛应用于自然语言处理领域的各种任务,例如文本分类、命名实体识别和机器翻译等。

        Glove有四大类,根据数据量不同进行区分,相同数据下又根据向量长度分

        a.Wikipedia 2014 + Gigaword 5 (6B tokens, 400K vocab, uncased, 50d, 100d, 200d, & 300d vectors, 822 MB download): glove.6B.zip

        b.Common Crawl (42B tokens, 1.9M vocab, uncased, 300d vectors, 1.75 GB download): glove.42B.300d.zip

        c.Common Crawl (840B tokens, 2.2M vocab, cased, 300d vectors, 2.03 GB download): glove.840B.300d.zip

        d.Twitter (2B tweets, 27B tokens, 1.2M vocab, uncased, 25d, 50d, 100d, & 200d vectors, 1.42 GB download): glove.twitter.27B.zip

         在这里,采用Wikipedia 2014 + Gigaword 5 中的100d,即词向量长度为100,向量的token数量有6B。

        下载好的GloVe词向量矩阵是一个txt文件,一行是一个词和词向量,中间用空格隔开,因此加载该预训练词向量矩阵可以这样。

        原始GloVe预训练词向量有40万个词,在这里只关心词表中有的词,因此可以在加载字典时加一行过滤,即在词表中的词,才去获取它的词向量。

        在本案例中,词表大小是2万,根据匹配,只有19720个词在GloVe中找到了词向量,其余的词向量就需要随机初始化。

        获取GloVe预训练词向量字典后,需要把词向量放到embedding层中的矩阵,对弈embedding层来说,一行是一个词的词向量,因此通过词表的序号找到对应的行,然后把预训练词向量放进去即可,代码如下:

9.训练及实验记录

        准备好了数据和模型,接下来按照常规模型训练即可。

        这里将会做一些对比实验,包括模型对比:

         a.RNN vs LSTM

        b.有预训练词向量 vs 无预训练词向量

       c. 冻结预训练词向量 vs 放开预训练词向量

        具体指令如下,推荐放到bash文件中,一次性跑

        实验结果如下所示:

        1. RNN整体不work,经过分析发现设置的文本token长度太长,导致RNN梯度消失,以至于无法训练。调整 text_max_len为50后,train acc=0.8+, val=0.62,整体效果较差。

         2. 有了预训练词向量要比没有预训练词向量高出10多个点。

         3. 放开词向量训练,效果会好一些,但是不明显。

        补充实验:将RNN模型的文本最长token数量设置为50,其余保持不变,得到的三种embedding方式的结果如下:

        结论:

        1. LSTM较RNN在长文本处理上效果更好

        2. 预训练词向量在小样本数据集上很关键,有10多个点的提升

        3. 放开与冻结embedding层训练,效果差不多

10.小结

        本小节通过电影影评数据集实现文本分类任务,通过该任务可以了解:

        1. 文本预处理机制:包括清洗、分词、词频统计、词表构建、词表截断、UNK与PAD特殊词设定等。

        2. 预训练词向量使用:包括GloVe的下载及加载、nn.embedding层的设置 。

        3. RNN系列网络模型使用:大致了解循环神经网络的输入/输出是如何构建,如何配合fc层实现文本分类。

         4. RNN可接收的文本长度有限:文本过长,导致梯度消失,文本过短,导致无法捕获更多文本信息,因此推荐采用 LSTM等门控机制的模型。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/36253.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

鸿蒙星河NEXT学习笔记

1.1 字符串 // 变量的存储和修改(string number boolean) // 1. 变量存储 // 1.1 字符串 string 类型 // 注意点1:字符串需要用引号引起来(单引双引号)字符串 "字符串" // 注意点2:存储的时候&a…

Elasticsearch开启认证|为ES设置账号密码|ES账号密码设置|ES单机开启认证|ES集群开启认证

文章目录 前言单节点模式开启认证生成节点证书修改ES配置文件为内置账号添加密码Kibana修改配置验证 ES集群开启认证验证 前言 ES安装完成并运行,默认情况下是允许任何用户访问的,这样并不安全,可以为ES开启认证,设置账号密码。 …

FPGA无网络芯片实现千兆TCP/IP协议栈,基于1G/2.5G Ethernet PCS/PMA or SGMII方案,提供18套工程源码和技术支持

目录 1、前言工程概述免责声明 2、相关方案推荐我这里已有的以太网方案网络芯片版本-->千兆网 TCP-->服务器 方案网络芯片版本-->千兆网 TCP-->客户端 方案10G 万兆网 TCP-->服务器客户端 方案1G/2.5G Ethernet PCS/PMA or SGMII 方案AXI 1G/2.5G Ethernet Subs…

AI大模型安全挑战和安全要求解读

引言 随着人工智能技术的飞速发展,大模型技术以其卓越的性能和广泛的应用前景,正在重塑人工智能领域的新格局。然而,任何技术都有两面性,大模型在带来前所未有便利的同时,也引发了深刻的安全和伦理挑战。 大模型&…

vscode安装lean4

本教程演示在Windows系统下如何安装Lean 4正式版。Linux和MacOS版本请参考Lean Manual。 如果你身在中国,在运行安装程序前需要做如下准备: 在系统目录C:\Windows\System32\drivers\etc文件夹下找到hosts文件。对于其它系统用户也都是找到各自系统的host…

字节码编程ASM之两数之和

写在前面 源码 。 看下如何使用ASM来写如下的类: package com.dahuyou.demo.asm;public class AsmSumOfTwo {public AsmSumOfTwo() {}public static void main(String[] var0) {int var1 (new AsmSumOfTwo()).sum(1, 2);System.out.println(var1);}public int su…

Batch学习及应用案例

一、介绍 Batch是一种Windows操作系统中使用的批处理脚本语言,用于自动化执行一系列命令和操作。通过编写批处理脚本,可以实现自动化完成重复性或繁琐的任务,提高工作效率。 Batch脚本可以使用内置的命令和命令行工具,以及调用其…

使用飞书多维表格实现推送邮件

一、为什么用飞书? 在当今竞争激烈的商业环境中,选择一款高效、智能的办公工具至关重要。了解飞书的朋友应该都知道,飞书的集成能力是很强大的,能够与各种主流的办公软件无缝衔接,实现数据交互,提升工作效…

竞赛选题 python区块链实现 - proof of work工作量证明共识算法

文章目录 0 前言1 区块链基础1.1 比特币内部结构1.2 实现的区块链数据结构1.3 注意点1.4 区块链的核心-工作量证明算法1.4.1 拜占庭将军问题1.4.2 解决办法1.4.3 代码实现 2 快速实现一个区块链2.1 什么是区块链2.2 一个完整的快包含什么2.3 什么是挖矿2.4 工作量证明算法&…

vue3中通过vditor插件实现自定义上传图片、录入echarts、脑图、markdown语法的编辑器

1、下载Vditor插件 npm i vditor 我的vditor版本是3.10.2,大家可以自行选择下载最新版本 官网:Vditor 一款浏览器端的 Markdown 编辑器,支持所见即所得(富文本)、即时渲染(类似 Typora)和分屏 …

消息队列选型之 Kafka vs RabbitMQ

在面对众多的消息队列时,我们往往会陷入选择的困境:“消息队列那么多,该怎么选啊?Kafka 和 RabbitMQ 比较好用,用哪个更好呢?”想必大家也曾有过类似的疑问。对此本文将在接下来的内容中以 Kafka 和 Rabbit…

阿里AI-Spring Cloud Alibaba AI:快速搭建自己的通义千问

本文基于官方文档。 Spring AI 官方文档:Spring AI :: Spring AI Reference 中文文档:Spring AI 简介 - spring 中文网 (springdoc.cn) Spring AI 是 Spring 官方社区项目,旨在简化 Java AI 应用程序开发,让 Java 开发者像使用…

流光卡片,生成炫酷文字,开源API

https://fireflycard.shushiai.com/ 这是我的一个网站,流光卡片,主要功能是帮你制作酷炫的文字卡片,用精美的卡片让你的文字生动起来。 展示效果如下: 你可以用它制作卡片,来记录自己的表达。支持设定卡片背景、LOGO…

梗图生成器突然爆红;ElevenLabs发布IOS APP 高质量语音朗读手机各种文本内容;开源工作流架构ControlFlow

✨ 1: 梗图生成器 fabianstelzer 在Glif做的一个超强meme生成器 Glif 是一个工作流,能生成文字图片和视频,用工作流的形式可以完成很多的花样来。 最近爆红的梗图生成器,WOJAK MEME GENERATOR ,也是用工作流的形式来生成这些有…

宝塔面板之 wwwroot修改不了权限

宝塔使用Apache环境,搭建网站出现 You don’t have permission to access this resource.Server unable to read h出错时的解决办法 今天由于某些原因导致我宝塔 在Apache和Nginx运行环境下不断切换,结果我网站全部不能正常打不开了 结果我发现原本宝塔…

boss直聘招聘数据可视化分析

boss直聘招聘数据可视化分析 一、数据预处理二、数据可视化三、完整代码一、数据预处理 在 上一篇博客中,笔者已经详细介绍了使用selenium爬取南昌市web前端工程师的招聘岗位数据,数据格式如下: 这里主要对薪水列进行处理,为方便处理,将日薪和周薪的数据删除,将带有13薪…

自媒体内容创作者必备:ChatGPT助你提升文章质量

随着自媒体的迅猛发展,越来越多的人加入到内容创作的行列。然而,要在这个竞争激烈的领域脱颖而出,不仅需要创意和独特的观点,更需要高质量的文章内容。在这方面,ChatGPT作为一个智能写作助手,能够帮助自媒体…

靠!AI绘画月入过万!是否现实?

前言 AI人工智能已经出现在了越来越多领域中,比如最近一段时间,AI绘画就受到了许多人的关注,一来,其背后隐藏的版权问题、替代性问题引发了人们的广泛讨论,再者,AI绘画在短期时间内成为了流量密码&#xf…

暑假追高必备:ChildLife全新钙镁锌小绿钙

2024年暑假将至,家长们对于孩子的健康关注再次提升,其中补钙成为许多家长关注的重点。暑假期间,孩子有更多时间进行户外活动,加上高温流汗多,身体的钙更容易流失,因此需要额外地补充。为此,美国…

工业AIoT竞赛流程

不要点到重置!!!要刷新虚拟机就点重启 xshell连接虚拟机:ssh rootPublic IP 环境构建 vim /etc/hosts 按 i 进入插入模式,加内网ip和主机名,按esc,按 : ,按wq 三个虚拟机都这样配 …