跟着问题学15——GRU网络结构详解及代码实战

   

1 RNN的缺陷——长期依赖的问题 (The Problem of Long-Term Dependencies)

前面一节我们学习了RNN神经网络,它可以用来处理序列型的数据,比如一段文字,视频等等。RNN网络的基本单元如下图所示,可以将前面的状态作为当前状态的输入。

但也有一些情况,我们需要更“长期”的上下文信息。比如预测最后一个单词“我在中国长大……我说一口流利的**。”“短期”的信息显示,下一个单词很可能是一种语言的名字,但如果我们想缩小范围,我们需要更长期语境——“我在中国长大”,但这个相关信息与需要它的点之间的距离完全有可能变得非常大。

不幸的是,随着这种距离的扩大,RNN无法学会连接这些信息。

从理论上讲,RNN绝对有能力处理这种“长期依赖性”。人们可以为他们精心选择参数,以解决这种形式的问题。遗憾的是,在实践中,RNN似乎无法学习它们。

幸运的是,GRU也没有这个问题!

2、GRU

什么是GRU

GRU(Gate Recurrent Unit)是循环神经网络(Recurrent Neural Network, RNN)的一种。和LSTM(Long-Short Term Memory)一样,也是为了解决长期记忆和反向传播中的梯度等问题而提出来的。

GRU和LSTM在很多情况下实际表现上相差无几,那么为什么我们要使用新人GRU(2014年提出)而不是相对经受了更多考验的LSTM(1997提出)呢。

用论文中的话说,相比LSTM,使用GRU能够达到相当的效果,并且相比之下更容易进行训练,能够很大程度上提高训练效率,因此很多时候会更倾向于使用GRU。

2.1总体结构框架

前面我们讲到,神经网络的各种结构都是为了挖掘变换数据特征的,所以下面我们也将结合数据特征的维度来对比介绍一下RNN&&LSTM的网络结构。

多层感知机(线性连接层)结构

从特征角度考虑:

输入特征:是n*1的单维向量(这也是为什么卷积神经网络在linear层前要把所有特征层展平),

隐藏层:然后根据隐藏层神经元的数量m将前层输入的特征用m*1的单维向量进行表示(对特征进行了提取变换,隐藏层的数据特征),单个隐藏层的神经元数量就代表网络参数,可以设置多个隐藏层;

输出特征:最终根据输出层的神经元数量y输出y*1的单维向量。

卷积神经网络结构

 从特征角度考虑:

输入特征:是(batch)*channel*width*height的张量,

卷积层(等):然后根据输入通道channel的数量c_in和输出通道channel的数量c_out会有c_out*c_in*k*k个卷积核将前层输入的特征进行卷积(对特征进行了提取变换,k为卷积核尺寸),卷积核的大小和数量c_out*c_in*k*k就代表网络参数,可以设置多个卷积层;每一个channel都代表提取某方面的一种特征,该特征用width*height的二维张量表示,不同特征层之间是相互独立的(可以进行融合)。

输出特征:根据场景的需要设置后面的输出,可以是多分类的单维向量等等。

循环神经网络RNN系列结构

从特征角度考虑:

输入特征:是(batch)*T_seq*feature_size的张量(T_seq代表序列长度,注意不是batch_size).

我们来详细对比一下卷积神经网络的输入特征,

(batch)*T_seq*feature_size

(batch)*channel*width*height,

逐个进行分析,RNN系列的基础输入特征表示是feature_size*1的单维向量,比如一个单词的词向量,比如一个股票价格的影响因素向量,而CNN系列的基础输入特征是width*height的二维张量;

再来看一下序列T_seq和通道channel,RNN系列的序列T_seq是指一个连续的输入,比如一句话,一周的股票信息,而且这个序列是有时间先后顺序且互相关联的,而CNN系列的通道channel则是指不同角度的特征,比如彩色图像的RGB三色通道,过程中每个通道代表提取了每个方面的特征,不同通道之间是没有强相关性的,不过也可以进行融合。

最后就是batch,两者都有,在RNN系列,batch就是有多个句子,在CNN系列,就是有多张图片(每个图片可以有多个通道)

隐藏层:明确了输入特征之后,我们再来看看隐藏层代表着什么。隐藏层有T_seq个隐状态H_t(和输入序列长度相同),每个隐状态H_t类似于一个channel,对应着T_seq中的t时刻的输入特征;而每个隐状态H_t是用hidden_size*1的单维向量表示的,所以一个隐含层是T_seq*hidden_size的张量;对应时刻t的输入特征由feature_size*1变为hidden_size*1的向量。如图中所示,同一个隐含层不同时刻的参数W_ih和W_hh是共享的;隐藏层可以有num_layers个(图中只有1个)

