【深度学习--RNN 循环神经网络--附LSTM情感文本分类】

deep learning 系列 --RNN 循环神经网络

什么是序列模型

包括了RNN LSTM GRU等网络模型,主要用途是自然语言处理、语音识别等方面,比如生成乐曲,音频转换为文字,文本情感分类,机器翻译等等

标准模型的缺陷

以往的标准模型比如CNN,每次的输入不影响下次的输出,也就是说每次输入的图片都是独立的,没有任何关联,但是很多情况下,我们建立的模型与前项甚至后项的输入是相关的。举个例子,我们要从这两句话中识别出人名:

President Teddy was …
Teddy bear was…

这其中都有一个关键且同样的词Teddy,但是这个Teddy可以是个人名,也可以是泰迪,究竟是哪个呢?通常的识别方法是:

1.从之前学习过的人名识别

我们之前的位置可能提取过Teddy是人名, 但CNN网络,并不能共享从不同位置提取到到的特征,因此不可行

2.从本次输入的上下文出发

比如这里下文有 bear,President,但是CNN也不具备这种序列特性,因此也不可行

另外,CNN的网络对于输入的模型的数据长度都是固定的,但是不同的句子长度并不一致,当然我们可以padding,但是不那么好。
因此才有了循环神经网络架构,它可以克服以上的这些问题

循环神经网络

架构

在这里插入图片描述
图示是循环神经网络架构, 是有循环的,这个图形是最长用于展示的,但实际并不好理解,这个环实际上是多次的输入和输出,下面这种展开的方式更容易理解
在这里插入图片描述
RNN的结构就包括三类输入的权重参数
1.激活值的参数,水平方向的值,每个时间步的激活参数是相同的
2.输入到隐藏层的参数,也就是x到A方向,这个也是每个时间步相同
3.用来预测输出的参数,A到h的方向
在这里插入图片描述
其中y的输出可以用如下表示:
在这里插入图片描述

在这里插入图片描述
预测值y<t>包括了激活值a<t>,而a<t>包括了x<t-1>(注意此处t表示时间步),也就是说每一次的预测输出不仅包括了本次x<t>也包括了上个时间步x<t-1>,依次类推,此次时间步的输出包括了之前所有输入

BPTT 反向传播

实际上在工具中,反向传播都是自动进行的,原理跟普通的反向传播一致
输入序列后,假设预测的值是0.9,但实际是1,产生损失,这个损失可以用交叉熵损失函数来估量
RNN中的反向传播时将之前所有的箭头都反过来,计算出合适的变量,通过导数相关的计算,利用梯度下降算法更新参数,也就是图示:
在这里插入图片描述
这个传播过程中最重要的就是水平方向时间步的反向,因此又叫做穿越时间的反向传播Back Propagate Through time

不同类型的循环神经网络

  • 多对多
    比如机器翻译,输入是多个,输出也是多个,且并不对等,此时经常使用的是encode-decoder,encoder编码器获取输入,并输出,而decoder则使用encode编码的输出,执行decoder,这样输入x和输出y的长度就可以不相同了
    在这里插入图片描述
    当然输入和输出对等的情况也很多
    比如在句子中寻找人名,预测输出y就可以是每个x的对应的每个位置输出y(0表示非人名,1表示是人名)

  • 多对一
    在这里插入图片描述
    比如进行文本的情感分类,这样输入就有可能是一段文本,而输出我们只需要最后一层最后一个时间步的输出即可,一个0和1就足以标识这段文本是positive还是negative

  • 一对多
    在这里插入图片描述
    比如生成类的,输入一个音符或者不输入,就可以产生多个输出

