跟着问题学15——GRU网络结构详解及代码实战

   

1 RNN的缺陷——长期依赖的问题 (The Problem of Long-Term Dependencies)

前面一节我们学习了RNN神经网络,它可以用来处理序列型的数据,比如一段文字,视频等等。RNN网络的基本单元如下图所示,可以将前面的状态作为当前状态的输入。

但也有一些情况,我们需要更“长期”的上下文信息。比如预测最后一个单词“我在中国长大……我说一口流利的**。”“短期”的信息显示,下一个单词很可能是一种语言的名字,但如果我们想缩小范围,我们需要更长期语境——“我在中国长大”,但这个相关信息与需要它的点之间的距离完全有可能变得非常大。

不幸的是,随着这种距离的扩大,RNN无法学会连接这些信息。

从理论上讲,RNN绝对有能力处理这种“长期依赖性”。人们可以为他们精心选择参数,以解决这种形式的问题。遗憾的是,在实践中,RNN似乎无法学习它们。

幸运的是,GRU也没有这个问题!

2、GRU

什么是GRU

GRU(Gate Recurrent Unit)是循环神经网络(Recurrent Neural Network, RNN)的一种。和LSTM(Long-Short Term Memory)一样,也是为了解决长期记忆和反向传播中的梯度等问题而提出来的。

GRU和LSTM在很多情况下实际表现上相差无几,那么为什么我们要使用新人GRU(2014年提出)而不是相对经受了更多考验的LSTM(1997提出)呢。

用论文中的话说,相比LSTM,使用GRU能够达到相当的效果,并且相比之下更容易进行训练,能够很大程度上提高训练效率,因此很多时候会更倾向于使用GRU。

2.1总体结构框架

前面我们讲到,神经网络的各种结构都是为了挖掘变换数据特征的,所以下面我们也将结合数据特征的维度来对比介绍一下RNN&&LSTM的网络结构。

多层感知机(线性连接层)结构

从特征角度考虑:

输入特征:是n*1的单维向量(这也是为什么卷积神经网络在linear层前要把所有特征层展平),

隐藏层:然后根据隐藏层神经元的数量m将前层输入的特征用m*1的单维向量进行表示(对特征进行了提取变换,隐藏层的数据特征),单个隐藏层的神经元数量就代表网络参数,可以设置多个隐藏层;

输出特征:最终根据输出层的神经元数量y输出y*1的单维向量。

卷积神经网络结构

 从特征角度考虑:

输入特征:是(batch)*channel*width*height的张量,

卷积层(等):然后根据输入通道channel的数量c_in和输出通道channel的数量c_out会有c_out*c_in*k*k个卷积核将前层输入的特征进行卷积(对特征进行了提取变换,k为卷积核尺寸),卷积核的大小和数量c_out*c_in*k*k就代表网络参数,可以设置多个卷积层;每一个channel都代表提取某方面的一种特征,该特征用width*height的二维张量表示,不同特征层之间是相互独立的(可以进行融合)。

输出特征:根据场景的需要设置后面的输出,可以是多分类的单维向量等等。

循环神经网络RNN系列结构

从特征角度考虑:

输入特征:是(batch)*T_seq*feature_size的张量(T_seq代表序列长度,注意不是batch_size).

我们来详细对比一下卷积神经网络的输入特征,

(batch)*T_seq*feature_size

(batch)*channel*width*height,

逐个进行分析,RNN系列的基础输入特征表示是feature_size*1的单维向量,比如一个单词的词向量,比如一个股票价格的影响因素向量,而CNN系列的基础输入特征是width*height的二维张量;

再来看一下序列T_seq和通道channel,RNN系列的序列T_seq是指一个连续的输入,比如一句话,一周的股票信息,而且这个序列是有时间先后顺序且互相关联的,而CNN系列的通道channel则是指不同角度的特征,比如彩色图像的RGB三色通道,过程中每个通道代表提取了每个方面的特征,不同通道之间是没有强相关性的,不过也可以进行融合。

最后就是batch,两者都有,在RNN系列,batch就是有多个句子,在CNN系列,就是有多张图片(每个图片可以有多个通道)

隐藏层:明确了输入特征之后,我们再来看看隐藏层代表着什么。隐藏层有T_seq个隐状态H_t(和输入序列长度相同),每个隐状态H_t类似于一个channel,对应着T_seq中的t时刻的输入特征;而每个隐状态H_t是用hidden_size*1的单维向量表示的,所以一个隐含层是T_seq*hidden_size的张量;对应时刻t的输入特征由feature_size*1变为hidden_size*1的向量。如图中所示,同一个隐含层不同时刻的参数W_ih和W_hh是共享的;隐藏层可以有num_layers个(图中只有1个)

