【python,机器学习,nlp】RNN循环神经网络

RNN(Recurrent Neural Network),中文称作循环神经网络,它一般以序列数据为输入,通过网络内部的结构设计有效捕捉序列之间的关系特征,一般也是以序列形式进行输出。

因为RNN结构能够很好利用序列之间的关系,因此针对自然界具有连续性的输入序列,如人类的语言,语音等进行很好的处理,广泛应用于NLP领域的各项任务,如文本分类,情感分析,意图识别,机器翻译等.

RNN模型的分类:

这里我们将从两个角度对RNN模型进行分类.第一个角度是输入和输出的结构,第二个角度是RNN的内部构造.

按照输入和输出的结构进行分类:

N vs N-RNN

它是RNN最基础的结构形式,最大的特点就是:输入和输出序列是等长的.由于这个限制的存在,使其适用范围比较小,可用于生成等长度的合辙诗句.

N vs 1-RNN

有时候我们要处理的问题输入是一个序列,而要求输出是一个单独的值而不是序列,要在最后一个隐层输出h上进行线性变换。

大部分情况下,为了更好的明确结果,还要使用sigmoid或者softmax进行处理.这种结构经常被应用在文本分类问题上.

1 vs N-RNN

我们最常采用的一种方式就是使该输入作用于每次的输出之上.这种结构可用于将图片生成文字任务等.

N vs M-RNN

这是一种不限输入输出长度的RNN结构,它由编码器和解码器两部分组成,两者的内部结构都是某类RNN,它也被称为seq2seq架构。

输入数据首先通过编码器,最终输出一个隐含变量c,之后最常用的做法是使用这个隐含变量c作用在解码器进行解码的每一步上,以保证输入信息被有效利用。

按照RNN的内部构造进行分类:

传统RNN

内部计算函数

tanh的作用: 用于帮助调节流经网络的值,tanh函数将值压缩在﹣1和1之间。

传统RNN的优势:
由于内部结构简单,对计算资源要求低,相比之后我们要学习的RNN变体:LSTM和GRU模型参数总量少了很多,在短序列任务上性能和效果都表现优异。

传统rnn的缺点:
传统RNN在解决长序列之间的关联时,通过实践,证明经典RNN表现很差,原因是在进行反向传播的时候,过长的序列导致梯度的计算异常,发生梯度消失或爆炸。

LSTM

LSTM (Long Short-Term Memory)也称长短时记忆结构,它是传统RNN的变体,与经典RNN相比能够有效捕捉长序列之间的语义关联,缓解梯度消失或爆炸现象,同时LSTM的结构更复杂。

LSTM缺点:由于内部结构相对较复杂,因此训练效率在同等算力下较传统RNN低很多.

LSTM优势:LSTM的门结构能够有效减缓长序列问题中可能出现的梯度消失或爆炸,虽然并不能杜绝这种现象,但在更长的序列问题上表现优于传统RNN.

 

它的核心结构可以分为四个部分去解析:

遗忘门

与传统RNN的内部结构计算非常相似,首先将当前时间步输入x(t)与上一个时间步隐含状态h(t-1)拼接, 得到[x(t), h(t-1)],然后通过一个全连接层做变换,最后通过sigmoid函数(变化到【0,1】)进行激活得到f(t),我们可以将f(t)看作是门值,好比一扇门开合的大小程度,门值都将作用在通过该扇门的张量,遗忘门门值将作用的上一层的细胞状态上,代表遗忘过去的多少信息,又因为遗忘门门值是由x(t), h(t-1)计算得来的,因此整个公式意味着根据当前时间步输入和上一个时间步隐含状态h(t-1)来决定遗忘多少上一层的细胞状态所携带的过往信息.

输入门

输入门的计算公式有两个,第一个就是产生输入门门值的公式,它和遗忘门公式几乎相同,区别只是在于它们之后要作用的目标上,这个公式意味着输入信息有多少需要进行过滤.输入门的第二个公式是与传统RNN的内部结构计算相同.对于LSTM来讲,它得到的是当前的细胞状态,而不是像经典RNN一样得到的是隐含状态.

