长短期记忆(LSTM)与RNN的比较:突破性的序列训练技术

长短期记忆(Long short-term memory, LSTM)是一种特殊的RNN,主要是为了解决长序列训练过程中的梯度消失和梯度爆炸问题。简单来说,就是相比普通的RNN,LSTM能够在更长的序列中有更好的表现。

Why

LSTM提出的动机是为了解决「长期依赖问题」

长期依赖(Long Term Dependencies)

在深度学习领域中(尤其是RNN),“长期依赖“问题是普遍存在的。长期依赖产生的原因是当神经网络的节点经过许多阶段的计算后,之前比较长的时间片的特征已经被覆盖,例如下面例子

eg1: The cat, which already ate a bunch of food, was full.
      |   |     |      |     |  |   |   |   |     |   |
     t0  t1    t2      t3    t4 t5  t6  t7  t8    t9 t10
eg2: The cats, which already ate a bunch of food, were full.
      |   |      |      |     |  |   |   |   |     |    |
     t0  t1     t2     t3    t4 t5  t6  t7  t8    t9   t10

我们想预测'full'之前系动词的单复数情况,显然full是取决于第二个单词’cat‘的单复数情况,而非其前面的单词food。根据RNN的结构,随着数据时间片的增加,RNN丧失了学习连接如此远的信息的能力。

alt

LSTM vs. RNN

alt

相比RNN只有一个传递状态 ,LSTM有两个传输状态,一个 (cell state),和一个 (hidden state)。(Tips:RNN中的 对于LSTM中的

其中对于传递下去的 改变得很慢,通常输出的 是上一个状态传过来的 加上一些数值。

则在不同节点下往往会有很大的区别。

Model 详解

状态计算

首先使用LSTM的当前输入 和上一个状态传递下来的 拼接训练得到四个状态。

alt

其中, 是由拼接向量乘以权重矩阵之后,再通过一个 激活函数转换成0到1之间的数值,来作为一种门控状态。而 则是将结果通过一个 激活函数将转换成-1到1之间的值(这里使用 是因为这里是将其做为输入数据,而不是门控信号)。

计算过程

alt

⊙ 是Hadamard Product,也就是操作矩阵中对应的元素相乘,因此要求两个相乘矩阵是同型的。 ⊕ 则代表进行矩阵加法。

LSTM内部主要有三个阶段:

  1. 「忘记阶段」。这个阶段主要是对上一个节点传进来的输入进行 「选择性」忘记。简单来说就是会 “忘记不重要的,记住重要的”。

具体来说是通过计算得到的 (f表示forget)来作为忘记门控,来控制上一个状态的 哪些需要留哪些需要忘。

  1. 「选择记忆阶段」。这个阶段将这个阶段的输入有选择性地进行“记忆”。主要是会对输入 进行选择记忆。哪些重要则着重记录下来,哪些不重要,则少记一些。当前的输入内容由前面计算得到的 表示。而选择的门控信号则是由 (i代表information)来进行控制。

将上面两步得到的结果相加,即可得到传输给下一个状态的 。也就是上图中的第一个公式。

  1. 「输出阶段」。这个阶段将决定哪些将会被当成当前状态的输出。主要是通过 来进行控制的。并且还对上一阶段得到的 进行了放缩(通过一个tanh激活函数进行变化)。

与普通RNN类似,输出 往往最终也是通过 变化得到。

Code

现在,我们从零开始实现长短期记忆网络。 与 8.5节中的实验相同, 我们首先加载时光机器数据集。

import torch
from torch import nn
from d2l import torch as d2l

batch_size, num_steps = 3235
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)
  • 初始化模型参数

定义和初始化模型参数。 如前所述,超参数num_hiddens定义隐藏单元的数量。 我们按照标准差0.01的高斯分布初始化权重,并将偏置项设为0。

def get_lstm_params(vocab_size, num_hiddens, device):
    num_inputs = num_outputs = vocab_size

    def normal(shape):
        return torch.randn(size=shape, device=device)*0.01

    def three():
        return (normal((num_inputs, num_hiddens)),
                normal((num_hiddens, num_hiddens)),
                torch.zeros(num_hiddens, device=device))

    W_xi, W_hi, b_i = three()  # 输入门参数
    W_xf, W_hf, b_f = three()  # 遗忘门参数
    W_xo, W_ho, b_o = three()  # 输出门参数
    W_xc, W_hc, b_c = three()  # 候选记忆元参数
    # 输出层参数
    W_hq = normal((num_hiddens, num_outputs))
    b_q = torch.zeros(num_outputs, device=device)
    # 附加梯度
    params = [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc,
              b_c, W_hq, b_q]
    for param in params:
        param.requires_grad_(True)
    return params
  • 定义模型
def init_lstm_state(batch_size, num_hiddens, device):
    return (torch.zeros((batch_size, num_hiddens), device=device),
            torch.zeros((batch_size, num_hiddens), device=device))

