长短期记忆(LSTM)与RNN的比较:突破性的序列训练技术

长短期记忆(Long short-term memory, LSTM)是一种特殊的RNN,主要是为了解决长序列训练过程中的梯度消失和梯度爆炸问题。简单来说,就是相比普通的RNN,LSTM能够在更长的序列中有更好的表现。

Why

LSTM提出的动机是为了解决「长期依赖问题」

长期依赖(Long Term Dependencies)

在深度学习领域中(尤其是RNN),“长期依赖“问题是普遍存在的。长期依赖产生的原因是当神经网络的节点经过许多阶段的计算后,之前比较长的时间片的特征已经被覆盖,例如下面例子

eg1: The cat, which already ate a bunch of food, was full.
      |   |     |      |     |  |   |   |   |     |   |
     t0  t1    t2      t3    t4 t5  t6  t7  t8    t9 t10
eg2: The cats, which already ate a bunch of food, were full.
      |   |      |      |     |  |   |   |   |     |    |
     t0  t1     t2     t3    t4 t5  t6  t7  t8    t9   t10

我们想预测'full'之前系动词的单复数情况,显然full是取决于第二个单词’cat‘的单复数情况,而非其前面的单词food。根据RNN的结构,随着数据时间片的增加,RNN丧失了学习连接如此远的信息的能力。

alt

LSTM vs. RNN

alt

相比RNN只有一个传递状态 ,LSTM有两个传输状态,一个 (cell state),和一个 (hidden state)。(Tips:RNN中的 对于LSTM中的

其中对于传递下去的 改变得很慢,通常输出的 是上一个状态传过来的 加上一些数值。

则在不同节点下往往会有很大的区别。

Model 详解

状态计算

首先使用LSTM的当前输入 和上一个状态传递下来的 拼接训练得到四个状态。

alt

其中, 是由拼接向量乘以权重矩阵之后,再通过一个 激活函数转换成0到1之间的数值,来作为一种门控状态。而 则是将结果通过一个 激活函数将转换成-1到1之间的值(这里使用 是因为这里是将其做为输入数据,而不是门控信号)。

计算过程

alt

⊙ 是Hadamard Product,也就是操作矩阵中对应的元素相乘,因此要求两个相乘矩阵是同型的。 ⊕ 则代表进行矩阵加法。

LSTM内部主要有三个阶段:

  1. 「忘记阶段」。这个阶段主要是对上一个节点传进来的输入进行 「选择性」忘记。简单来说就是会 “忘记不重要的,记住重要的”。

具体来说是通过计算得到的 (f表示forget)来作为忘记门控,来控制上一个状态的 哪些需要留哪些需要忘。

  1. 「选择记忆阶段」。这个阶段将这个阶段的输入有选择性地进行“记忆”。主要是会对输入 进行选择记忆。哪些重要则着重记录下来,哪些不重要,则少记一些。当前的输入内容由前面计算得到的 表示。而选择的门控信号则是由 (i代表information)来进行控制。

将上面两步得到的结果相加,即可得到传输给下一个状态的 。也就是上图中的第一个公式。

  1. 「输出阶段」。这个阶段将决定哪些将会被当成当前状态的输出。主要是通过 来进行控制的。并且还对上一阶段得到的 进行了放缩(通过一个tanh激活函数进行变化)。

与普通RNN类似,输出 往往最终也是通过 变化得到。

Code

现在,我们从零开始实现长短期记忆网络。 与 8.5节中的实验相同, 我们首先加载时光机器数据集。

import torch
from torch import nn
from d2l import torch as d2l

batch_size, num_steps = 3235
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)
  • 初始化模型参数

定义和初始化模型参数。 如前所述,超参数num_hiddens定义隐藏单元的数量。 我们按照标准差0.01的高斯分布初始化权重,并将偏置项设为0。

def get_lstm_params(vocab_size, num_hiddens, device):
    num_inputs = num_outputs = vocab_size

    def normal(shape):
        return torch.randn(size=shape, device=device)*0.01

    def three():
        return (normal((num_inputs, num_hiddens)),
                normal((num_hiddens, num_hiddens)),
                torch.zeros(num_hiddens, device=device))

    W_xi, W_hi, b_i = three()  # 输入门参数
    W_xf, W_hf, b_f = three()  # 遗忘门参数
    W_xo, W_ho, b_o = three()  # 输出门参数
    W_xc, W_hc, b_c = three()  # 候选记忆元参数
    # 输出层参数
    W_hq = normal((num_hiddens, num_outputs))
    b_q = torch.zeros(num_outputs, device=device)
    # 附加梯度
    params = [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc,
              b_c, W_hq, b_q]
    for param in params:
        param.requires_grad_(True)
    return params
  • 定义模型