以t时刻具体阐述一下:

X_t是t时刻的输入,是一个feature_size*1的向量

W_ih是输入层到隐藏层的权重矩阵

H_t是t时刻的隐藏层的值,是一个hidden_size*1的向量

W_hh是上一时刻的隐藏层的值传入到下一时刻的隐藏层时的权重矩阵

Ot是t时刻RNN网络的输出

从上右图中可以看出这个RNN网络在t时刻接受了输入Xt之后,隐藏层的值是St,输出的值是Ot。但是从结构图中我们可以发现St并不单单只是由Xt决定,还与t-1时刻的隐藏层的值St-1有关。

2.2 GRU的输入输出结构

GRU的输入输出结构与普通的RNN是一样的。有一个当前的输入xt,和上一个节点传递下来的隐状态(hidden state)ht-1 ,这个隐状态包含了之前节点的相关信息。结合xt和 ht-1,GRU会得到当前隐藏节点的输出yt 和传递给下一个节点的隐状态 ht。

图 GRU的输入输出结构

那么,GRU到底有什么特别之处呢?下面来对它的内部结构进行分析!

2.3 GRU的内部结构

不同于LSTM有3个门控,GRU仅有2个门控,

第一个是“重置门”(reset gate),其根据当前时刻的输入xt和上一时刻的隐状态ht-1变换后经sigmoid函数输出介于0和1之间的数字,用于将上一时刻隐状态ht-1重置为ht-1’,即ht-1’=ht-1*r。

再将ht-1’与输入xt进行拼接,再通过一个tanh激活函数来将数据放缩到-1~1的范围内。即得到如下图2-3所示的h’。

第一个是“更新门”(update gate),其根据当前时刻的输入xt和上一时刻的隐状态ht-1变换后经sigmoid函数输出介于0和1之间的数字,

最终的隐状态ht的更新表达式即为:

再次强调一下,门控信号(这里的z)的范围为0~1。门控信号越接近1,代表”记忆“下来的数据越多;而越接近0则代表”遗忘“的越多。

2.4 小结

GRU很聪明的一点就在于,使用了同一个门控z就同时可以进行遗忘和选择记忆(LSTM则要使用多个门控)。与LSTM相比,GRU内部少了一个”门控“,参数比LSTM少,但是却也能够达到与LSTM相当的功能。考虑到硬件的计算能力时间成本,因而很多时候我们也就会选择更加”实用“的GRU。

3代码

import torch
import torch.nn as nndef my_gru(input,initial_states,w_ih,w_hh,b_ih,b_hh):h_prev=initial_statesbatch_size,T_seq,feature_size=input.shapehidden_size=w_ih.shape[0]//3batch_w_ih=w_ih.unsqueeze(0).tile(batch_size,1,1)batch_w_hh=w_hh.unsqueeze(0).tile(batch_size,1,1)output=torch.zeros(batch_size,T_seq,hidden_size)for t in range(T_seq):x=input[:,t,:]w_times_x=torch.bmm(batch_w_ih,x.unsqueeze(-1))w_times_x=w_times_x.squeeze(-1)# print(batch_w_hh.shape,h_prev.shape)# 计算两个tensor的矩阵乘法,torch.bmm(a,b),tensor a 的size为(b,h,w),tensor b的size为(b,w,m)# 也就是说两个tensor的第一维是相等的,然后第一个数组的第三维和第二个数组的第二维度要求一样,# 对于剩下的则不做要求,输出维度 (b,h,m)# batch_w_hh=batch_size*(3*hidden_size)*hidden_size# h_prev=batch_size*hidden_size*1# w_times_x=batch_size*hidden_size*1##squeeze,在给定维度(维度值必须为1)上压缩维度,负数代表从后开始数w_times_h_prev=torch.bmm(batch_w_hh,h_prev.unsqueeze(-1))w_times_h_prev=w_times_h_prev.squeeze(-1)r_t=torch.sigmoid(w_times_x[:,:hidden_size]+w_times_h_prev[:,:hidden_size]+b_ih[:hidden_size]+b_hh[:hidden_size])z_t=torch.sigmoid(w_times_x[:,hidden_size:2*hidden_size]+w_times_h_prev[:,hidden_size:2*hidden_size]+b_ih[hidden_size:2*hidden_size]+b_hh[hidden_size:2*hidden_size])n_t=torch.tanh(w_times_x[:,2*hidden_size:3*hidden_size]+w_times_h_prev[:,2*hidden_size:3*hidden_size]+b_ih[2*hidden_size:3*hidden_size]+b_hh[2*hidden_size:3*hidden_size])h_prev=(1-z_t)*n_t+z_t*h_prevoutput[:,t,:]=h_prevreturn output,h_previf __name__=="__main__":fc=nn.Linear(12,6)batch_size=2T_seq=5feature_size=4hidden_size=3# output_feature_size=3input=torch.randn(batch_size,T_seq,feature_size)h_prev=torch.randn(batch_size,hidden_size)gru_layer=nn.GRU(feature_size,hidden_size,batch_first=True)output,h_final=gru_layer(input,h_prev.unsqueeze(0))# for k,v in gru_layer.named_parameters():#     print(k,v.shape)# print(output,h_final)my_output, my_h_final=my_gru(input,h_prev,gru_layer.weight_ih_l0,gru_layer.weight_hh_l0,gru_layer.bias_ih_l0,gru_layer.bias_hh_l0)# print(my_output, my_h_final)# print(torch.allclose(output,my_output))