以t时刻具体阐述一下:

X_t是t时刻的输入,是一个feature_size*1的向量

W_ih是输入层到隐藏层的权重矩阵

H_t是t时刻的隐藏层的值,是一个hidden_size*1的向量

W_hh是上一时刻的隐藏层的值传入到下一时刻的隐藏层时的权重矩阵

Ot是t时刻RNN网络的输出

从上右图中可以看出这个RNN网络在t时刻接受了输入Xt之后,隐藏层的值是St,输出的值是Ot。但是从结构图中我们可以发现St并不单单只是由Xt决定,还与t-1时刻的隐藏层的值St-1有关。

2.2 GRU的输入输出结构

GRU的输入输出结构与普通的RNN是一样的。有一个当前的输入xt,和上一个节点传递下来的隐状态(hidden state)ht-1 ,这个隐状态包含了之前节点的相关信息。结合xt和 ht-1,GRU会得到当前隐藏节点的输出yt 和传递给下一个节点的隐状态 ht。

图 GRU的输入输出结构

那么,GRU到底有什么特别之处呢?下面来对它的内部结构进行分析!

2.3 GRU的内部结构

不同于LSTM有3个门控,GRU仅有2个门控,

第一个是“重置门”(reset gate),其根据当前时刻的输入xt和上一时刻的隐状态ht-1变换后经sigmoid函数输出介于0和1之间的数字,用于将上一时刻隐状态ht-1重置为ht-1’,即ht-1’=ht-1*r。

再将ht-1’与输入xt进行拼接,再通过一个tanh激活函数来将数据放缩到-1~1的范围内。即得到如下图2-3所示的h’。

第一个是“更新门”(update gate),其根据当前时刻的输入xt和上一时刻的隐状态ht-1变换后经sigmoid函数输出介于0和1之间的数字,

最终的隐状态ht的更新表达式即为:

再次强调一下,门控信号(这里的z)的范围为0~1。门控信号越接近1,代表”记忆“下来的数据越多;而越接近0则代表”遗忘“的越多。

2.4 小结

GRU很聪明的一点就在于,使用了同一个门控z就同时可以进行遗忘和选择记忆(LSTM则要使用多个门控)。与LSTM相比,GRU内部少了一个”门控“,参数比LSTM少,但是却也能够达到与LSTM相当的功能。考虑到硬件的计算能力时间成本,因而很多时候我们也就会选择更加”实用“的GRU。

3代码

import torch
import torch.nn as nndef my_gru(input,initial_states,w_ih,w_hh,b_ih,b_hh):h_prev=initial_statesbatch_size,T_seq,feature_size=input.shapehidden_size=w_ih.shape[0]//3batch_w_ih=w_ih.unsqueeze(0).tile(batch_size,1,1)batch_w_hh=w_hh.unsqueeze(0).tile(batch_size,1,1)output=torch.zeros(batch_size,T_seq,hidden_size)for t in range(T_seq):x=input[:,t,:]w_times_x=torch.bmm(batch_w_ih,x.unsqueeze(-1))w_times_x=w_times_x.squeeze(-1)# print(batch_w_hh.shape,h_prev.shape)# 计算两个tensor的矩阵乘法,torch.bmm(a,b),tensor a 的size为(b,h,w),tensor b的size为(b,w,m)# 也就是说两个tensor的第一维是相等的,然后第一个数组的第三维和第二个数组的第二维度要求一样,# 对于剩下的则不做要求,输出维度 (b,h,m)# batch_w_hh=batch_size*(3*hidden_size)*hidden_size# h_prev=batch_size*hidden_size*1# w_times_x=batch_size*hidden_size*1##squeeze,在给定维度(维度值必须为1)上压缩维度,负数代表从后开始数w_times_h_prev=torch.bmm(batch_w_hh,h_prev.unsqueeze(-1))w_times_h_prev=w_times_h_prev.squeeze(-1)r_t=torch.sigmoid(w_times_x[:,:hidden_size]+w_times_h_prev[:,:hidden_size]+b_ih[:hidden_size]+b_hh[:hidden_size])z_t=torch.sigmoid(w_times_x[:,hidden_size:2*hidden_size]+w_times_h_prev[:,hidden_size:2*hidden_size]+b_ih[hidden_size:2*hidden_size]+b_hh[hidden_size:2*hidden_size])n_t=torch.tanh(w_times_x[:,2*hidden_size:3*hidden_size]+w_times_h_prev[:,2*hidden_size:3*hidden_size]+b_ih[2*hidden_size:3*hidden_size]+b_hh[2*hidden_size:3*hidden_size])h_prev=(1-z_t)*n_t+z_t*h_prevoutput[:,t,:]=h_prevreturn output,h_previf __name__=="__main__":fc=nn.Linear(12,6)batch_size=2T_seq=5feature_size=4hidden_size=3# output_feature_size=3input=torch.randn(batch_size,T_seq,feature_size)h_prev=torch.randn(batch_size,hidden_size)gru_layer=nn.GRU(feature_size,hidden_size,batch_first=True)output,h_final=gru_layer(input,h_prev.unsqueeze(0))# for k,v in gru_layer.named_parameters():#     print(k,v.shape)# print(output,h_final)my_output, my_h_final=my_gru(input,h_prev,gru_layer.weight_ih_l0,gru_layer.weight_hh_l0,gru_layer.bias_ih_l0,gru_layer.bias_hh_l0)# print(my_output, my_h_final)# print(torch.allclose(output,my_output))

