【深度学习】LSTM、BiLSTM详解

文章目录

  • 1. LSTM简介:
  • 2. LSTM结构图:
  • 3. 单层LSTM详解
  • 4. 双层LSTM详解
  • 5. BiLSTM
  • 6. Pytorch实现LSTM示例
  • 7. nn.LSTM参数详解

1. LSTM简介:

    LSTM是一种循环神经网络,它可以处理和预测时间序列中间隔和延迟相对较长的重要事件。LSTM通过使用门控单元来控制信息的流动,从而缓解RNN中的梯度消失和梯度爆炸的问题。LSTM的核心是三个门:输入门遗忘门输出门

遗忘门: 遗忘门的作用是决定哪些信息从记忆单元中遗忘,它使用sigmoid激活函数,可以输出在0到1之间的值,可以理解为保留信息的比例。
输入门: 作用是决定哪些新信息被存储在记忆单元中
输出门: 输出门决定了下一个隐藏状态,即生成当前时间步的输出并传递到下一时间步
记忆单元:负责长期信息的存储,通过遗忘门和输入门的相互作用,记忆单元能够学习如何选择性地记住或忘记信息

2. LSTM结构图:

在这里插入图片描述

涉及到的计算公式如下:
在这里插入图片描述

3. 单层LSTM详解

(1)设定有3个字的序列【“早”“上”“好”】要经过LSTM处理,每个序列由20个元素组成的列向量构成,所以input size就为20。

(2)设定全连接层中有100个隐藏单元,LSTM的层数为1。

(3)因为是3个字的序列,所以LSTM需要3个时间步(即会自循环3次)才能处理完这个序列。

(4)nn.LSTM()每层也可以拆开写,这样每层的隐藏单元个数就可以分别设定。

在这里插入图片描述
    LSTM单元包含三个输入参数x、c、h;首先t1时刻作为第一个时间步,输入到第一个LSTM单元中,此时输入的初始从c(0)和h(0)都是0矩阵,计算完成后,第一个LSTM单元输出一组h(t1)\c(t1),作为本层LSTM的第二个时间步的输入参数;因此第二个时间步的输入就是h(t1),c(t1),x(t2),而输出是h(t2),c(t2);因此第三个时间步的输入就是h(t2),c(t2),x(t3),而输出是h(t3),c(t3)。

4. 双层LSTM详解

(1)设定有3个字的序列【“早”“上”“好”】要经过LSTM处理,每个序列由20个元素组成的列向量构成,所以input size就为20。

(2)设定全连接层中有100个隐藏单元,LSTM的层数为2。

(3)因为是3个字的序列,所以LSTM需要3个时间步(即会自循环3次)才能处理完这个序列。

(4)nn.LSTM()每层也可以拆开写,这样每层的隐藏单元个数就可以分别设定。

在这里插入图片描述
    第二层LSTM没有输入参数x(t1)、x(t2)、x(t3);所以我们将第一层LSTM输出的h(t1)、h(t2)、h(t3)作为第二层LSTM的输入x(t1)、x(t2)、x(t3)。第一个时间步输入的初始c(0)和h(0)都为0矩阵,计算完成后,第一个时间步输出新的一组h(t1)、c(t1),作为本层LSTM的第二个时间步的输入参数;因此第二个时间步的输入就是h(t1),c(t1),x(t2),而输出是h(t2),c(t2);因此第三个时间步的输入就是h(t2),c(t2),x(t3),而输出是h(t3),c(t3)。

5. BiLSTM

单层的BiLSTM其实就是2个LSTM,一个正向去处理序列,一个反向去处理序列,处理完后,两个LSTM的输出会拼接起来。
在这里插入图片描述

6. Pytorch实现LSTM示例