RNN的缺陷 梯度消失和梯度爆炸

  • 现象:

    实际上深度比较大的网络都可能梯度消失或者爆炸,这种现象在在RNN中更加明显
    当我们输入的序列为1000的时候,拿最简单的模型举例 y = wx 经过1000次的传播,y1000 的变化
    在这里插入图片描述
    w仅仅变化一点,经过1000次的传播,变化非常的大

  • 原因:

    发生这样的根本原因是RNN中每一次的输出都将被前面的数据彻底的清洗,而经过长时间的传输,很前面步的影响都后面的影响已经很微弱了,损失的反向调整同样也是,经过长距离的调整,差错已经很难反馈很多个时间步之后了

  • 解决办法:

    梯度爆炸:进行梯度裁剪即可,比如我们发现输出有很多超大的值的时候,进行裁剪
    梯度消失:它很难察觉,也在标准的RNN结构中,可以用GRU或者LSTM解决,而解决梯度消失的实质是通过保留一些前期输入的记忆

LSTM

标准的RNN结构是这样的:
在这里插入图片描述
而标准的LSTM加入了四个门控单元
在这里插入图片描述
这四个门单元分别是:
it, ft,gt, ot 分别是输入门、遗忘门、cell(记忆)门,以及输出门
他们控制 是否输入,对输入的遗忘和记忆,也控制是否输出,从而控制重要信息的传递,不重要信息遗忘

GRU

它比LSTM诞生更晚,是LSTM的变形版本,由于门单元更少,计算简单些,因此训练时间更短一些。
在这里插入图片描述

LSTM实例

本实例以IMDB数据集为例,代码篇幅过长,本文仅列示其中LSTM使用相关的重点,后续会有专门的博客详细解析代码。

  1. 预处理数据
    读取IMDB数据集
def read_imdb(datafolder ='train', dataroot=imdb_zip_path):data=[]for label in ['pos', 'neg']:filepath = os.path.join(imdb_zip_path, datafolder, label)for file in tqdm(os.listdir(filepath)):with open(os.path.join(filepath,file), 'rb') as f:content = f.read().decode('utf-8').replace('\n', ' ').lower()data.append([content, 1 if label == 'pos' else 0])random.shuffle(data)return data

IMDB中的数据分词

def get_tokenized(data):def tokenizer(text):filters = ['!', '"', '#', '$', '%', '&', '\(', '\)', '\*', '\+', ',', '-', '\.', '/', ':', ';', '<', '=', '>','\?', '@', '\[', '\\', '\]', '^', '_', '`', '\{', '\|', '\}', '~', '\t', '\n', '\x97', '\x96', '”', '“', ]text = re.sub("<.*?>", " ", text, flags=re.S)text = re.sub("|".join(filters), " ", text, flags=re.S)return [i.strip().lower() for i in text.split()]return [tokenizer(context) for context, _ in data]

创建分词后的词典

def get_vocab(data):counter = collections.Counter(_flatten(data))return vocab.vocab(counter)

封装dataloader

class ImdbLoader(object):def __init__(self, set_name='train', batch_size='64'):super(ImdbLoader, self).__init__()self.data_set = set_nameself.batch_size = batch_sizedef get_data_loader(self):# train_data = [['"dick tracy" is one of our"', 1],#               ['arguably this is a  the )', 1],#               ["i don't  just to warn anyone ", 0]]train_data = read_imdb(self.data_set)data = preprocess(train_data)#print(data)data_set = Data.TensorDataset(*data)data_loader = Data.DataLoader(data_set, self.batch_size, shuffle=True)return data_loader
  1. 创建模型
    此处创建了一个模型,包括双向的LSTM层和一个全连接层
 class BiRNN(nn.Module):def __init__(self, vocabulary, embed_len, hidden_len, num_layer):super(BiRNN, self).__init__()self.embedding = nn.Embedding(len(vocabulary), embed_len)self.encoder = nn.LSTM(input_size=embed_len,hidden_size=hidden_len,num_layers=num_layer,bidirectional=True,dropout = 0.3)# 本次使用起始和最终时间步的隐藏状态座位全连接层的输入self.decoder = nn.Linear(2*2*hidden_len, 2)def forward(self, inputs):#print('rnn model py: input_shape: ', inputs.shape)embeddings = self.embedding(inputs)glove_vab = getGlove()net.embedding.weight.data.copy_(load_pretrained_embedding(vo.get_itos(), glove_vab))net.embedding.weight.requires_grad = False#print('after embed input shape:', embeddings.shape)embeddings = embeddings.permute(1, 0, 2)output_sequence, _ = self.encoder(embeddings)concat_out = torch.cat((output_sequence[0], output_sequence[-1]), -1)outputs = self.decoder(concat_out)return outputs