参考资料

https://zhuanlan.zhihu.com/p/32481747

https://speech.ee.ntu.edu.tw/~tlkagk/courses/MLDS_2018/Lecture/Seq%20(v2).pdf

https://www.bilibili.com/video/BV1jm4y1Q7uh/?spm_id_from=333.788&vd_source=cf7630d31a6ad93edecfb6c5d361c659

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/888912.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

pytest中使用conftest做测试前置和参数化

pytest中比较高阶的应用是,使用conftest去做测试前置工作、测试收尾工作和参数化。conftest是pytest的一个组件,用于配置测试环境和参数。通过conftest, 可以创建一个可复用的测试配置文件,以便在多个测试模块之间共享配置信息。即&#xff0…

04 创建一个属于爬虫的主虚拟环境

文章目录 回顾conda常用指令创建一个爬虫虚拟主环境Win R 调出终端查看当前conda的虚拟环境创建 spider_base 的虚拟环境安装完成查看环境是否存在 为 pycharm 配置创建的爬虫主虚拟环境选一个盘符来存储之后学习所写的爬虫文件用 pycharm 打开创建的文件夹pycharm 配置解释器…

mvn test 失败,单独运行单元测试成功

标题mvn test 失败,单独运行单元测试成功 使用junit4进行单元测试时是通过的,但是在执行maven的test与package时测试不通过 报错信息: parse data from Nacos error,dataId:guoyu-new-asset-dev.yml,data: ....... 配置文件内容 ....... o…

android 富文本及展示更多组件

模拟微博 #热贴 和 用户 的这种 富文本形式组件,不说了, 直接上代码 package com.tongtong.feat_watch.viewimport android.content.Context import android.graphics.Color import android.util.AttributeSet import android.view.LayoutInflater impo…

gitlab 生成并设置 ssh key

一、介绍 🎯 本文主要介绍 SSH Key 的生成方法,以及如何在GitLab上添加SSH Key。GitLab 使用SSH协议与Git 进行安全通信。当您使用 SSH密钥 对 GitLab远程服务器进行身份验证时,您不需要每次都提供您的用户名和密码。SSH使用两个密钥&#x…

保姆级教程Docker部署Nacos镜像

目录 1、创建挂载目录 2、拉取 Nacos 镜像 3、临时启动并复制文件 4、创建Nacos表结构 5、修改Nacos配置 6、正式启动 Nacos 7、登录Nacos 1、创建挂载目录 在宿主机上创建一个目录用于配置文件映射,这个目录将作为数据卷挂载到容器内部,使得我…

【北京迅为】iTOP-4412全能版使用手册-第六十七章 USB鼠标驱动详解

iTOP-4412全能版采用四核Cortex-A9,主频为1.4GHz-1.6GHz,配备S5M8767 电源管理,集成USB HUB,选用高品质板对板连接器稳定可靠,大厂生产,做工精良。接口一应俱全,开发更简单,搭载全网通4G、支持WIFI、蓝牙、…

【银河麒麟操作系统真实案例分享】内存黑洞导致服务器卡死分析全过程

了解更多银河麒麟操作系统全新产品,请点击访问 麒麟软件产品专区:https://product.kylinos.cn 开发者专区:https://developer.kylinos.cn 文档中心:https://documentkylinos.cn 现象描述 机房显示器连接服务器后黑屏&#xff…

Java项目实战II基于微信小程序的旅游社交平台(开发文档+数据库+源码)