import torch 
import torch.nn as nndevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")class LSTM(nn.Module):def __init__(self, input_dim, hidden_dim, num_layers, output_dim):super(LSTM, self).__init__()self.hidden_dim = hidden_dim  # 隐藏层维度self.num_layers = num_layers  # LSTM层的数量# LSTM网络层self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers, batch_first=True)# 全连接层,用于将LSTM的输出转换为最终的输出维度self.fc = nn.Linear(hidden_dim, output_dim)def forward(self, x):# 初始化隐藏状态和细胞状态h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_dim).to(device)c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_dim).to(device)# 前向传播LSTM,返回输出和最新的隐藏状态与细胞状态out, (hn, cn) = self.lstm(x, (h0.detach(), c0.detach()))# 将LSTM的最后一个时间步的输出通过全连接层out = self.fc(out[:, -1, :])return out

7. nn.LSTM参数详解

pytorch官方定义:

CLASS torch.nn.LSTM(
        input_size,
        hidden_size,
        num_layers=1,
        bias=True,
        batch_first=False,
        dropout=0.0,
        bidirectional=False,
        proj_size=0,
        device=None,
        dtype=None
    )

input_size – 输入 x 中预期的特征数量
hidden_size – 隐藏状态 h 中的特征数量
num_layers – 循环层的数量。例如,设置 num_layers=2 表示将两个 LSTM 堆叠在一起形成一个 stacked LSTM,其中第二个 LSTM 接收第一个 LSTM 的输出并计算最终结果。默认值:1
bias – 如果 False,则该层不使用偏差权重 b_ih 和 b_hh。默认值:True
batch_first – 如果 True,则输入和输出张量将以 (batch, seq, feature) 而不是 (seq, batch, feature) 的形式提供。请注意,这并不适用于隐藏状态或单元状态。有关详细信息,请参见下面的输入/输出部分。默认值:False
dropout – 如果非零,则在除最后一层之外的每个 LSTM 层的输出上引入一个 Dropout 层,其 dropout 概率等于 dropout。默认值:0
bidirectional – 如果 True,则变为双向 LSTM。默认值:False
proj_size – 如果 > 0,则将使用具有相应大小的投影的 LSTM。默认值:0

对于输入序列每一个元素,每一层都会进行以下计算:
在这里插入图片描述
网络输入:
在这里插入图片描述

网络输出:
在这里插入图片描述

本文参考:https://blog.csdn.net/qq_34486832/article/details/134898868
https://pytorch.ac.cn/docs/stable/generated/torch.nn.LSTM.html#

LSTM每层的输出都要经过全连接层吗,还是直接对隐藏层进行输出?
通过在代码中对lstm的输出进行print输出:

import torch 
import torch.nn as nndevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")class LSTM(nn.Module):def __init__(self, input_dim, hidden_dim, num_layers, output_dim):super(LSTM, self).__init__()self.hidden_dim = hidden_dim  # 隐藏层维度self.num_layers = num_layers  # LSTM层的数量# LSTM网络层self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers, batch_first=True)# 全连接层,用于将LSTM的输出转换为最终的输出维度self.fc = nn.Linear(hidden_dim, output_dim)def forward(self, x):# 初始化隐藏状态和细胞状态h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_dim).to(device)c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_dim).to(device)# 前向传播LSTM,返回输出和最新的隐藏状态与细胞状态out, (hn, cn) = self.lstm(x, (h0.detach(), c0.detach()))print(out)print(hn)print(cn)# 将LSTM的最后一个时间步的输出通过全连接层out = self.fc(out[:, -1, :])return out
if __name__ == "__main__":input_dim = 3        # 输入特征的维度hidden_dim = 4       # 隐藏层的维度num_layers = 1       # LSTM 层的数量output_dim = 1       # 输出特征的维度lstm = LSTM(input_dim, hidden_dim, num_layers, output_dim).to(device)batch_size = 1seq_length = 10input_tensor = torch.randn(batch_size, seq_length, input_dim).to(device)output = lstm(input_tensor)

通过对LSTM网络的输出我们可以看到,out的最后一层与最后一层隐藏层hn一致,说明并未经过全连接层,而是直接输出隐藏层
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/60778.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

leetcode hot100【LeetCode 236.二叉树的最近公共祖先】java实现