def init_lstm_state(batch_size, num_hiddens, device):
    return (torch.zeros((batch_size, num_hiddens), device=device),
            torch.zeros((batch_size, num_hiddens), device=device))

def lstm(inputs, state, params):
    [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c,
     W_hq, b_q] = params
    (H, C) = state
    outputs = []
    for X in inputs:
        I = torch.sigmoid((X @ W_xi) + (H @ W_hi) + b_i)
        F = torch.sigmoid((X @ W_xf) + (H @ W_hf) + b_f)
        O = torch.sigmoid((X @ W_xo) + (H @ W_ho) + b_o)
        C_tilda = torch.tanh((X @ W_xc) + (H @ W_hc) + b_c)
        C = F * C + I * C_tilda
        H = O * torch.tanh(C)
        Y = (H @ W_hq) + b_q
        outputs.append(Y)
    return torch.cat(outputs, dim=0), (H, C)
  • 训练和预测
vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_epochs, lr = 5001
model = d2l.RNNModelScratch(len(vocab), num_hiddens, device, get_lstm_params,
                            init_lstm_state, lstm)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)

# perplexity 1.3, 17736.0 tokens/sec on cuda:0
# time traveller for so it will leong go it we melenot ir cove i s
# traveller care be can so i ngrecpely as along the time dime
alt

总结

  • 长短期记忆网络有三种类型的门:输入门、遗忘门和输出门。
  • 长短期记忆网络的隐藏层输出包括“隐状态”和“记忆元”。只有隐状态会传递到输出层,而记忆元完全属于内部信息。
  • 长短期记忆网络可以缓解梯度消失和梯度爆炸。

Ref

  1. https://zhuanlan.zhihu.com/p/32085405
  2. https://zhuanlan.zhihu.com/p/42717426
  3. https://zh.d2l.ai/chapter_recurrent-modern/lstm.html

本文由 mdnice 多平台发布

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/147059.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

django理解02 前后端分离中的问题

前后端分离相对于传统方式的问题 前后端数据交换的问题跨域问题 页面js往自身程序(django服务)发送请求,这是浏览器默认接受响应 而请求其它地方是浏览器认为存在潜在危险。自动隔离请求!!! 跨域问题的解决…

springcloud整合nacos实现服务注册

Nacos是一个开源的分布式系统服务和基础设施解决方案,用于实现动态服务发现、配置管理和服务治理。它可以帮助开发人员和运维团队更好地管理微服务架构中的服务实例、配置信息和服务调用。 Nacos提供了服务注册与发现、动态配置管理、服务路由和负载均衡等功能&…

C++之set/multise容器

C之set/multise容器 set基本概念 set构造和赋值 #include <iostream> #include<set> using namespace std;void PrintfSet(set<int>&s) {for(set<int>::iterator it s.begin();it ! s.end();it){cout<<*it<<" ";}cout&l…

typora使用PicGo自动上传图片到chevereto图床

typora使用PicGo自动上传图片到chevereto图床 近期发现&#xff0c;gitee图床不能用了。github又涉及科学上网。搜索了开源图床方案&#xff0c;找到了chevereto&#xff0c;使用起来还不错。分享给大家。 文章目录 typora使用PicGo自动上传图片到chevereto图床chevereto图床安…

精密云工程:智能激活业务速率 ——华为云11.11联合大促倒计时 仅剩3日

现新客3.96元起&#xff0c;下单有机会抽HUAWEI P60 Art&#xff0c;福利仅限双十一&#xff0c;机会唾手可得&#xff0c;立即行动&#xff01; 双十一购物节来临倒计时&#xff0c;华为云备上多款增值产品&#xff0c;以最优品质迸发冬日技术热浪&#xff0c;满足行业技术应用…

Mac 安装 protobuf 和Android Studio 使用

1. 安装,执行命令 brew install protoc 2. Mac 错误提示&#xff1a;zsh: command not found: brew解决方法 解决方法&#xff1a;mac 安装homebrew&#xff0c; 用以下命令安装&#xff0c;序列号选择中科大&#xff08;1&#xff09;或 阿里云 /bin/zsh -c "$(curl…

MLC-LLM 支持RWKV-5推理以及对RWKV-5的一些思考

自从2023年3月左右&#xff0c;chatgpt火热起来之后&#xff0c;我把关注的一些知乎帖子都记录到了这个markdown里面&#xff0c;&#xff1a;https://github.com/BBuf/how-to-optim-algorithm-in-cuda/tree/master/large-language-model-note &#xff0c;从2023年3月左右到现…

安装插件时Vscode XHR Failed 报错ERR_CERT_AUTHORITY_INVALID