3.模型训练

   def train(epoch, imdb_model, lr, train_batch_size):imdb_model_device = imdb_model.to(device)# 过滤掉不需要计算梯度的embedding的参数optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, imdb_model_device.parameters()), lr=lr)loader = ImdbLoader('train', train_batch_size)data_loader = loader.get_data_loader()for i in range(epoch):for idx, (inputs, target) in enumerate(data_loader):target = target.to(device)inputs = inputs.to(device)#print('train.py input shape:', inputs.shape)optimizer.zero_grad()output = imdb_model(inputs)#print('ouput.shape', output.shape)criterion = nn.CrossEntropyLoss()loss = criterion(output, target)loss.backward()optimizer.step()if idx % 10 == 0:predict = torch.max(output, dim=-1, keepdim=False)[-1]acc = predict.eq(target.data).cpu().numpy().mean() * 100print('train Epoch:{} processed:[{} / {} ({:.0f}%) Loss: {:.6f}, ACC: {:.6f}]'.format(i,idx * len(inputs),len(data_loader.dataset),100. * idx / len(data_loader),loss.item(),acc))torch.save(imdb_model.state_dict(), '../../resources/model_save/imdb_net.pkl')torch.save(optimizer.state_dict(), '../../resources/model_save/imdb_optimizer.pkl')

4.模型评估

 def test(imdb_model, test_batch_size):imdb_model.eval()imdb_model = imdb_model.to(device)loader = ImdbLoader('test', test_batch_size)data_loader = loader.get_data_loader()with torch.no_grad():for idx, (inputs, target) in enumerate(data_loader):target = target.to(device)inputs = inputs.to(device)#print(inputs)output = imdb_model(inputs)criterion = nn.CrossEntropyLoss()loss = criterion(output, target)predict = torch.max(output, dim=-1, keepdim=False)[-1]correct = predict.eq(target.data).sum()acc = 100. * predict.eq(target.data).cpu().numpy().mean()print('idx: {} loss : {}, accurate: {}/{} {:.2f}'.format(idx,  loss, correct, target.size(0), acc))

最终效果,当我们执行4个epoch后,准确率基本稳定在80%以上

train Epoch:3 processed:[23680 / 25000 (95%) Loss: 0.325240, ACC: 84.375000]
train Epoch:3 processed:[24320 / 25000 (97%) Loss: 0.449456, ACC: 75.000000]
train Epoch:3 processed:[15600 / 25000 (100%) Loss: 0.438567, ACC: 80.000000]
train Epoch:4 processed:[0 / 25000 (0%) Loss: 0.353131, ACC: 85.937500]
train Epoch:4 processed:[640 / 25000 (3%) Loss: 0.345814, ACC: 89.062500]
train Epoch:4 processed:[1280 / 25000 (5%) Loss: 0.195520, ACC: 93.750000]
train Epoch:4 processed:[1920 / 25000 (8%) Loss: 0.269773, ACC: 87.500000]
train Epoch:4 processed:[2560 / 25000 (10%) Loss: 0.287010, ACC: 85.937500]
train Epoch:4 processed:[3200 / 25000 (13%) Loss: 0.291449, ACC: 90.625000]

参考

colah https://colah.github.io/posts/2015-08-Understanding-LSTMs/
吴恩达 deep learning

https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html#torch.nn.LSTM

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/39323.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

flutter 常见的状态管理器

flutter 常见的状态管理器 前言一、Provider二、Bloc三、Redux四、GetX总结 前言 当我们构建复杂的移动应用时&#xff0c;有效的状态管理是至关重要的&#xff0c;因为应用的不同部分可能需要共享数据、相应用户交互并保持一致的状态。Flutter 中有多种状态管理解决方案&#…

