循环神经网络的变体模型-LSTM、GRU

一.LSTM(长短时记忆网络)

1.1基本介绍

长短时记忆网络(Long Short-Term Memory,LSTM)是一种深度学习模型,属于循环神经网络(Recurrent Neural Network,RNN)的一种变体。LSTM的设计旨在解决传统RNN中遇到的长序列依赖问题,以更好地捕捉和处理序列数据中的长期依赖关系。

下面是LSTM的内部结构图

LSTM

LSTM为了改善梯度消失,引入了一种特殊的存储单元,该存储单元被设计用于存储和提取长期记忆。与传统的RNN不同,LSTM包含三个关键的门(gate)来控制信息的流动,这些门分别是遗忘门(Forget Gate)、输入门(Input Gate)和输出门(Output Gate)。

LSTM的结构允许它有效地处理和学习序列中的长期依赖关系,这在许多任务中很有用,如自然语言处理、语音识别和时间序列预测。由于其能捕获长期记忆,LSTM成为深度学习中重要的组件之一。

1.2 主要组成部分和工作原理

首先我们先弄明白LSTM单元中的每个符号的含义。每个黄色方框表示一个神经网络层,由权值,偏置以及激活函数组成;每个粉色圆圈表示元素级别操作;箭头表示向量流向;相交的箭头表示向量的拼接;分叉的箭头表示向量的复制。
图中元素的节点信息

以下是LSTM的主要组成部分和工作原理:

  1. 细胞状态(Cell State):
    细胞状态是LSTM网络的主要存储单元,用于存储和传递长期记忆。细胞状态在序列的每一步都会被更新。在LSTM中,细胞状态负责保留网络需要记住的信息,以便更好地处理长期依赖关系。在每个时间步,LSTM通过一系列的操作来更新细胞状态。这些操作包括遗忘门、输入门和输出门的计算。细胞状态在这些门的帮助下动态地保留和遗忘信息。
    细胞状态

  2. 遗忘门(Forget Gate):
    遗忘门决定哪些信息应该被遗忘,从而允许网络丢弃不重要的信息。它通过一个sigmoid激活函数生成一个介于0和1之间的值,用于控制细胞状态中信息的丢失程度。
    遗忘门的计算过程如下:
    2.1 输入:
    上一时刻的隐藏状态(或者是输入数据的向量)
    当前时刻的输入数据
    2.2 计算遗忘门的值:
    将上一时刻的隐藏状态和当前时刻的输入数据拼接在一起。
    通过一个带有sigmoid激活函数的全连接层(通常称为遗忘门层)得到介于0和1之间的值。
    这个值表示细胞状态中哪些信息应该被保留(接近1),哪些信息应该被遗忘(接近0)。
    2.3 遗忘操作:
    将上一时刻的细胞状态与遗忘门的输出相乘,以决定保留哪些信息。
    2.4数学表达式如下:
    遗忘门的输出:
    遗忘门

其中:
W f 和 b f 是遗忘门的权重矩阵和偏置向量。 W_f 和 b_f是遗忘门的权重矩阵和偏置向量。 Wfbf是遗忘门的权重矩阵和偏置向量。
h t − 1 ​是上一时刻的隐藏状态。 h_{t−1}​ 是上一时刻的隐藏状态。 ht1是上一时刻的隐藏状态。
x t 是当前时刻的输入数据。 x_t是当前时刻的输入数据。 xt是当前时刻的输入数据。
σ 是 s i g m o i d 激活函数。 σ 是sigmoid激活函数。 σsigmoid激活函数。

遗忘门的输出 ft 决定了细胞状态中上一时刻信息的保留程度。这个机制允许LSTM网络在处理时间序列数据时更有效地记住长期依赖关系。

  1. 输入门(Input Gate):
    输入门负责确定在当前时间步骤中要添加到细胞状态的新信息。类似于遗忘门,输入门使用sigmoid激活函数产生一个介于0和1之间的值,表示要保留多少新信息,并使用tanh激活函数生成一个新的候选值。
    在这里插入图片描述输入门的计算过程如下:
(1)输入门的输出计算:将上一时刻的隐藏状态(或者是输入数据)和当前时刻的输入数据拼接在一起。通过一个带有sigmoid激活函数的全连接层得到介于0和1之间的值。这个值表示要保留的新信息的程度。
(2)生成新的候选值:将上一时刻的隐藏状态(或者是输入数据)和当前时刻的输入数据拼接在一起。通过一个带有tanh激活函数的全连接层得到一个新的候选值(介于-1和1之间)。
(3)更新细胞状态的操作:将输入门的输出与新的候选值相乘,得到要添加到细胞状态的新信息。
  1. 输出门(Output Gate):
    输出门(Output Gate)在LSTM中控制细胞在特定时间步上的输出。输出门使用sigmoid激活函数产生介于0和1之间的值,这个值决定了在当前时间步细胞状态中有多少信息被输出。同时,输出门的输出与细胞状态经过tanh激活函数后的值相乘,产生最终的LSTM输出。

输出门的计算过程如下:

输出门的输出计算:将上一时刻的隐藏状态(或者是输入数据)和当前时刻的输入数据拼接在一起。通过一个带有sigmoid激活函数的全连接层得到介于0和1之间的值。这个值表示在当前时间步细胞状态中有多少信息要输出。
生成最终的LSTM输出:将当前时刻的细胞状态经过tanh激活函数,得到介于-1和1之间的值。将输出门的输出与tanh激活函数的细胞状态相乘,产生最终的LSTM输出。

在这里插入图片描述

1.3 LSTM的基础代码实现

以下是一个基础的实现,其中包括多层双向LSTM的前向传播。请注意,这个实现仍然是一个简化版本,实际应用中可能需要更多的调整和优化。

import numpy as npdef sigmoid(x):return 1 / (1 + np.exp(-x))def tanh(x):return np.tanh(x)def lstm_cell(xt, a_prev, c_prev, parameters):# 从参数中提取权重和偏置Wf = parameters["Wf"]bf = parameters["bf"]Wi = parameters["Wi"]bi = parameters["bi"]Wo = parameters["Wo"]bo = parameters["bo"]Wc = parameters["Wc"]bc = parameters["bc"]# 合并输入和上一个时间步的隐藏状态concat = np.concatenate((a_prev, xt), axis=0)# 遗忘门ft = sigmoid(np.dot(Wf, concat) + bf)# 输入门it = sigmoid(np.dot(Wi, concat) + bi)# 更新细胞状态cct = tanh(np.dot(Wc, concat) + bc)c_next = ft * c_prev + it * cct# 输出门ot = sigmoid(np.dot(Wo, concat) + bo)# 更新隐藏状态a_next = ot * tanh(c_next)# 保存计算中间结果,以便反向传播cache = (xt, a_prev, c_prev, a_next, c_next, ft, it, ot, cct)return a_next, c_next, cachedef lstm_forward(x, a0, parameters):n_x, m, T_x = x.shapen_a = a0.shape[0]a = np.zeros((n_a, m, T_x))c = np.zeros_like(a)caches = []a_prev = a0c_prev = np.zeros_like(a_prev)for t in range(T_x):xt = x[:, :, t]a_next, c_next, cache = lstm_cell(xt, a_prev, c_prev, parameters)a[:,:,t] = a_nextc[:,:,t] = c_nextcaches.append(cache)a_prev = a_nextc_prev = c_nextreturn a, c, cachesdef lstm_model_forward(x, parameters):caches = []a = xc_list = []for layer in parameters:a, c, layer_cache = lstm_forward(a, np.zeros_like(a[:, :, 0]), layer)caches.append(layer_cache)c_list.append(c)return a, c_list, cachesdef dense_layer_forward(a, parameters):W = parameters["W"]b = parameters["b"]z = np.dot(W, a) + ba_next = sigmoid(z)return a_next, zdef model_forward(x, parameters_lstm, parameters_dense):a_lstm, c_list, caches_lstm = lstm_model_forward(x, parameters_lstm)a_dense = a_lstm[:, :, -1]z_dense_list = []for layer_dense in parameters_dense:a_dense, z_dense = dense_layer_forward(a_dense, layer_dense)z_dense_list.append(z_dense)return a_dense, c_list, caches_lstm, z_dense_list# 示例数据和参数
np.random.seed(1)
x = np.random.randn(10, 5, 3)  # 10个样本,每个样本5个时间步,每个时间步3个特征# LSTM参数
parameters_lstm = [{"Wf": np.random.randn(5, 8), "bf": np.random.randn(5, 1),"Wi": np.random.randn(5, 8), "bi": np.random.randn(5, 1),"Wo": np.random.randn(5, 8), "bo": np.random.randn(5, 1),"Wc": np.random.randn(5, 8), "bc": np.random.randn(5, 1)},{"Wf": np.random.randn(3, 8), "bf": np.random.randn(3, 1),"Wi": np.random.randn(3, 8), "bi": np.random.randn(3, 1),"Wo": np.random.randn(3, 8), "bo": np.random.randn(3, 1),"Wc": np.random.randn(3, 8), "bc": np.random.randn(3, 1)}
]# Dense层参数
parameters_dense = [{"W": np.random.randn(1, 5), "b": np.random.randn(1, 1)},{"W": np.random.randn(1, 5), "b": np.random.randn(1, 1)}
]# 进行正向传播
a_dense, c_list, caches_lstm, z_dense_list = model_forward(x, parameters_lstm, parameters_dense)# 打印输出形状
print("a_dense.shape:", a_dense.shape)