参考资料

https://zhuanlan.zhihu.com/p/32481747

https://speech.ee.ntu.edu.tw/~tlkagk/courses/MLDS_2018/Lecture/Seq%20(v2).pdf

https://www.bilibili.com/video/BV1jm4y1Q7uh/?spm_id_from=333.788&vd_source=cf7630d31a6ad93edecfb6c5d361c659

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/888912.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

pytest中使用conftest做测试前置和参数化

pytest中比较高阶的应用是,使用conftest去做测试前置工作、测试收尾工作和参数化。conftest是pytest的一个组件,用于配置测试环境和参数。通过conftest, 可以创建一个可复用的测试配置文件,以便在多个测试模块之间共享配置信息。即&#xff0…

04 创建一个属于爬虫的主虚拟环境

文章目录 回顾conda常用指令创建一个爬虫虚拟主环境Win R 调出终端查看当前conda的虚拟环境创建 spider_base 的虚拟环境安装完成查看环境是否存在 为 pycharm 配置创建的爬虫主虚拟环境选一个盘符来存储之后学习所写的爬虫文件用 pycharm 打开创建的文件夹pycharm 配置解释器…

mvn test 失败,单独运行单元测试成功

标题mvn test 失败,单独运行单元测试成功 使用junit4进行单元测试时是通过的,但是在执行maven的test与package时测试不通过 报错信息: parse data from Nacos error,dataId:guoyu-new-asset-dev.yml,data: ....... 配置文件内容 ....... o…

基于gitlab API刷新MR的commit的指定status

场景介绍 自己部署的gitlab Jenkins,并已经设置好联动(如何设置可以在网上很容易搜到)每个MergeRequest都可以触发多个Jenkins pipeline,pipeline结束后会将状态更新到gitlab这个MR上希望可以跳过pipeline运行,直接将指定的MR的指定pipeline状态刷新为…

android 富文本及展示更多组件

模拟微博 #热贴 和 用户 的这种 富文本形式组件,不说了, 直接上代码 package com.tongtong.feat_watch.viewimport android.content.Context import android.graphics.Color import android.util.AttributeSet import android.view.LayoutInflater impo…

YourPHPCMS checkEmail SQL注入漏洞复现

0x01 产品简介 Yourphp企业网站管理系统是一款完全开源免费的PHP+MYSQL系统。核心采用了ThinkPHP框架,同时也作为开源软件发布。 集众多开源项目于一身的特点,使本系统从安全,效率,易用及可扩展性上更加突出。程序内置

gitlab 生成并设置 ssh key

一、介绍 🎯 本文主要介绍 SSH Key 的生成方法,以及如何在GitLab上添加SSH Key。GitLab 使用SSH协议与Git 进行安全通信。当您使用 SSH密钥 对 GitLab远程服务器进行身份验证时,您不需要每次都提供您的用户名和密码。SSH使用两个密钥&#x…

保姆级教程Docker部署Nacos镜像

目录 1、创建挂载目录 2、拉取 Nacos 镜像 3、临时启动并复制文件 4、创建Nacos表结构 5、修改Nacos配置 6、正式启动 Nacos 7、登录Nacos 1、创建挂载目录 在宿主机上创建一个目录用于配置文件映射,这个目录将作为数据卷挂载到容器内部,使得我…

Python+onlyoffice 实现在线word编辑

