循环神经网络的变体模型-LSTM、GRU

一.LSTM(长短时记忆网络)

1.1基本介绍

长短时记忆网络(Long Short-Term Memory,LSTM)是一种深度学习模型,属于循环神经网络(Recurrent Neural Network,RNN)的一种变体。LSTM的设计旨在解决传统RNN中遇到的长序列依赖问题,以更好地捕捉和处理序列数据中的长期依赖关系。

下面是LSTM的内部结构图

LSTM

LSTM为了改善梯度消失,引入了一种特殊的存储单元,该存储单元被设计用于存储和提取长期记忆。与传统的RNN不同,LSTM包含三个关键的门(gate)来控制信息的流动,这些门分别是遗忘门(Forget Gate)、输入门(Input Gate)和输出门(Output Gate)。

LSTM的结构允许它有效地处理和学习序列中的长期依赖关系,这在许多任务中很有用,如自然语言处理、语音识别和时间序列预测。由于其能捕获长期记忆,LSTM成为深度学习中重要的组件之一。

1.2 主要组成部分和工作原理

首先我们先弄明白LSTM单元中的每个符号的含义。每个黄色方框表示一个神经网络层,由权值,偏置以及激活函数组成;每个粉色圆圈表示元素级别操作;箭头表示向量流向;相交的箭头表示向量的拼接;分叉的箭头表示向量的复制。
图中元素的节点信息

以下是LSTM的主要组成部分和工作原理:

  1. 细胞状态(Cell State):
    细胞状态是LSTM网络的主要存储单元,用于存储和传递长期记忆。细胞状态在序列的每一步都会被更新。在LSTM中,细胞状态负责保留网络需要记住的信息,以便更好地处理长期依赖关系。在每个时间步,LSTM通过一系列的操作来更新细胞状态。这些操作包括遗忘门、输入门和输出门的计算。细胞状态在这些门的帮助下动态地保留和遗忘信息。
    细胞状态

  2. 遗忘门(Forget Gate):
    遗忘门决定哪些信息应该被遗忘,从而允许网络丢弃不重要的信息。它通过一个sigmoid激活函数生成一个介于0和1之间的值,用于控制细胞状态中信息的丢失程度。
    遗忘门的计算过程如下:
    2.1 输入:
    上一时刻的隐藏状态(或者是输入数据的向量)
    当前时刻的输入数据
    2.2 计算遗忘门的值:
    将上一时刻的隐藏状态和当前时刻的输入数据拼接在一起。
    通过一个带有sigmoid激活函数的全连接层(通常称为遗忘门层)得到介于0和1之间的值。
    这个值表示细胞状态中哪些信息应该被保留(接近1),哪些信息应该被遗忘(接近0)。
    2.3 遗忘操作:
    将上一时刻的细胞状态与遗忘门的输出相乘,以决定保留哪些信息。
    2.4数学表达式如下:
    遗忘门的输出:
    遗忘门

其中:
W f 和 b f 是遗忘门的权重矩阵和偏置向量。 W_f 和 b_f是遗忘门的权重矩阵和偏置向量。 Wfbf是遗忘门的权重矩阵和偏置向量。
h t − 1 ​是上一时刻的隐藏状态。 h_{t−1}​ 是上一时刻的隐藏状态。 ht1是上一时刻的隐藏状态。
x t 是当前时刻的输入数据。 x_t是当前时刻的输入数据。 xt是当前时刻的输入数据。
σ 是 s i g m o i d 激活函数。 σ 是sigmoid激活函数。 σsigmoid激活函数。

遗忘门的输出 ft 决定了细胞状态中上一时刻信息的保留程度。这个机制允许LSTM网络在处理时间序列数据时更有效地记住长期依赖关系。

  1. 输入门(Input Gate):
    输入门负责确定在当前时间步骤中要添加到细胞状态的新信息。类似于遗忘门,输入门使用sigmoid激活函数产生一个介于0和1之间的值,表示要保留多少新信息,并使用tanh激活函数生成一个新的候选值。
    在这里插入图片描述输入门的计算过程如下:
(1)输入门的输出计算:将上一时刻的隐藏状态(或者是输入数据)和当前时刻的输入数据拼接在一起。通过一个带有sigmoid激活函数的全连接层得到介于0和1之间的值。这个值表示要保留的新信息的程度。
(2)生成新的候选值:将上一时刻的隐藏状态(或者是输入数据)和当前时刻的输入数据拼接在一起。通过一个带有tanh激活函数的全连接层得到一个新的候选值(介于-1和1之间)。
(3)更新细胞状态的操作:将输入门的输出与新的候选值相乘,得到要添加到细胞状态的新信息。
  1. 输出门(Output Gate):
    输出门(Output Gate)在LSTM中控制细胞在特定时间步上的输出。输出门使用sigmoid激活函数产生介于0和1之间的值,这个值决定了在当前时间步细胞状态中有多少信息被输出。同时,输出门的输出与细胞状态经过tanh激活函数后的值相乘,产生最终的LSTM输出。

输出门的计算过程如下:

输出门的输出计算:将上一时刻的隐藏状态(或者是输入数据)和当前时刻的输入数据拼接在一起。通过一个带有sigmoid激活函数的全连接层得到介于0和1之间的值。这个值表示在当前时间步细胞状态中有多少信息要输出。
生成最终的LSTM输出:将当前时刻的细胞状态经过tanh激活函数,得到介于-1和1之间的值。将输出门的输出与tanh激活函数的细胞状态相乘,产生最终的LSTM输出。

在这里插入图片描述

1.3 LSTM的基础代码实现

以下是一个基础的实现,其中包括多层双向LSTM的前向传播。请注意,这个实现仍然是一个简化版本,实际应用中可能需要更多的调整和优化。

import numpy as npdef sigmoid(x):return 1 / (1 + np.exp(-x))def tanh(x):return np.tanh(x)def lstm_cell(xt, a_prev, c_prev, parameters):# 从参数中提取权重和偏置Wf = parameters["Wf"]bf = parameters["bf"]Wi = parameters["Wi"]bi = parameters["bi"]Wo = parameters["Wo"]bo = parameters["bo"]Wc = parameters["Wc"]bc = parameters["bc"]# 合并输入和上一个时间步的隐藏状态concat = np.concatenate((a_prev, xt), axis=0)# 遗忘门ft = sigmoid(np.dot(Wf, concat) + bf)# 输入门it = sigmoid(np.dot(Wi, concat) + bi)# 更新细胞状态cct = tanh(np.dot(Wc, concat) + bc)c_next = ft * c_prev + it * cct# 输出门ot = sigmoid(np.dot(Wo, concat) + bo)# 更新隐藏状态a_next = ot * tanh(c_next)# 保存计算中间结果,以便反向传播cache = (xt, a_prev, c_prev, a_next, c_next, ft, it, ot, cct)return a_next, c_next, cachedef lstm_forward(x, a0, parameters):n_x, m, T_x = x.shapen_a = a0.shape[0]a = np.zeros((n_a, m, T_x))c = np.zeros_like(a)caches = []a_prev = a0c_prev = np.zeros_like(a_prev)for t in range(T_x):xt = x[:, :, t]a_next, c_next, cache = lstm_cell(xt, a_prev, c_prev, parameters)a[:,:,t] = a_nextc[:,:,t] = c_nextcaches.append(cache)a_prev = a_nextc_prev = c_nextreturn a, c, cachesdef lstm_model_forward(x, parameters):caches = []a = xc_list = []for layer in parameters:a, c, layer_cache = lstm_forward(a, np.zeros_like(a[:, :, 0]), layer)caches.append(layer_cache)c_list.append(c)return a, c_list, cachesdef dense_layer_forward(a, parameters):W = parameters["W"]b = parameters["b"]z = np.dot(W, a) + ba_next = sigmoid(z)return a_next, zdef model_forward(x, parameters_lstm, parameters_dense):a_lstm, c_list, caches_lstm = lstm_model_forward(x, parameters_lstm)a_dense = a_lstm[:, :, -1]z_dense_list = []for layer_dense in parameters_dense:a_dense, z_dense = dense_layer_forward(a_dense, layer_dense)z_dense_list.append(z_dense)return a_dense, c_list, caches_lstm, z_dense_list# 示例数据和参数
np.random.seed(1)
x = np.random.randn(10, 5, 3)  # 10个样本,每个样本5个时间步,每个时间步3个特征# LSTM参数
parameters_lstm = [{"Wf": np.random.randn(5, 8), "bf": np.random.randn(5, 1),"Wi": np.random.randn(5, 8), "bi": np.random.randn(5, 1),"Wo": np.random.randn(5, 8), "bo": np.random.randn(5, 1),"Wc": np.random.randn(5, 8), "bc": np.random.randn(5, 1)},{"Wf": np.random.randn(3, 8), "bf": np.random.randn(3, 1),"Wi": np.random.randn(3, 8), "bi": np.random.randn(3, 1),"Wo": np.random.randn(3, 8), "bo": np.random.randn(3, 1),"Wc": np.random.randn(3, 8), "bc": np.random.randn(3, 1)}
]# Dense层参数
parameters_dense = [{"W": np.random.randn(1, 5), "b": np.random.randn(1, 1)},{"W": np.random.randn(1, 5), "b": np.random.randn(1, 1)}
]# 进行正向传播
a_dense, c_list, caches_lstm, z_dense_list = model_forward(x, parameters_lstm, parameters_dense)# 打印输出形状
print("a_dense.shape:", a_dense.shape)