0143 串

目录 4.串 4.1串的定义和实现 4.2串的模式匹配 部分习题 4.串 4.1串的定义和实现 4.2串的模式匹配 部分习题 1.设有两个串S1和S2&#xff0c;求S2在S1中首次出现的位置的运算称为&#xff08;&#xff09; A.求字串 B.判断是否相等 C.模式匹配 D.连…

Vue2(组件开发)

目录 前言一&#xff0c;组件的使用二&#xff0c;插槽slot三&#xff0c;refs和parent四&#xff0c;父子组件间的通信4.1&#xff0c;父传子 &#xff1a;父传子的时候&#xff0c;通过属性传递4.2&#xff0c;父组件监听自定义事件 五&#xff0c;非父子组件的通信六&#x…

麦肯锡发布《2023年度科技报告》!

在经历了 2022 年技术投资和人才的动荡之后&#xff0c;2023 年上半年&#xff0c;人们对技术促进商业和社会进步的潜力重新燃起了热情。生成式人工智能&#xff08;Generative AI&#xff09;在这一复兴过程中功不可没&#xff0c;但它只是众多进步中的一个&#xff0c;可以推…

总说绿幕直播抠像抠不干净?很有可能是你不知道这个神器!

在绿幕直播的时候&#xff0c;你是不是座位、绿幕、灯光都摆对了&#xff0c;但主播轮廓仍然有绿边和虚化的情况发生&#xff1f;这种很大可能就是你使用的直播抠像软件有问题。今天小编把市面上的常见直播软件来和vLive虚拟直播的抠像做一个对比&#xff0c;让你直观感受下他们…

机器学习笔记 - 基于PyTorch + 类似ResNet的单目标检测

一、获取并了解数据 我们将处理年龄相关性黄斑变性 (AMD) 患者的眼部图像。 数据集下载地址,从下面的地址中,找到iChallenge-AMD,然后下载。 Baidu Research Open-Access Dataset - DownloadDownload Baidu Research Open-Access Datasethttps://ai.baidu.com/bro…

基于ACF,AMDF算法的语音编码matlab仿真

目录 1.算法运行效果图预览 2.算法运行软件版本 3.部分核心程序 4.算法理论概述 5.算法完整程序工程 1.算法运行效果图预览 2.算法运行软件版本 matlab2022a 3.部分核心程序 .......................................................................... plotFlag …

函数递归专题(案例超详解一篇讲通透)

函数递归 前言1.递归案例:案例一&#xff1a;取球问题案例二&#xff1a;求斐波那契额数列案例三&#xff1a;函数实现n的k次方案例四&#xff1a;输入一个非负整数&#xff0c;返回组成它的数字之和案例五&#xff1a;元素逆置案例六&#xff1a;实现strlen案例七&#xff1a;…

服务器遭受攻击之后的常见思路

哈喽大家好&#xff0c;我是咸鱼 不知道大家有没有看过这么一部电影&#xff1a; 这部电影讲述了男主是一个电脑极客&#xff0c;在计算机方面有着不可思议的天赋&#xff0c;男主所在的黑客组织凭借着超高的黑客技术去入侵各种国家机构的系统&#xff0c;并引起了德国秘密警察…

Mac如何打开隐藏文件中Redis的配置文件redis.conf