onlyoffice部署 version: "3" services:onlyoffice:image: onlyoffice/documentserver:7.5.1container_name: onlyofficerestart: alwaysenvironment:- JWT_ENABLEDfalse#- USE_UNAUTHORIZED_STORAGEtrue#- ONLYOFFICE_HTTPS_HSTS_ENABLEDfalseports:- "8080:8…

【北京迅为】iTOP-4412全能版使用手册-第六十七章 USB鼠标驱动详解

iTOP-4412全能版采用四核Cortex-A9,主频为1.4GHz-1.6GHz,配备S5M8767 电源管理,集成USB HUB,选用高品质板对板连接器稳定可靠,大厂生产,做工精良。接口一应俱全,开发更简单,搭载全网通4G、支持WIFI、蓝牙、…

【银河麒麟操作系统真实案例分享】内存黑洞导致服务器卡死分析全过程

了解更多银河麒麟操作系统全新产品,请点击访问 麒麟软件产品专区:https://product.kylinos.cn 开发者专区:https://developer.kylinos.cn 文档中心:https://documentkylinos.cn 现象描述 机房显示器连接服务器后黑屏&#xff…

docker系统详解哟 以及相关命令 Centos Kali安装相关详解 Docker-Compose 亲测

目录 who Is Docker 概念 centos7 安装docker kali安装docker docker安装nginx Docker常用命令 容器得常用命令 Docker-Compose install 常用docker-compose命令 who Is Docker 软件的打包技术,就是将算乱的多个文件打包为一个整体,打包技术在没…

Java项目实战II基于微信小程序的旅游社交平台(开发文档+数据库+源码)

目录 一、前言 二、技术介绍 三、系统实现 四、核心代码 五、源码获取 全栈码农以及毕业设计实战开发,CSDN平台Java领域新星创作者,专注于大学生项目实战开发、讲解和毕业答疑辅导。 一、前言 随着移动互联网的迅猛发展,旅游已经成为人…

Windows 和 Linux 系统命令行操作详解:从文件管理到进程监控

1.切换盘符与目录操作 在命令行中,切换盘符和目录是最常见的操作。尽管 DOS 和 Linux 在这些操作上有所不同,但它们都能实现相似的功能。 (1)切换盘符 ①DOS命令:在 DOS 中,切换盘符非常简单,使用 盘符名:&#xff…

【数据库】关系代数和SQL语句

一 对于教学数据库的三个基本表 学生S(S#,SNAME,AGE,SEX) 学习SC(S#,C#,GRADE) 课程(C#,CNAME,TEACHER) (1)试用关系代数表达式和SQL语句表示:检索WANG同学不学的课程号 select C# from C where C# not in(select C# from SCwhere S# in…

IS-IS二

目录 ISIS建立邻接关系的基本条件: 1、接口链路类型一致 2、广播型链路上,接口类型一致 3、Hello包级别和类型一致 4、L1区域的ID要一致,L2的邻居区域ID不做要求 5、L1-2在区域ID相同下,即建立L1也建立L2区域ID不同只能建立…

echarts全屏,vue

echarts实现全屏并且不失真&#xff0c;全屏图片需要自己换 html&#xff1a; <!-- 图表全屏盒子 --> <div style"position: relative;" ref"charts_orders"><!-- 图表 --><div class"chart_box" v-show"sho…

杂谈随笔-关于unity开发游戏

最近有在做unity的游戏开发&#xff0c;都是自学&#xff0c;甚至没有完整的课程体系…… 在犹豫要不要出系列教程&#xff0c;帮助新手快速入门的同时算是巩固一下基础知识。 那这篇文章先谈谈我对于引擎开发游戏的一些小观点&#xff0c;算是做了这么十几个星期的微不足道的…

️ 在 Windows WSL 上部署 Ollama 和大语言模型的完整指南20241206

&#x1f6e0;️ 在 Windows WSL 上部署 Ollama 和大语言模型的完整指南 &#x1f4dd; 引言 随着大语言模型&#xff08;LLM&#xff09;和人工智能的飞速发展&#xff0c;越来越多的开发者尝试在本地环境中部署大模型进行实验。然而&#xff0c;由于资源需求高、网络限制多…

[光源控制] UI调节光源亮度参数失效

📢博客主页:https://loewen.blog.csdn.net📢欢迎点赞 👍 收藏 ⭐留言 📝 如有错误敬请指正!📢本文由 丶布布原创,首发于 CSDN,转载注明出处🙉📢现在的付出,都会是一种沉淀,只为让你成为更好的人✨文章预览: 一. 前言二. 串口调试助手辅助排查接线问题二. …