安装插件时Vscode XHR Failed 报错ERR_CERT_AUTHORITY_INVALID 今天用vscode 安装python插件时报XHR failed,无法拉取应用商城的数据&#xff0c; 报的错如下&#xff1a; ERR_CERT_AUTHORITY_INVALID 翻译过来就是证书有问题 找错误代码的方法&#xff1a; 打开vscode, 按F1…

Swift 如何打造兼容新老系统的字符串分割(split)方法

0. 概览 在 Swift 的开发中&#xff0c;我们经常要与字符串打交道。其中一个常见的操作就是用特定的“分隔符”来分割字符串&#xff0c;这里分隔符可能不仅仅是字符&#xff0c;而是多字符组成的字符串。 从 iOS 16 开始&#xff0c; 新增了对应的方法来专注此事。不过&am…

HBase中的数据表是如何用CHAT进行分区的?

问CHA&#xff1a;HBase中的数据表是如何进行分区的&#xff1f; CHAT回复&#xff1a; 在HBase中&#xff0c;数据表是水平分区的。每一个分区被称为一个region。当一个region达到给定的大小限制时&#xff0c;它会被分裂成两个新的region。 因此&#xff0c;随着数据量的增…

mac苹果笔记本应用程序在哪?有什么快捷方式吗?

苹果笔记本电脑一直以来都被广泛使用&#xff0c;而苹果的操作系统 macOS 也非常受欢迎。一台好的笔记本电脑不仅仅依赖于硬件配置&#xff0c;还需要丰富多样的应用程序来满足用户的需求。苹果笔记本应用程序在哪&#xff0c;不少mac新手用户会有这个疑问。在这篇文章中&#…

2023.11.14 hivesql的容器,数组与映射

目录 https://blog.csdn.net/m0_49956154/article/details/134365327?spm1001.2014.3001.5501https://blog.csdn.net/m0_49956154/article/details/134365327?spm1001.2014.3001.5501 8.hive的复杂类型 9.array类型: 又叫数组类型,存储同类型的单数据的集合 10.struct类型…

Selenium操作已经打开的Chrome浏览器窗口

Selenium操作已经打开的Chrome浏览器窗口 0. 背景 在使用之前的代码通过selenium操作Chrome浏览器时&#xff0c;每次都要新打开一个窗口&#xff0c;觉得麻烦&#xff0c;所以尝试使用 Selenium 获取已经打开的浏览器窗口&#xff0c;在此记录下过程 本文使用 chrome浏览器来…

场景交互与场景漫游-osgGA库(5)

osgGA库 osgGA库是OSG的一个附加的工具库&#xff0c;它为用户提供各种事件处理及操作处理。通过osgGA库读者可以像控制Windows窗口一样来处理各种事件 osgGA的事件处理器主要由两大部分组成&#xff0c;即事件适配器和动作适配器。osgGA:GUIEventHandler类主要提供了窗口系统的…

系列九、对象的生命周期和GC

一、堆细分 Java堆从GC的角度还可以细分为&#xff1a;新生代&#xff08;eden【伊甸园区】、from【幸存者0区】、to【幸存者1区】&#xff09;和老年代。 二、MinorGC的过程 复制>清空》交换 1、eden、from区中的对象复制到to区&#xff0c;年龄1 首先&#xff0c;当eden区…

我认为除了HelloWorld之外,Python的三大数据转换实例可以作为开始学习Python的入门语言。

Python的三大数据转换实例 一、反转三位数 class Solution:def funtcion(self,number):hint(number/100)tint(number%100/10)zint(number%10)return 100*z10*th if __name____main__:solution Solution()num123new_num solution.funtcion(num)print("输入:{}".fo…

【仿真动画】ABB IRB 8700 机器人搬运(ruckig在线轨迹生成)动画欣赏

场景 动画 一、IRB 8700简介 二、动画脚本重点分析 2.1 sim.moveToPose 通过在两个 poses 之间执行插值&#xff0c;使用 Ruckig 在线轨迹生成器生成对象运动数据。该函数可以通过处理 4 个运动变量&#xff08;x、y、z 和两个姿势之间的角度&#xff09;或单个运动变量&#…

深度学习数据集—细胞、微生物、显微图像数据集大合集

最近收集了一大波关于细胞、微生物、显微图像数据集&#xff0c;有细胞、微生物&#xff0c;细菌等。 接下来是每个数据的详细介绍&#xff01;&#xff01; 1、12500张血细胞增强图像&#xff08;JPEG&#xff09;数据集 该数据集包含12500张血细胞增强图像&#xff08;JPE…

vscode终端npm install报错

报错如下&#xff1a; npm WARN read-shrinkwrap This version of npm is compatible with lockfileVersion1, but package-lock.json was generated for lockfileVersion2. Ill try to do my best with it! npm ERR! code EPERM npm ERR! syscall open npm ERR! errno -4048…