Redis下载(通过⬇️博客下载的Redis默认路径为&#xff1a;/usr/local/etc) Redis下载 1.打开终端进入/usr文件夹 cd /usr 2.打开/local/文件夹 open local 3.找到redis.conf并打开,即可修改配置信息

讯飞星火认知大模型全新升级,全新版本、多模交互—测评结果超预期

写在前面 版本新功能 1 体验介绍 登录注册 申请体验 2 具体使用 2.1 多模态能力 2.1.1 多模理解 2.1.2 视觉问答 2.1.3 多模生成 2.2 代码能力 2.2.1 代码生成 2.2.2 代码解释 2.2.3 代码纠错 2.2.4 单元测试 2.3 插件功能 2.3.1 PPT生成 2.3.2 简历生成 2.3.4 文档问答 3 其他…

Android学习之路(3) 布局

线性布局LinearLayout 前几个小节的例程中&#xff0c;XML文件用到了LinearLayout布局&#xff0c;它的学名为线性布局。顾名思义&#xff0c;线性布局 像是用一根线把它的内部视图串起来&#xff0c;故而内部视图之间的排列顺序是固定的&#xff0c;要么从左到右排列&#xf…

Android之版本号、版本别名、API等级对应关系(全)(一百六十二)

简介&#xff1a; CSDN博客专家&#xff0c;专注Android/Linux系统&#xff0c;分享多mic语音方案、音视频、编解码等技术&#xff0c;与大家一起成长&#xff01; 优质专栏&#xff1a;Audio工程师进阶系列【原创干货持续更新中……】&#x1f680; 人生格言&#xff1a; 人生…

HTML详解连载(4)

HTML详解连载&#xff08;4&#xff09; 专栏链接 [link](http://t.csdn.cn/xF0H3)下面进行专栏介绍 开始喽CSS定义书写位置示例注意 CSS引入方式内部样式表&#xff1a;学习使用 外部演示表&#xff1a;开发使用代码示例行内样式代码示例 选择器作用基础选择器标签选择器举例特…

RISC-V公测平台发布 · 7-zip 测试

简介 7-Zip 是一个开源的压缩和解压缩工具&#xff0c;具有高压缩比和快速解压缩的特点。除了普通的文件压缩和解压缩功能之外&#xff0c;7-Zip 还提供了基准测试功能&#xff0c;通过压缩和解压缩大型文件来评估系统的处理能力和性能。 7-Zip 提供了一种在不同压缩级别和多…

BUUCTF [MRCTF2020]Ezpop解题思路

题目代码 Welcome to index.php <?php //flag is in flag.php //WTF IS THIS? //Learn From https://ctf.ieki.xyz/library/php.html#%E5%8F%8D%E5%BA%8F%E5%88%97%E5%8C%96%E9%AD%94%E6%9C%AF%E6%96%B9%E6%B3%95 //And Crack It! class Modifier {protected $var;publi…

运维监控学习笔记7

Zabbix的安装&#xff1a; 1、基础环境准备&#xff1a; 安装zabbix的yum源&#xff0c;阿里的yum源提供了zabbix3.0。 rpm -ivh http://mirrors.aliyun.com/zabbix/zabbix/3.0/rhel/7/x86_64/zabbix-release-3.0-1.el7.noarch.rpm 这个文件就是生成了一个zabbix.repo 2、安…

流程挖掘in汽车丨宝马的流程效能提升实例

汽车行业在未来10年里&#xff0c;可能会面临比过去50年更多的变化。电动化、智能化、共享化和自动驾驶等方面的趋势可能给企业流程带来以下挑战&#xff1a; 供应链管理-电动化和智能化的发展可能导致供应链中的零部件和系统结构发生变化&#xff0c;企业需要重新评估和优化供…

zookeeperAPI操作与写数据原理

要执行API操作需要在idea中创建maven项目 &#xff08;改成自己的阿里仓库&#xff09;导入特定依赖 添加日志文件 上边操作做成后就可以进行一些API的实现了 目录 导入maven依赖&#xff1a; 创建日志文件&#xff1a; 创建API客户端&#xff1a; &#xff08;1&#xff09…

Springboot 实践(5)springboot添加资源访问目录及目录测试

前文讲解了swagger测试服务控制器&#xff0c;实现了数据库数据访问&#xff0c;这些功能都是运行在后台服务器上&#xff0c;实际用户并不能直接调用接口获取数据&#xff0c;即使用户能够利用接口获取到数据&#xff0c;数据也是结构化数据&#xff0c;不能争取转化成用户使用…