二.GRU(门控循环单元)

GRU

2.1 GRU的基本介绍

门控循环单元(GRU,Gated Recurrent Unit)是一种用于处理序列数据的循环神经网络(RNN)变体,旨在解决传统RNN中的梯度消失问题,并提供更好的长期依赖建模。GRU引入了门控机制,类似于LSTM,但相对于LSTM,GRU结构更加简单。

GRU包含两个门:更新门(Update Gate)和重置门(Reset Gate)。这两个门允许GRU网络决定在当前时间步更新细胞状态的程度以及如何利用先前的隐藏状态。

重置门(Reset Gate)的计算:

通过一个sigmoid激活函数计算重置门的输出。重置门决定了在当前时间步,应该忽略多少先前的隐藏状态信息。

更新门(Update Gate)的计算:

通过一个sigmoid激活函数计算更新门的输出。更新门决定了在当前时间步,应该保留多少先前的隐藏状态信息。

候选隐藏状态的计算:

通过tanh激活函数计算一个候选的隐藏状态。

新的隐藏状态的计算:

通过更新门和候选隐藏状态计算新的隐藏状态。

2.2 GRU的代码实现

以下是使用PyTorch库实现基本的门控循环单元(GRU)的代码。PyTorch提供了GRU的高级API,可以轻松实现和使用。下面是一个简单的例子:

import torch
import torch.nn as nn# 定义GRU模型
class SimpleGRU(nn.Module):def __init__(self, input_size, hidden_size):super(SimpleGRU, self).__init__()self.gru = nn.GRU(input_size, hidden_size)def forward(self, x, hidden=None):output, hidden = self.gru(x, hidden)return output, hidden# 示例数据和模型参数
input_size = 3
hidden_size = 5
seq_len = 1  # 序列长度
batch_size = 1# 创建GRU模型
gru_model = SimpleGRU(input_size, hidden_size)# 将输入数据转换为PyTorch的Tensor
x = torch.randn(seq_len, batch_size, input_size)# 前向传播
output, hidden = gru_model(x)# 打印输出形状
print("Output shape:", output.shape)
print("Hidden shape:", hidden.shape)

以下是使用NumPy库实现基本的门控循环单元(GRU)的代码。这个实现是一个简化版本,其中包含更新门和重置门的计算,以及候选隐藏状态和新的隐藏状态的计算。

import numpy as npdef sigmoid(x):return 1 / (1 + np.exp(-x))def tanh(x):return np.tanh(x)def gru_cell(a_prev, x, parameters):# 从参数中提取权重和偏置W_r = parameters["W_r"]b_r = parameters["b_r"]W_z = parameters["W_z"]b_z = parameters["b_z"]W_a = parameters["W_a"]b_a = parameters["b_a"]# 计算重置门r_t = sigmoid(np.dot(W_r, np.concatenate([a_prev, x])) + b_r)# 计算更新门z_t = sigmoid(np.dot(W_z, np.concatenate([a_prev, x])) + b_z)# 计算候选隐藏状态tilde_a_t = tanh(np.dot(W_a, np.concatenate([r_t * a_prev, x])) + b_a)# 计算新的隐藏状态a_t = (1 - z_t) * a_prev + z_t * tilde_a_t# 保存计算中间结果,以便反向传播cache = (a_prev, x, r_t, z_t, tilde_a_t, a_t)return a_t, cache# 示例数据和参数
np.random.seed(1)
a_prev = np.random.randn(5, 1)  # 上一时刻的隐藏状态
x = np.random.randn(3, 1)  # 当前时刻的输入数据# GRU参数
parameters = {"W_r": np.random.randn(5, 8),"b_r": np.random.randn(5, 1),"W_z": np.random.randn(5, 8),"b_z": np.random.randn(5, 1),"W_a": np.random.randn(5, 8),"b_a": np.random.randn(5, 1)
}# 单个GRU单元的前向传播
a_t, cache = gru_cell(a_prev, x, parameters)# 打印输出形状
print("a_t.shape:", a_t.shape)