细胞状态

我们看到输入门的计算公式有两个,第一个就是产生输入门门值的公式,它和遗忘门公式几乎相同,区别只是在于它们之后要作用的目标上.这个公式意味着输入信息有多少需要进行过滤.输入门的第二个公式是与传统RNN的内部结构计算相同.对于LSTM来讲,它得到的是当前的细胞状态,而不是像经典RNN一样得到的是隐含状态。

输出门

输出门部分的公式也是两个,第一个即是计算输出门的门值,它和遗忘门,输入门计算方式相同.第二个即是使用这个门值产生隐含状态h(t),他将作用在更新后的细胞状态C(t)上,并做tanh激活,最终得到h(t)作为下一时间步输入的一部分.整个输出门的程,就是为了产生隐含状态h(t)。

Bi-LSTM

Bi-LSTM即双向LSTM,它没有改变LSTM本身任何的内部结构,只是将LSTM应用两次且方向不同,再将两次得到的LSTM结果进行拼接作为最终输出

GRU

GRU(Gated Recurrent Unit)也称门控循环单元结构,它也是传统RNN的变体,同LSTM一样能够有效捕捉长序列之间的语义关联,缓解梯度消失或爆炸现象.同时它的结构和计算要比LSTM 更简单。

GRU的优势:GRU和LSTM作用相同,在捕捉长序列语义关联时,能有效抑制梯度消失或爆炸,效果都优于传统rnn且计算复杂度相比lstm要小.

GRU的缺点:GRU仍然不能完全解决梯度消失问题,同时其作用RNN的变体,有着RNN结构本身的一大弊端,即不可并行计算,这在数据量和模型体量逐步增大的未来,是RNN发展的关键瓶颈

它的核心结构可以分为两个部分去解析:

更新门 
重置门

Bi-GRU

Bi-GRU与Bi-LSTM的逻辑相同,都是不改变其内部结构,而是将模型应用两次且方向不同,再将两次得到的LSTM结果进行拼接作为最终输出.具体参见上小节中的Bi-LSTM。

注意力机制

注意力机制是注意力计算规则能够应用的深度学习网络的载体,同时包括一些必要的全连接层以及相关张量处理,使其与应用网络融为一体.使自注意力计算规则的注意力机制称为自注意力机制.

注意力计算规则

它需要三个指定的输入Q(query), K(key), V(value), 然后通过计算公式得到注意力的结果,这个结果代表query在key和value作用下的注意力表示.当输入的Q=K=V时,称作自注意力计算规则.

注意力机制的作用

在解码器端的注意力机制: 能够根据模型目标有效的聚焦编码器的输出结果,当其作为解码器的输入时提升效果,改善以往编码器输出是单一定长张量,无法存储过多信息的情况.

在编码器端的注意力机制:主要解决表征问题,相当于特征提取过程,得到输入的注意力表示.一般使用自注意力(self-attention).

注意力机制实现步骤

第一步:根据注意力计算规则,对Q,K,V进行相应的计算.

第二步:根据第一步采用的计算方法,如果是拼接方法,则需要将Q与第二步的计算结果再进行拼接,如果是转置点积,一般是自注意力,Q与V相同,则不需要进行与Q的拼接.

第三步:最后为了使整个attention机制按照指定尺寸输出,使用线性层作用在第二步的结果上做一个线性变换,得到最终对Q的注意力表示.

代码实现

传统模型

import torch
import torch.nn as nn"""
nn.RNN类初始化主要参数解释
input_size:输入张量x中特征维度的大小
hidden_size:隐层张量h中特征维度的大小
num_layers: 隐含层的数量.
nonlinearity: 激活函数的选择,默认是tanh.
"""
rnn=nn.RNN(input_size=5,hidden_size=6,num_layers=1)"""
设定输入的张量x
第一个参数:sequence_length(输入序列的长度)
第二个参数:batch_size(批次的样本数)
第三个参数:input_size(输入张量x的维度)
"""
input=torch.randn(1,3,5)
"""
设定初始化的h0
第一个参数:num_layers *num_directions(层数*网络方向数)
第二个参数:batch_size(批次的样本数)
第三个参数:hiddeh_size(隐藏层的维度)
"""
h0=torch.randn(1,3,6)"""
nn.RNN类实例化对象主要参数解释
input: 输入张量x
h0:初始化的隐层张量h
"""
output,hn=rnn(input,h0)