二.GRU(门控循环单元)

GRU

2.1 GRU的基本介绍

门控循环单元(GRU,Gated Recurrent Unit)是一种用于处理序列数据的循环神经网络(RNN)变体,旨在解决传统RNN中的梯度消失问题,并提供更好的长期依赖建模。GRU引入了门控机制,类似于LSTM,但相对于LSTM,GRU结构更加简单。

GRU包含两个门:更新门(Update Gate)和重置门(Reset Gate)。这两个门允许GRU网络决定在当前时间步更新细胞状态的程度以及如何利用先前的隐藏状态。

重置门(Reset Gate)的计算:

通过一个sigmoid激活函数计算重置门的输出。重置门决定了在当前时间步,应该忽略多少先前的隐藏状态信息。

更新门(Update Gate)的计算:

通过一个sigmoid激活函数计算更新门的输出。更新门决定了在当前时间步,应该保留多少先前的隐藏状态信息。

候选隐藏状态的计算:

通过tanh激活函数计算一个候选的隐藏状态。

新的隐藏状态的计算:

通过更新门和候选隐藏状态计算新的隐藏状态。

2.2 GRU的代码实现

以下是使用PyTorch库实现基本的门控循环单元(GRU)的代码。PyTorch提供了GRU的高级API,可以轻松实现和使用。下面是一个简单的例子:

import torch
import torch.nn as nn# 定义GRU模型
class SimpleGRU(nn.Module):def __init__(self, input_size, hidden_size):super(SimpleGRU, self).__init__()self.gru = nn.GRU(input_size, hidden_size)def forward(self, x, hidden=None):output, hidden = self.gru(x, hidden)return output, hidden# 示例数据和模型参数
input_size = 3
hidden_size = 5
seq_len = 1  # 序列长度
batch_size = 1# 创建GRU模型
gru_model = SimpleGRU(input_size, hidden_size)# 将输入数据转换为PyTorch的Tensor
x = torch.randn(seq_len, batch_size, input_size)# 前向传播
output, hidden = gru_model(x)# 打印输出形状
print("Output shape:", output.shape)
print("Hidden shape:", hidden.shape)

以下是使用NumPy库实现基本的门控循环单元(GRU)的代码。这个实现是一个简化版本,其中包含更新门和重置门的计算,以及候选隐藏状态和新的隐藏状态的计算。

import numpy as npdef sigmoid(x):return 1 / (1 + np.exp(-x))def tanh(x):return np.tanh(x)def gru_cell(a_prev, x, parameters):# 从参数中提取权重和偏置W_r = parameters["W_r"]b_r = parameters["b_r"]W_z = parameters["W_z"]b_z = parameters["b_z"]W_a = parameters["W_a"]b_a = parameters["b_a"]# 计算重置门r_t = sigmoid(np.dot(W_r, np.concatenate([a_prev, x])) + b_r)# 计算更新门z_t = sigmoid(np.dot(W_z, np.concatenate([a_prev, x])) + b_z)# 计算候选隐藏状态tilde_a_t = tanh(np.dot(W_a, np.concatenate([r_t * a_prev, x])) + b_a)# 计算新的隐藏状态a_t = (1 - z_t) * a_prev + z_t * tilde_a_t# 保存计算中间结果,以便反向传播cache = (a_prev, x, r_t, z_t, tilde_a_t, a_t)return a_t, cache# 示例数据和参数
np.random.seed(1)
a_prev = np.random.randn(5, 1)  # 上一时刻的隐藏状态
x = np.random.randn(3, 1)  # 当前时刻的输入数据# GRU参数
parameters = {"W_r": np.random.randn(5, 8),"b_r": np.random.randn(5, 1),"W_z": np.random.randn(5, 8),"b_z": np.random.randn(5, 1),"W_a": np.random.randn(5, 8),"b_a": np.random.randn(5, 1)
}# 单个GRU单元的前向传播
a_t, cache = gru_cell(a_prev, x, parameters)# 打印输出形状
print("a_t.shape:", a_t.shape)