def lstm(inputs, state, params):
    [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c,
     W_hq, b_q] = params
    (H, C) = state
    outputs = []
    for X in inputs:
        I = torch.sigmoid((X @ W_xi) + (H @ W_hi) + b_i)
        F = torch.sigmoid((X @ W_xf) + (H @ W_hf) + b_f)
        O = torch.sigmoid((X @ W_xo) + (H @ W_ho) + b_o)
        C_tilda = torch.tanh((X @ W_xc) + (H @ W_hc) + b_c)
        C = F * C + I * C_tilda
        H = O * torch.tanh(C)
        Y = (H @ W_hq) + b_q
        outputs.append(Y)
    return torch.cat(outputs, dim=0), (H, C)
  • 训练和预测
vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_epochs, lr = 5001
model = d2l.RNNModelScratch(len(vocab), num_hiddens, device, get_lstm_params,
                            init_lstm_state, lstm)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)

# perplexity 1.3, 17736.0 tokens/sec on cuda:0
# time traveller for so it will leong go it we melenot ir cove i s
# traveller care be can so i ngrecpely as along the time dime
alt

总结

  • 长短期记忆网络有三种类型的门:输入门、遗忘门和输出门。
  • 长短期记忆网络的隐藏层输出包括“隐状态”和“记忆元”。只有隐状态会传递到输出层,而记忆元完全属于内部信息。
  • 长短期记忆网络可以缓解梯度消失和梯度爆炸。

Ref

  1. https://zhuanlan.zhihu.com/p/32085405
  2. https://zhuanlan.zhihu.com/p/42717426
  3. https://zh.d2l.ai/chapter_recurrent-modern/lstm.html

本文由 mdnice 多平台发布

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/147059.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

C# using语句使用介绍

在C#中,using语句有两种主要用途:一是引入命名空间,二是提供一种简便的方式来处理资源的清理(主要用于实现了 IDisposable 接口的对象)。 引入命名空间:using 语句用于引入命名空间,从而可以在代…

django理解02 前后端分离中的问题

前后端分离相对于传统方式的问题 前后端数据交换的问题跨域问题 页面js往自身程序(django服务)发送请求,这是浏览器默认接受响应 而请求其它地方是浏览器认为存在潜在危险。自动隔离请求!!! 跨域问题的解决…

springcloud整合nacos实现服务注册

Nacos是一个开源的分布式系统服务和基础设施解决方案,用于实现动态服务发现、配置管理和服务治理。它可以帮助开发人员和运维团队更好地管理微服务架构中的服务实例、配置信息和服务调用。 Nacos提供了服务注册与发现、动态配置管理、服务路由和负载均衡等功能&…

C++之set/multise容器

C之set/multise容器 set基本概念 set构造和赋值 #include <iostream> #include<set> using namespace std;void PrintfSet(set<int>&s) {for(set<int>::iterator it s.begin();it ! s.end();it){cout<<*it<<" ";}cout&l…

typora使用PicGo自动上传图片到chevereto图床

typora使用PicGo自动上传图片到chevereto图床 近期发现&#xff0c;gitee图床不能用了。github又涉及科学上网。搜索了开源图床方案&#xff0c;找到了chevereto&#xff0c;使用起来还不错。分享给大家。 文章目录 typora使用PicGo自动上传图片到chevereto图床chevereto图床安…

精密云工程:智能激活业务速率 ——华为云11.11联合大促倒计时 仅剩3日

现新客3.96元起&#xff0c;下单有机会抽HUAWEI P60 Art&#xff0c;福利仅限双十一&#xff0c;机会唾手可得&#xff0c;立即行动&#xff01; 双十一购物节来临倒计时&#xff0c;华为云备上多款增值产品&#xff0c;以最优品质迸发冬日技术热浪&#xff0c;满足行业技术应用…

Mac 安装 protobuf 和Android Studio 使用

1. 安装,执行命令 brew install protoc 2. Mac 错误提示&#xff1a;zsh: command not found: brew解决方法 解决方法&#xff1a;mac 安装homebrew&#xff0c; 用以下命令安装&#xff0c;序列号选择中科大&#xff08;1&#xff09;或 阿里云 /bin/zsh -c "$(curl…

VB.net webbrowser 自定义下载接口实现

使用《VB.net webbrowser 如何实现自定义下载 IDownloadManager》中的控件ExtendedWebBrowser&#xff08;下载控件&#xff09;&#xff0c;并扩展了NewWindow2。 使用ExtendedWebBrowser_1过程中&#xff0c;遇到很多问题&#xff0c;花了几天时间&#xff0c;终于解决了所有…

MLC-LLM 支持RWKV-5推理以及对RWKV-5的一些思考