LSTM模型

import torch
import torch.nn as nn"""
nn.LSTM类初始化主要参数解释:
input_size: 输入张量x中特征维度的大小.
hidden_size: 隐层张量h中特征维度的大小.
num_layers: 隐含层的数量.
bidirectional: 是否选择使用双向LSTM,如果为True,则使用;默认不使用.
"""
rnn=nn.LSTM(input_size=5,hidden_size=6,num_layers=2)"""
设定输入的张量x
第一个参数:sequence_length(输入序列的长度)
第二个参数:batch_size(批次的样本数)
第三个参数:input_size(输入张量x的维度)
"""
input=torch.randn(1,3,5)
"""
设定初始化的h0,c0
第一个参数:num_layers *num_directions(层数*网络方向数)
第二个参数:batch_size(批次的样本数)
第三个参数:hiddeh_size(隐藏层的维度)
"""
h0=torch.randn(2,3,6)
c0=torch.randn(2,3,6)"""
nn.LSTM类实例化对象主要参数解释
input: 输入张量x
h0:初始化的隐层张量h.
cO:初始化的细胞状态张量c.
"""
output,(hn,cn)=rnn(input,(h0,c0))

GRU模型

import torch
import torch.nn as nn"""
nn.GRU类初始化主要参数解释
Input_size: 输入张量x中特征维度的大小
hidden_size:隐层张量h中特征维度的大小
num_layers:隐含层的数量
bidirectional: 是否选择使用双向LSTM,如果为True,则使用;默认不使用
"""
rnn=nn.GRU(input_size=5,hidden_size=6,num_layers=2)"""
设定输入的张量x
第一个参数:sequence_length(输入序列的长度)
第二个参数:batch_size(批次的样本数)
第三个参数:input_size(输入张量x的维度)
"""
input=torch.randn(1,3,5)
"""
设定初始化的h0
第一个参数:num_layers *num_directions(层数*网络方向数)
第二个参数:batch_size(批次的样本数)
第三个参数:hiddeh_size(隐藏层的维度)
"""
h0=torch.randn(2,3,6)"""
nn.GRU类实例化对象主要参数解释
input: 输入张量x.
h0:初始化的隐层张量h.
"""
output,hn=rnn(input,h0)

注意力模型

import torch
import torch.nn as nn
import torch.nn.functional as F#建立attn类
class Attn(nn.Module):def __init__(self, query_size,key_size,value_size1,value_size2,output_size):"""_summary_Args:query_size (_type_): 代表的是Q的最后一个维度key_size (_type_): 代表的K的最后一个维度value_size1 (_type_): 代表value的导数第二维大小value_size2 (_type_): 代表value的倒数第一维大小output_size (_type_): 代表输出的最后一个维度的大小"""super(Attn, self).__init__()self.query_size = query_sizeself.key_size = key_sizeself.value_size1 = value_size1self.value_size2 = value_size2self.output_size = output_size# 初始化注意力机制self.attn=nn.Linear(self.query_size+self.key_size,self.value_size1)self.attn_combine=nn.Linear(self.query_size+self.value_size2,self.output_size)def forward(self,query,key,value):"""_summary_Args:query (_type_): 代表Qkey (_type_): 代表Kvalue (_type_): 代表VReturns:_type_: 返回注意力机制的输出"""# 计算注意力权重attn_weights=F.softmax(self.attn(torch.cat((query[0],key[0]),1)),dim=1)attn_applied=torch.bmm(attn_weights.unsqueeze(0),value)# 计算注意力机制的输出output=torch.cat((query[0],attn_applied[0]),1)output=self.attn_combine(output).unsqueeze(0)return output,attn_weightsquery_size=32
key_size=32
value_size1=32
value_size2=64
output_size=64#初始化attn
attn=Attn(query_size,key_size,value_size1,value_size2,output_size)
#使用attn实例
Q=torch.randn(1,1,32)
K=torch.randn(1,1,32)
V=torch.randn(1,32,64)
output=attn(Q,K,V)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/615514.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