本文参考了以下链接:http://colah.github.io/posts/2015-08-Understanding-LSTMs/

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/630306.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Android NDK Crash信息收集捕获和日志异常定位分析(addr2line)

Android NDK 闪退日志收集与分析 我们在开发过程中,Android JNI层Crash问题或者我们引用的第三方.so库文件报错,都是一个比较头疼的问题。相对Java层来说,由于c/c++造成的crash没有输出如同Java的Exception Strace堆栈信息,所以定位问题也是个比较艰难的事情。 Google Br…

HCIA的路由协议

动态路由协议/静态路由协议 静态路由协议和动态路由协议的区别: 静态路由协议的缺点: 配置繁琐 针对拓扑的变化不能够自动收敛 只适用于小型网络 静态路由协议优点: 占用资源少 安全 稳定 动态路由协议的优点: 配置简单 针对拓…

前端项目配置 Dockerfile 打包后镜像部署无法访问

Dockerfile 配置如下: FROM node:lts-alpineWORKDIR /app COPY . . RUN npm install RUN npm run buildEXPOSE 3001CMD ["npm", "run", "preview"]构建镜像 docker build -t vite-clarity-project .启动镜像容器 docker run -p 30…

进程(一) 进程概念

文章目录 什么是进程呢? 描述进程-PCBtask_struct-PCB的一种task_struct内容分类 查看进程通过系统目录查看通过ps命令查看通过系统调用获取进程的PID和PPID通过系统调用创建进程- fork()函数 fork()函数fork函数做了什么?fork之后…

Vue加载序列帧动图

解读方法 使用<img :src"currentFrame" alt"加载中" /> 加载图片动态更改src的值使用 requestAnimationFrame 定时更新在需要的页面调用封装的组件 <LoadToast v-if"showLoading" /> 封装组件 <template><div class"…

CTF CRYPTO 密码学-1

题目名称&#xff1a;enc 题目描述&#xff1a; 压缩包中含两个文件&#xff1a;一个秘钥d.dec&#xff0c;一个密文flag.enc 解题过程&#xff1a; Step1&#xff1a;这题是一个解密他题目&#xff0c;尝试openssl去ras解密 工具简介 在Kali Linux系统中&#xff0c;openss…

React16源码: React中的异步调度scheduler模块的源码实现

React Scheduler 1 ) 概述 react当中的异步调度&#xff0c;称为 React Scheduler发布成单独的一个 npm 包就叫做 scheduler这个包它做了什么&#xff1f; A. 首先它维护时间片B. 然后模拟 requestIdleCallback 这个API 因为现在浏览器的支持不是特别的多所以在浏览当中只是去…

【计算机图形学】习题课:Viewing

【计算机图形学】Viewing 部分问题与解答 CS100433 Computer Graphics Assignment 21 Proof the composed transformations defined in global coordinate frame is equivalent to the composed transformations defined in local coordinate frame but in different composing…

1月14-17日为技术测试期!字节与腾讯上演“大和解”,抖音全面开放《王者荣耀》直播

综合整理&#xff5c;TesterHome社区 来源&#xff5c;《王者荣耀》官方、界面新闻 北京商报、IT之家 1月13日&#xff0c;腾讯游戏《王者荣耀》官方微博发布消息宣布&#xff0c;从1月21日起&#xff0c;《王者荣耀》抖音直播将全面开放。 为了筛查开播期间可能遇到的所有技…

几何_直线方程 Ax + By + C = 0 的系数A,B,C几何含义是?