自从2023年3月左右&#xff0c;chatgpt火热起来之后&#xff0c;我把关注的一些知乎帖子都记录到了这个markdown里面&#xff0c;&#xff1a;https://github.com/BBuf/how-to-optim-algorithm-in-cuda/tree/master/large-language-model-note &#xff0c;从2023年3月左右到现…

漏洞利用工具的编写

预计更新网络扫描工具的编写漏洞扫描工具的编写Web渗透测试工具的编写密码破解工具的编写漏洞利用工具的编写拒绝服务攻击工具的编写密码保护工具的编写情报收集工具的编写 漏洞利用工具是一种常见的安全工具&#xff0c;它可以利用系统或应用程序中的漏洞来获取系统权限或者窃…

一起Talk Android吧(第五百五十二回:Retrofit的基本用法)

文章目录 1. 概念介绍2. 使用方法2.1 创建请求接口2.2 创建retrofit对象2.3 创建请求接口的对象2.4 发起请求3. 内容总结各位看官们大家好,上一回中咱们说的例子是"如何自定义SplashScreen",本章回中介绍的例子是" Retrofit的基本用法"。闲话休提,言归正…

SELinux零知识学习十三、SELinux策略语言之客体类别和许可(7)

接前一篇文章&#xff1a;SELinux零知识学习十二、SELinux策略语言之客体类别和许可&#xff08;6&#xff09; 一、SELinux策略语言之客体类别和许可 4. 客体类别许可实例 为了更好地理解许可是如何控制对系统资源的访问的&#xff0c;下面进一步讨论以下两个客体类别和许可…

python项目源码基于django的宿舍管理系统dormitory+mysql数据库文件

基于Django的宿舍管理系统 运行效果 个人亲自制作python项目源码基于django的宿舍管理系统dormitorymysql数据库文件 1. 介绍 宿舍管理系统是一个基于Django框架开发的项目&#xff0c;旨在简化和优化宿舍管理的流程。该系统包括学生和管理员两个角色&#xff0c;学生可以通过…

安装插件时Vscode XHR Failed 报错ERR_CERT_AUTHORITY_INVALID

安装插件时Vscode XHR Failed 报错ERR_CERT_AUTHORITY_INVALID 今天用vscode 安装python插件时报XHR failed,无法拉取应用商城的数据&#xff0c; 报的错如下&#xff1a; ERR_CERT_AUTHORITY_INVALID 翻译过来就是证书有问题 找错误代码的方法&#xff1a; 打开vscode, 按F1…

Swift 如何打造兼容新老系统的字符串分割(split)方法

0. 概览 在 Swift 的开发中&#xff0c;我们经常要与字符串打交道。其中一个常见的操作就是用特定的“分隔符”来分割字符串&#xff0c;这里分隔符可能不仅仅是字符&#xff0c;而是多字符组成的字符串。 从 iOS 16 开始&#xff0c; 新增了对应的方法来专注此事。不过&am…

HBase中的数据表是如何用CHAT进行分区的?

问CHA&#xff1a;HBase中的数据表是如何进行分区的&#xff1f; CHAT回复&#xff1a; 在HBase中&#xff0c;数据表是水平分区的。每一个分区被称为一个region。当一个region达到给定的大小限制时&#xff0c;它会被分裂成两个新的region。 因此&#xff0c;随着数据量的增…

【华为OD题库-026】通过软盘拷贝文件-java

题目 有一名科学家想要从一台古董电脑中拷贝文件到自己的电脑中加以研究。但此电脑除了有一个3.5寸软盘驱动器以外&#xff0c;没有任何手段可以将文件拷贝出来&#xff0c;而且只有一张软盘可以使用。因此这一张软盘是唯一可以用来拷贝文件的载体。科学家想要尽可能多地将计算…

WPF拖拽相关的类

WPF的VisualTreeHelper类是一组静态方法&#xff0c;主要用于在WPF的VisualTree&#xff08;可视化树&#xff09;中进行遍历和查找操作。VisualTreeHelper类提供的方法可以帮助开发人员轻松地访问和操作VisualTree中的元素。 以下是VisualTreeHelper类的一些主要功能&#xf…

mac苹果笔记本应用程序在哪?有什么快捷方式吗?

苹果笔记本电脑一直以来都被广泛使用&#xff0c;而苹果的操作系统 macOS 也非常受欢迎。一台好的笔记本电脑不仅仅依赖于硬件配置&#xff0c;还需要丰富多样的应用程序来满足用户的需求。苹果笔记本应用程序在哪&#xff0c;不少mac新手用户会有这个疑问。在这篇文章中&#…

Golang抓包:实现网络数据包捕获与分析

介绍 在网络通信中&#xff0c;网络数据包是信息传递的基本单位。抓包是一种监控和分析网络流量的方法&#xff0c;用于获取网络数据包并对其进行分析。在Golang中&#xff0c;我们可以借助现有的库来实现抓包功能&#xff0c;进一步对网络数据进行分析和处理。 本文将介绍如…