螺杆冷水机组工作原理

螺杆冷水机组主要由螺杆压缩机、冷凝器、蒸发器、膨胀阀及电控系统组成。水冷单螺杆冷水机组制冷原图如下: (一)双螺杆制冷压缩机 双螺杆制冷压缩机是一种能量可调式喷油压缩机。它的吸气、压缩、排气三个连续过程是靠机体内的一对相互啮合的…

软考高级选择考哪个好?

📒软考高级总共5个科目,同样是高级证书,认可度也有区别! 大家一般在「信息系统项目管理师」✔️和「系统架构设计师」✔️二选一 1️⃣信息系统项目管理师 ❤️信息系统项目管理师也叫「高项」,考试内容主要是「项目管理」相关&am…

【思扬赠书 | 第1期】教你如何一站式解决OpenCV工程化开发痛点

⛳️ 写在前面参与规则!!! ✅参与方式:关注博主、点赞、收藏、评论,任意评论(每人最多评论三次) ⛳️本次送书1~3本【取决于阅读量,阅读量越多,送的越多】 思扬赠书 | 第…

Sublime Text 3配置 Python 开发环境

Sublime Text 3配置 Python 开发环境 一、引言二、主要内容1. 初识 Sublime Text 32. 初识 Python2. 接入 Python2.1 下载2.2 安装和使用 python2.2 环境变量配置 3. 配置 Python 开发环境4. 编写 Python 代码5. 运行 Python 代码 三、总结 一、引言 Python 是一种简洁但功能强…

AI时代下的智能商品计划如何助力服装企业实现库存精准优化

在AI时代,智能商品计划为服装企业实现库存精准优化提供了强大的支持。以下是AI在这方面的关键作用和助力手段: 1. 数据驱动的需求预测: AI利用大数据和机器学习技术,分析历史销售数据、市场趋势、季节性变化等多方面信息&#x…

uniapp微信小程序投票系统实战 (SpringBoot2+vue3.2+element plus ) -我参与的投票列表实现

锋哥原创的uniapp微信小程序投票系统实战: uniapp微信小程序投票系统实战课程 (SpringBoot2vue3.2element plus ) ( 火爆连载更新中... )_哔哩哔哩_bilibiliuniapp微信小程序投票系统实战课程 (SpringBoot2vue3.2element plus ) ( 火爆连载更新中... )共计21条视频…

给高中生的一些建议

背景 总分300,各科60分左右 基本原理 破罐子破摔,集中力量办大事 分析 破罐子破摔 从高一到现在高二上学期的成绩来看,如果继续保持目前这种状态,到高考也是稳在300左右。即,如果不改变,就是咸鱼一条。既…

打破硬件壁垒:TVM 助力 AI技术跨平台部署

文章目录 《TVM编译器原理与实践》编辑推荐内容简介作者简介目录前言/序言获取方式 随着人工智能(Artificial Intelligence,AI)在全世界信息产业中的广泛应用,深度学习模型已经成为推动AI技术革命的关键。TensorFlow、PyTorch、MX…

【习题】应用程序框架