本文参考了以下链接:http://colah.github.io/posts/2015-08-Understanding-LSTMs/

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/630306.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Android NDK Crash信息收集捕获和日志异常定位分析(addr2line)

Android NDK 闪退日志收集与分析 我们在开发过程中,Android JNI层Crash问题或者我们引用的第三方.so库文件报错,都是一个比较头疼的问题。相对Java层来说,由于c/c++造成的crash没有输出如同Java的Exception Strace堆栈信息,所以定位问题也是个比较艰难的事情。 Google Br…

HCIA的路由协议

动态路由协议/静态路由协议 静态路由协议和动态路由协议的区别: 静态路由协议的缺点: 配置繁琐 针对拓扑的变化不能够自动收敛 只适用于小型网络 静态路由协议优点: 占用资源少 安全 稳定 动态路由协议的优点: 配置简单 针对拓…

前端项目配置 Dockerfile 打包后镜像部署无法访问

Dockerfile 配置如下: FROM node:lts-alpineWORKDIR /app COPY . . RUN npm install RUN npm run buildEXPOSE 3001CMD ["npm", "run", "preview"]构建镜像 docker build -t vite-clarity-project .启动镜像容器 docker run -p 30…

进程(一) 进程概念

文章目录 什么是进程呢? 描述进程-PCBtask_struct-PCB的一种task_struct内容分类 查看进程通过系统目录查看通过ps命令查看通过系统调用获取进程的PID和PPID通过系统调用创建进程- fork()函数 fork()函数fork函数做了什么?fork之后…

centos 编译升级内核

一.离线编译并升级内核 1.下载内核 https://mirrors.ustc.edu.cn/centos-vault/7.9.2009/updates/Source/SPackages/ kernel-3.10.0-1160.105.1.el7.src.rpm 2.解压内核 (1)安装rpmrebuild yum install rpmrebuild; (2&#xf…

Vue加载序列帧动图

解读方法 使用<img :src"currentFrame" alt"加载中" /> 加载图片动态更改src的值使用 requestAnimationFrame 定时更新在需要的页面调用封装的组件 <LoadToast v-if"showLoading" /> 封装组件 <template><div class"…

CTF CRYPTO 密码学-1

题目名称&#xff1a;enc 题目描述&#xff1a; 压缩包中含两个文件&#xff1a;一个秘钥d.dec&#xff0c;一个密文flag.enc 解题过程&#xff1a; Step1&#xff1a;这题是一个解密他题目&#xff0c;尝试openssl去ras解密 工具简介 在Kali Linux系统中&#xff0c;openss…

React16源码: React中的异步调度scheduler模块的源码实现

React Scheduler 1 ) 概述 react当中的异步调度&#xff0c;称为 React Scheduler发布成单独的一个 npm 包就叫做 scheduler这个包它做了什么&#xff1f; A. 首先它维护时间片B. 然后模拟 requestIdleCallback 这个API 因为现在浏览器的支持不是特别的多所以在浏览当中只是去…

【计算机图形学】习题课:Viewing

【计算机图形学】Viewing 部分问题与解答 CS100433 Computer Graphics Assignment 21 Proof the composed transformations defined in global coordinate frame is equivalent to the composed transformations defined in local coordinate frame but in different composing…

1月14-17日为技术测试期!字节与腾讯上演“大和解”,抖音全面开放《王者荣耀》直播