LeetCode 236.二叉树的最近公共祖先 题目描述 给定一个二叉树, 找到该树中两个给定节点的最近公共祖先。 节点可以表示为它在树中的路径,其中路径的第一个节点是根节点,每个后续节点是其父节点的直接子节点。 示例 1: 输入: root [3,5,1,6,2,0,8,n…

【云原生系列--Longhorn的部署】

Longhorn部署手册 1.部署longhorn longhorn架构图: 1.1部署环境要求 kubernetes版本要大于v1.21 每个节点都必须装open-iscsi ,Longhorn依赖于 iscsiadm主机为 Kubernetes 提供持久卷。 apt-get install -y open-iscsiRWX 支持要求每个节点都安装 N…

跟李笑来学美式俚语(Most Common American Idioms): Part 02

Most Common American Idioms: Part 02 前言 本文是学习李笑来的Most Common American Idioms这本书的学习笔记,自用。 Github仓库链接:https://github.com/xiaolai/most-common-american-idioms 使用方法: 直接下载下来(或者clone到本地…

Molecular signatures database (MSigDB) 3.0

文献翻译和解读(解读在最后) 摘要 动机: 良好注释的基因集代表了生物学过程的全貌,对于大规模基因组数据的有意义和深入解读至关重要。分子特征数据库(MSigDB)是最广泛使用的此类基因集资源库之一。 结果…

【Hadoop】【hdfs】【大数据技术基础】实验三 HDFS 基础编程实验

实验三: HDFS Java API编程实践 实验题目 HDFS Java API编程实践 实验目的 熟悉HDFS操作常用的Java API。 实验平台 操作系统:Linux Hadoop版本:2.6.0或以上版本 JDK版本:1.6或以上版本 Java IDE:Eclipse 实验…

Flowable 构建后端服务(后端以及数据库搭建) Flowable Modeler 设计器搭建(前端)

案例地址&#xff1a;xupengboo-flowable-example Flowable 构建后端服务&#xff08;后端以及数据库搭建&#xff09; 以 Spring Boot 项目为例&#xff1a; 引入 Flowable 必要依赖。 <!-- flowable 依赖 --> <dependency><groupId>org.flowable</gr…

2022 年 9 月青少年软编等考 C 语言二级真题解析

目录 T1. 统计误差范围内的数思路分析 T2. 全在其中思路分析 T3. Lab 杯思路分析 T4. 有趣的跳跃思路分析 T5. 反反复复思路分析 T1. 统计误差范围内的数 统计一个整数序列中与指定数字 m m m 误差范围小于等于 x x x 的数的个数。 时间限制&#xff1a;1 s 内存限制&#…

ssm114基于SSM框架的网上拍卖系统的设计与实现+vue(论文+源码)_kaic

摘 要 随着科学技术的飞速发展&#xff0c;各行各业都在努力与现代先进技术接轨&#xff0c;通过科技手段提高自身的优势&#xff0c;商品拍卖当然也不能排除在外&#xff0c;随着商品拍卖管理的不断成熟&#xff0c;它彻底改变了过去传统的经营管理方式&#xff0c;不仅使商品…

智慧农业的前世今生

智慧农业是将现代信息技术与传统农业相结合的新型农业生产方式&#xff0c;其发展历程如下&#xff1a; 20世纪70年代末&#xff0c;以美国为代表的欧美国家率先开始农业信息化、智能化的应用研究&#xff0c;以农业专家系统为代表的农业信息化应用开始在农业生产领域萌芽。我…

BERT模型核心组件详解及其实现

摘要 BERT&#xff08;Bidirectional Encoder Representations from Transformers&#xff09;是一种基于Transformer架构的预训练模型&#xff0c;在自然语言处理领域取得了显著的成果。本文详细介绍了BERT模型中的几个关键组件及其实现&#xff0c;包括激活函数、变量初始化…

Transformer中的算子:其中Q,K,V就是算子

目录 Transformer中的算子 其中Q,K,V就是算子 一、数学中的算子 二、计算机科学中的算子 三、深度学习中的算子 四、称呼的由来 Transformer中的算子 其中Q,K,V就是算子 “算子”这一称呼源于其在数学、计算机科学以及深度学习等多个领域中的广泛应用和特定功能。以下是…

ElementPlus el-upload上传组件on-change只触发一次

ElementPlus el-upload上传组件on-change只触发一次 主要运用了:on-exceed方法 废话不多说&#xff0c;直接上代码 <el-uploadclass"avatar-uploader"action"":on-change"getFilesj":limit"1":auto-upload"false"accep…

厦大南洋理工最新开源,一种面向户外场景的特征-几何一致性无监督点云配准方法

导读 本文提出了INTEGER&#xff0c;一种面向户外点云数据的无监督配准方法&#xff0c;通过整合高层上下文和低层几何特征信息来生成更可靠的伪标签。该方法基于教师-学生框架&#xff0c;创新性地引入特征-几何一致性挖掘&#xff08;FGCM&#xff09;模块以提高伪标签的准确…

Conda环境与Ubuntu环境移植详解

Conda环境与Ubuntu环境移植详解 在计算机科学中&#xff0c;环境迁移是一项常见的任务&#xff0c;特别是对于使用Anaconda等工具进行数据科学和机器学习的开发人员。迁移环境不仅能够帮助开发者在不同设备间无缝切换&#xff0c;还能确保项目依赖的一致性&#xff0c;从而避免…

【深度学习基础】PyCharm anaconda PYTorch python CUDA cuDNN 环境配置

这里写目录标题 PyCharm 安装anaconda安装PYTorch安装确定python版本CUDA安装cuDNN安装检验环境是否配置成功参照:PyCharm 安装 官网下载 anaconda安装 官网下载 :https://www.anaconda.com/download 配置环境变量,增加 D:\WorkSoftware\Install\Anaconda3 D:\WorkSoftw…

生产环境中AI调用的优化:AI网关高价值应用实践

随着越来越多的组织将生成式AI引入生产环境&#xff0c;他们面临的挑战已经超出了初步实施的范畴。如果管理不当&#xff0c;扩展性限制、安全漏洞和性能瓶颈可能会阻碍AI应用的推广。实际问题如用户数据的安全性、固定容量限制、成本管理和延迟优化等&#xff0c;需要创新的解…

Redis 概 述 和 安 装

安 装 r e d i s: 1. 下 载 r e dis h t t p s : / / d o w n l o a d . r e d i s . i o / r e l e a s e s / 2. 将 redis 安装包拷贝到 /opt/ 目录 3. 解压 tar -zvxf redis-6.2.1.tar.gz 4. 安装gcc yum install gcc 5. 进入目录 cd redis-6.2.1 6. 编译 make …

SpringBoot 2.2.10 无法执行Test单元测试

很早之前的项目今天clone现在&#xff0c;想执行一个业务订单的检查&#xff0c;该检查的代码放在test单元测试中&#xff0c;启动也是好好的&#xff0c;当点击对应的方法执行Test的时候就报错 tip&#xff1a;已添加spring-boot-test-starter 所以本身就引入了junit5的库 No…

Dubbo 3.2 源码导读

Dubbo 是一个高性能的 Java RPC 框架&#xff0c;广泛用于构建分布式服务。Dubbo 3.2 版本引入了一些新的特性和改进&#xff0c;是一个值得深入研究的版本。以下是对 Dubbo 3.2 源码的导读&#xff0c;帮助你理解其架构和设计。 1. 源码获取 从 GitHub 上获取 Dubbo 3.2 的源…

[项目代码] YOLOv5 铁路工人安全帽安全背心识别 [目标检测]

YOLOv5是一种单阶段&#xff08;one-stage&#xff09;检测算法&#xff0c;它将目标检测问题转化为一个回归问题&#xff0c;能够在一次前向传播过程中同时完成目标的分类和定位任务。相较于两阶段检测算法&#xff08;如Faster R-CNN&#xff09;&#xff0c;YOLOv5具有更高的…