判断题 1. 一个应用只能有一个UIAbility。错误(False) 正确(True)错误(False) 2. 创建的Empty Ability模板工程,初始会生成一个UIAbility文件。正确(True) 正确(True)错误(False) 3. 每调用一次router.pushUrl()方法,页面路由栈数量均会加1。错误(Fal…

开放式耳机品牌排行榜,2024开放式耳机选购攻略

我在选后开放式耳机的路上可以说是花了不少米,前前后后也下了不少的功夫去做功课了解开放式耳机,包括市面上目前最火的西圣、南卡、cleer等热门型号我都有用过了,可以说是很有发言权了吧。 开放式耳机现在越来越涌现在大众的视野上了&#x…

如何构建Prompt,帮我生成QA,作为召回率检索的测试集?

最近在做搜索召回率的提升工作。粮草未动兵马先行!在做之前应该先有一把尺子来衡量召回率的好坏。所以应该先构建测试数据集,然后去做标准化测试。 有了测试机集以后。再去做搜索优化,才能看出来效果。 当然可以选择一些开源的测试集。如果可…

POI:对Excel的基本读操作 整理2

1 简单读取操作 public class ExcelRead {String PATH "D:\\Idea-projects\\POI\\POI_projects";// 读取的一系列方法// ...... } 因为07版本和03版本操作流程大差不差,所以这边就以03版本为例 Testpublic void testRead03() throws IOException {//获取…

可拖拽表单比传统表单好在哪里?

随着行业的进步和发展,可拖拽表单的应用价值越来越高,在推动企业实现流程化办公和数字化转型的过程中发挥了重要价值和作用,是提质增效的办公利器,也是众多行业客户朋友理想的合作伙伴。那么,可拖拽表单的优势特点表单…

【MySQL】聚合函数与分组查询

聚合函数与分组查询 一、聚合函数1、常见的聚合函数2、实例 二、分组查询1、group by子句2、准备工作3、实例4、having 条件 一、聚合函数 说明:聚合函数用来计算一组数据的集合并返回单个值,通常用这些函数完成:个数的统计,某列…

Dubbo 框架揭秘:分布式架构的精髓与魔法【一】

欢迎来到我的博客,代码的世界里,每一行都是一个故事 Dubbo 框架揭秘:分布式架构的精髓与魔法【一】 前言Dubbo是什么Dubbo的核心概念整体设计 前言 在数字时代,分布式架构正成为应对大规模流量和复杂业务场景的标配。Dubbo&#…

【快刊录用】ABS一星,2区,仅2个月15天录用!

2023年12月30日-2024年1月5日 进展喜讯 经核实,由我处Unionpub学术推荐的论文中,新增2篇论文录用、3篇上线见刊、1篇数据库检索: 录用通知 FA20107 FA20181 — 见刊通知 FB20805 FA20269 FA20797 检索通知 FA20199 — — 计算机…

配网故障定位技术的发展与应用:保障电力供应安全稳定的重要支撑

在现代社会,电力供应安全稳定对于国家经济发展和民生福祉至关重要。然而,随着电网规模的不断扩大,配网故障问题也日益突出。为了确保电力供应的连续性和可靠性,人们不断探索和研发各种故障定位技术。本文将介绍一种基于行波测距技…

[Linux 进程(二)] Linux进程状态

文章目录 1、进程各状态的概念1.1 运行状态1.2 阻塞状态1.3 挂起状态 2、Linux进程状态2.1 运行状态 R2.2 睡眠状态 S2.3 深度睡眠 D2.4 停止状态 T2.5 僵尸状态 Z 与 死亡状态 X孤儿进程 Linux内核中,进程状态,就是PCB中的一个字段,是PCB中的…

智慧食堂管理方式,究竟改变了什么?

随着科技的迅速发展,餐饮业也在不断地迎来新的挑战和机遇。为了提升食堂管理效率、改善用户体验以及提高收益,许多食堂纷纷引入智慧收银系统。 客户案例 企业食堂改革 石家庄某大型企业食堂由于员工数量庞大,传统的收银方式难以满足快速就餐…

大话 JavaScript(Speaking JavaScript):第二十一章到第二十五章

第二十一章:数学 原文:21. Math 译者:飞龙 协议:CC BY-NC-SA 4.0 Math对象用作多个数学函数的命名空间。本章提供了一个概述。 数学属性 Math的属性如下: Math.E 欧拉常数(e) Math.LN2 2 …