参考&#xff1a; 直线方程 Ax By C 0 的系数A&#xff0c;B&#xff0c;C有什么几何含义&#xff1f;_设直线 l 的方程为axbyc0 怎么理解-CSDN博客 1. A B的含义&#xff1a;组成一个与直线垂直的向量 我们先来看A和B有什么含义。 在直线上取任意两点 P1:&#xff08;x1…

OceanBase集群部署

我认为学习一个中间件比较好的方式是&#xff0c;先了解它的架构和运行原理&#xff0c;然后动手部署一遍&#xff0c;加深对它的了解&#xff0c;再使用它&#xff0c;最后进行总结和分享 本篇介绍OceanBase部署前提配置和集群部署 1.使用开源免费的社区版&#xff0c;企业版…

[Android] Android架构体系(1)

文章目录 Android 的框架Dalvik 虚拟机JNI原生二进制可执行文件Android NDK中的binutils Bionic谷歌考虑到的版权问题Bionic与传统的C标准库&#xff08;如glibc&#xff09;的一些不同 参考 Android 的框架 Android 取得成功的关键因素之一就是它丰富的框架集。 没有这些框架…

架构08- 理解架构的模式2-管理和监控

大使模式&#xff1a;构建一个辅助服务&#xff0c;代表消费者使用服务或应用程序发送网络请求。 进程外的代理服务&#xff08;之前介绍中间件的时候也提到了&#xff0c;很多框架层面的事情可以以软件框架的形式寄宿在进程内&#xff0c;也可以以独立的代理形式做一个网络中…

AI绘图制作红包封面教程

注意&#xff1a;有不懂的话可加入QQ群聊一起交流&#xff1a;901944946欢迎大家关注微信公众号【程序猿代码之路】&#xff0c;每天都会不定时的发送一些红包封面&#xff01;&#xff01; 2024的春节即将到来&#xff0c;而在这春节到来之前&#xff0c;就有一个非常爆火的小…

黑马程序员 Java设计模式学习笔记(一)

目录 一、设计模式概述 1.1、23种设计模式有哪些&#xff1f; 1.2、软件设计模式的概念 1.3、学习设计模式的必要性 1.4、设计模式分类 二、UML图 2.1、类图概述 2.2、类图的作用 2.3、类图表示法 类的表示方式 类与类之间关系的表示方式 关联关系 聚合关系 组合…

陀螺仪LSM6DSV16X与AI集成(6)----检测自由落体

陀螺仪LSM6DSV16X与AI集成.6--检测自由落体 概述视频教学样品申请源码下载生成STM32CUBEMX串口配置IIC配置CS和SA0设置串口重定向参考程序初始换管脚获取ID复位操作BDU设置 概述 本文介绍如何初始化传感器并配置其参数&#xff0c;以便在检测到自由落体事件时发送通知。 最近…

显示报错: nmap.nmap.PortScannerError: ‘nmap program was not found in path‘

解决方案&#xff1a; 《关于想在Pycharm下使用nmap然后报错nmap.nmap.PortScannerError: ‘nmap program was not found in path.然后解决的那些事》 文章中进行了详尽的描述&#xff0c;总结一下就是下载一个nmap.exe&#xff0c;然后在nmap.py中引入nmap.exe所在的路径&…

RabbitMQ常见问题之消息堆积

文章目录 一、介绍二、使用惰性队列1. 基于Bean2. 基于RabbitListener 一、介绍 当生产者发送消息的速度超过了消费者处理消息的速度,就会导致队列中的消息堆积,直到队列存储消息达到上限。最 早接收到的消息&#xff0c;可能就会成为死信&#xff0c;会被丢弃&#xff0c;这就…

Pod控制器:

Pod控制器&#xff1a; Pv pvc 动态PV Pod控制器&#xff1a;工作负载。WordLoad&#xff0c;用于管理pod的中间层 &#xff0c;确保pod资源符合预期的状态 预期状态&#xff1a; 副本数容器的重启策略镜像的拉取策略 Pod出现故障时的重启等等 Pod控制器的类型&#xff1a…

【大数据】Flink 详解(八):SQL 篇 Ⅰ

《Flink 详解》系列&#xff08;已完结&#xff09;&#xff0c;共包含以下 10 10 10 篇文章&#xff1a; 【大数据】Flink 详解&#xff08;一&#xff09;&#xff1a;基础篇【大数据】Flink 详解&#xff08;二&#xff09;&#xff1a;核心篇 Ⅰ【大数据】Flink 详解&…