综合整理&#xff5c;TesterHome社区 来源&#xff5c;《王者荣耀》官方、界面新闻 北京商报、IT之家 1月13日&#xff0c;腾讯游戏《王者荣耀》官方微博发布消息宣布&#xff0c;从1月21日起&#xff0c;《王者荣耀》抖音直播将全面开放。 为了筛查开播期间可能遇到的所有技…

几何_直线方程 Ax + By + C = 0 的系数A,B,C几何含义是?

参考&#xff1a; 直线方程 Ax By C 0 的系数A&#xff0c;B&#xff0c;C有什么几何含义&#xff1f;_设直线 l 的方程为axbyc0 怎么理解-CSDN博客 1. A B的含义&#xff1a;组成一个与直线垂直的向量 我们先来看A和B有什么含义。 在直线上取任意两点 P1:&#xff08;x1…

OceanBase集群部署

我认为学习一个中间件比较好的方式是&#xff0c;先了解它的架构和运行原理&#xff0c;然后动手部署一遍&#xff0c;加深对它的了解&#xff0c;再使用它&#xff0c;最后进行总结和分享 本篇介绍OceanBase部署前提配置和集群部署 1.使用开源免费的社区版&#xff0c;企业版…

MetaGPT task1学习

基础知识学习了解&#xff1a; 安装环境&#xff1a; 获取MetaGPT 使用pip获取MetaGPT pip install -i https://pypi.tuna.tsinghua.edu.cn/simple metagpt0.5.2 配置MetaGPT 完成MetaGPT后&#xff0c;我们还需要完成一些配置才能开始使用这个强力的框架&#xff0c;包括配…

RHCE: web服务器+nfs服务器搭建

网站需求&#xff1a; 1.基于域名[www.openlab.com](http://www.openlab.com)可以访问网站内容为 welcome to openlab!!! web服务器准备工作&#xff1a; #添加多ip(也可以用nmtui命令在图形界面配置) [root192 ~]# nmcli connection modify ens32 ipv4.method manual ipv4.…

[面试题~] Golang

1. 逃逸分析 1.1 逃逸分析是什么&#xff1f; 在编译原理中&#xff0c;分析指针动态范围的方法称之为逃逸分析。在Go中的表现是&#xff0c;如果一个对象的指针被多个方法或线程引用时&#xff0c;则称这个指针发生了逃逸。 所以&#xff0c;我认为逃逸分析指的是&#xff0…

[leetcode~数位动态规划] 2719. 统计整数数目 hard

给你两个数字字符串 num1 和 num2 &#xff0c;以及两个整数 max_sum 和 min_sum 。如果一个整数 x 满足以下条件&#xff0c;我们称它是一个好整数&#xff1a; num1 < x < num2 min_sum < digit_sum(x) < max_sum. 请你返回好整数的数目。答案可能很大&#xff…

极客时间-《左耳听风》文章笔记 + 个人思考

极客时间-《左耳听风》文章笔记 个人思考 分布式架构21 | 分布式系统架构的冰与火 分布式架构 21 | 分布式系统架构的冰与火 比较流行的高并发框架&#xff1a; Node.js&#xff1a;是一个基于Chrome V8引擎的JavaScript运行环境&#xff0c;它使用事件驱动、非阻塞I/O模型…

2024.1.17

今天我已经回家了&#xff0c;感觉家就像我的温柔乡一样&#xff0c;一到了家&#xff0c;就不想学习了&#xff0c;这是很不对的事情&#xff0c;不该如此堕落&#xff0c;还是要像在学校一样该干什么干什么&#xff0c;所以说还是复习和写了一下曾经写过的代码。 #define _C…

[Android] Android架构体系(1)

文章目录 Android 的框架Dalvik 虚拟机JNI原生二进制可执行文件Android NDK中的binutils Bionic谷歌考虑到的版权问题Bionic与传统的C标准库&#xff08;如glibc&#xff09;的一些不同 参考 Android 的框架 Android 取得成功的关键因素之一就是它丰富的框架集。 没有这些框架…

第2章 法律法规

文章目录 2.1.1 《电信条例》概述2.1.2 《电信条例》关于电信市场的规定1、电信业务许可规定2、电信资费规定3、电信资费规定 2.1.3《电信条例》关于电信服务的规定1、电信业务经营者的义务2、电信用户的相关义务3、电信用户申诉及受理的规定4、电信业务经营者不正当行为的规定…