目录 一、前言 二、技术介绍 三、系统实现 四、核心代码 五、源码获取 全栈码农以及毕业设计实战开发,CSDN平台Java领域新星创作者,专注于大学生项目实战开发、讲解和毕业答疑辅导。 一、前言 随着移动互联网的迅猛发展,旅游已经成为人…

【数据库】关系代数和SQL语句

一 对于教学数据库的三个基本表 学生S(S#,SNAME,AGE,SEX) 学习SC(S#,C#,GRADE) 课程(C#,CNAME,TEACHER) (1)试用关系代数表达式和SQL语句表示:检索WANG同学不学的课程号 select C# from C where C# not in(select C# from SCwhere S# in…

IS-IS二

目录 ISIS建立邻接关系的基本条件: 1、接口链路类型一致 2、广播型链路上,接口类型一致 3、Hello包级别和类型一致 4、L1区域的ID要一致,L2的邻居区域ID不做要求 5、L1-2在区域ID相同下,即建立L1也建立L2区域ID不同只能建立…

️ 在 Windows WSL 上部署 Ollama 和大语言模型的完整指南20241206

🛠️ 在 Windows WSL 上部署 Ollama 和大语言模型的完整指南 📝 引言 随着大语言模型(LLM)和人工智能的飞速发展,越来越多的开发者尝试在本地环境中部署大模型进行实验。然而,由于资源需求高、网络限制多…

设计模式の单例工厂原型模式

文章目录 前言一、单例模式1.1、饿汉式静态常量单例1.2、饿汉式静态代码块单例1.3、懒汉式单例(线程不安全)1.4、懒汉式单例(线程安全,同步代码块)1.5、懒汉式单例(线程不安全,同步代码块&#…

net.sf.jsqlparser.statement.select.SelectItem

今天一启动项目,出现了这个错误,仔细想了想,应该是昨天合并代码,导致的mybatis-plus版本冲突,以及分页PageHelper版本不兼容 可以看见这个我是最下边的Caused by 报错信息,这个地方提示我 net .sf.jsqlpar…

第427场周赛: 转换数组、用点构造面积最大的矩形 Ⅰ、长度可被 K 整除的子数组的最大元素和、用点构造面积最大的矩形 Ⅱ

Q1、转换数组 1、题目描述 给你一个整数数组 nums&#xff0c;它表示一个循环数组。请你遵循以下规则创建一个大小 相同 的新数组 result &#xff1a; 对于每个下标 i&#xff08;其中 0 < i < nums.length&#xff09;&#xff0c;独立执行以下操作&#xff1a; 如…

CV工程师专用键盘开源项目硬件分析

1、前言 作为一个电子发烧友&#xff0c;你是否有遇到过这样的问题呢。当我们去查看函数定义的时候&#xff0c;需要敲击鼠标右键之后选择go to definition。更高级一些&#xff0c;我们使用键盘的快捷键来查看定义&#xff0c;这时候可以想象一下&#xff0c;你左手按下ALT&a…

SpringBoot3配置文件

一、统一配置管理概述: SpringBoot工程下&#xff0c;进行统一的配置管理&#xff0c;你想设置的任何参数(端口号、项目根路径、数据库连接信息等等)都集中到一个固定位置和命名的配置文件(application.properties或application.yml)中 配置文件应该放置在Spring Boot工程的s…

【机器学习】任务十一:Keras 模块的使用

1.Keras简介 1.1 什么是Keras&#xff1f; Keras 是一个开源的深度学习框架&#xff0c;用 Python 编写&#xff0c;构建于 TensorFlow 之上。它以简单、快速和易于使用为主要设计目标&#xff0c;适合初学者和研究者。 Keras 提供了高层次的 API&#xff0c;帮助用户快速构…

【新品发布】ESP32-P4开发板 —— 启明智显匠心之作,为物联网及HMI产品注入强劲动力

核心亮点&#xff1a; ESP32-P4开发板&#xff0c;是启明智显精心打造的一款高性能物联网开发板。它专为物联网项目及HMI&#xff08;人机界面&#xff09;产品而设计&#xff0c;旨在为您提供卓越的性能和稳定可靠的运行体验。 强大硬件配置&#xff1a; 双核400MHz RISC-V处…

在Ubuntu22.04.5上安装Docker-CE

文章目录 1. 查看Ubuntu版本2. 安装Docker-CE2.1 安装必要的系统工具2.2 信任Docker的GPG公钥2.3 写入软件源信息2.4 安装Docker相关组件2.5 安装指定版本Docker-CE2.5.1 查找Docker-CE的版本2.5.2 安装指定版本Docker-CE 3. 启动与使用Docker3.1 启动Docker服务3.2 查看Docker…