【动手学深度学习-pytorch】9.2长短期记忆网络(LSTM)

长期以来,隐变量模型存在着长期信息保存和短期输入缺失的问题。 解决这一问题的最早方法之一是长短期存储器(long short-term memory,LSTM) (Hochreiter and Schmidhuber, 1997)。 它有许多与门控循环单元( 9.1节)一样的属性。 有趣的是,长短期记忆网络的设计比门控循环单元稍微复杂一些, 却比门控循环单元早诞生了近20年.

门控记忆元 cell

  • 长短期记忆网络引入了记忆元(memory cell),或简称为单元(cell)
  • 为了控制记忆元,我们需要许多门。输入门 输出门 遗忘门
  • 其中一个门用来从单元中输出条目,我们将其称为输出门(output gate)。 另外一个门用来决定何时将数据读入单元,我们将其称为输入门(input gate)。 我们还需要一种机制来重置单元的内容,由遗忘门(forget gate)来管理, 这种设计的动机与门控循环单元相同, 能够通过专用机制决定什么时候记忆或忽略隐状态中的输入。 让我们看看这在实践中是如何运作的。

输入门、忘记门和输出门

就如在门控循环单元中一样, 当前时间步的输入和前一个时间步的隐状态 作为数据送入长短期记忆网络的门中, 如 图9.2.1所示。 它们由三个具有sigmoid激活函数的全连接层处理, 以计算输入门、遗忘门和输出门的值。 因此,这三个门的值都在
的范围内。
在这里插入图片描述
在这里插入图片描述

候选记忆元

在这里插入图片描述

记忆元

在这里插入图片描述

隐状态

在这里插入图片描述

只有隐状态会传递到输出层,而记忆元完全属于内部信息

从零开始实现

现在,我们从零开始实现长短期记忆网络。 与 8.5节中的实验相同, 我们首先加载时光机器数据集。

import torch
from torch import nn
from d2l import torch as d2lbatch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)

初始化模型参数

如前所述,超参数num_hiddens定义隐藏单元的数量。 我们按照标准差
的高斯分布初始化权重,并将偏置项设为0.

def get_lstm_params(vocab_size, num_hiddens, device):num_inputs = num_outputs = vocab_sizedef normal(shape):return torch.randn(size=shape, device=device)*0.01def three():return (normal((num_inputs, num_hiddens)),normal((num_hiddens, num_hiddens)),torch.zeros(num_hiddens, device=device))W_xi, W_hi, b_i = three()  # 输入门参数W_xf, W_hf, b_f = three()  # 遗忘门参数W_xo, W_ho, b_o = three()  # 输出门参数W_xc, W_hc, b_c = three()  # 候选记忆元参数# 输出层参数W_hq = normal((num_hiddens, num_outputs))b_q = torch.zeros(num_outputs, device=device)# 附加梯度params = [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc,b_c, W_hq, b_q]for param in params:param.requires_grad_(True)return params

定义模型

在初始化函数中, 长短期记忆网络的隐状态需要返回一个额外的记忆元, 单元的值为0,形状为(批量大小,隐藏单元数)。 因此,我们得到以下的状态初始化。

def init_lstm_state(batch_size, num_hiddens, device):return (torch.zeros((batch_size, num_hiddens), device=device),torch.zeros((batch_size, num_hiddens), device=device))

实际模型的定义与我们前面讨论的一样: 提供三个门和一个额外的记忆元。 请注意,只有隐状态才会传递到输出层, 而记忆元不直接参与输出计算。

def lstm(inputs, state, params):[W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c,W_hq, b_q] = params(H, C) = stateoutputs = []for X in inputs:I = torch.sigmoid((X @ W_xi) + (H @ W_hi) + b_i)F = torch.sigmoid((X @ W_xf) + (H @ W_hf) + b_f)O = torch.sigmoid((X @ W_xo) + (H @ W_ho) + b_o)C_tilda = torch.tanh((X @ W_xc) + (H @ W_hc) + b_c)C = F * C + I * C_tildaH = O * torch.tanh(C)Y = (H @ W_hq) + b_qoutputs.append(Y)return torch.cat(outputs, dim=0), (H, C)

训练和预测

让我们通过实例化 8.5节中 引入的RNNModelScratch类来训练一个长短期记忆网络, 就如我们在 9.1节中所做的一样。

vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_epochs, lr = 500, 1
model = d2l.RNNModelScratch(len(vocab), num_hiddens, device, get_lstm_params,init_lstm_state, lstm)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)

简洁实现

使用高级API,我们可以直接实例化LSTM模型。 高级API封装了前文介绍的所有配置细节。 这段代码的运行速度要快得多, 因为它使用的是编译好的运算符而不是Python来处理之前阐述的许多细节。

num_inputs = vocab_size
lstm_layer = nn.LSTM(num_inputs, num_hiddens)
model = d2l.RNNModel(lstm_layer, len(vocab))
model = model.to(device)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)

总结

  • 长短期记忆网络,包含三个门:输入门、忘记门和遗忘门。其中遗忘门用于重置单元的内容,通过专用的机制决定什么时候记忆或者忽略状态中的输入。

  • 长短期记忆网络的隐藏层输出包括“隐状态”和“记忆元”。只有隐状态会传递到输出层,而记忆元完全属于内部信息。

  • 长短期记忆网络可以缓解梯度消失和梯度爆炸。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/780996.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

HTX Ventures:为什么BounceBit可能成为新的BTC生态解决方案?

随着BTC现货ETF的通过,全球各大机构和个人都在不断加码对BTC的持仓,BTC价格也随之上升,目前已上升至全球市值排名前十的资产。在本轮市场周期中,BTC铭文和BTC扩容是两个被市场高度关注的细分赛道。BTC生态资产的多元化收益探索正在…

Django详细教程(一)

文章目录 一、安装Django二、创建项目1.终端创建项目2.Pycharm创建项目(专业版才可以)3.默认文件介绍 三、创建app1.app介绍2.默认文件介绍 四、快速上手1.写一个网页步骤1:注册app 【settings.py】步骤2:编写URL和视图函数对应关…

基于微信小程序的校园服务平台设计与实现(程序+论文)

本文以校园服务平台为研究对象,首先分析了当前校园服务平台的研究现状,阐述了本系统设计的意义和背景,运用微信小程序开发工具和云开发技术,研究和设计了一个校园服务平台,以满足学生在校园生活中的多样化需求。通过引…

最优算法100例之13-输出第n个丑数

专栏主页:计算机专业基础知识总结(适用于期末复习考研刷题求职面试)系列文章https://blog.csdn.net/seeker1994/category_12585732.html 题目描述 把只包含因子2、3和5的数称作丑数(Ugly Number)。例如6、8都是丑数,但14不是,因为它包含因子7。 习惯上我们把1当…

目标检测评价标准

主要借鉴:https://github.com/rafaelpadilla/Object-Detection-Metrics?tabreadme-ov-file 主要评价指标、术语: Intersection Over Union (IOU):两个检测框交集面积与并集面积的比值 True Positive (TP):IOU大于阈值的检测框…

Elasticsearch入门及常用命令和Spring中的常用操作

入门 官网 简介 一个分布式的、Restful风格的搜索引擎。支持对各种类型的数据的检索。搜索速度快,可以提供实时的搜索服务。便于水平扩展,每秒可以处理PB级海量数据。 常用术语 索引:与MySQL数据库中的Database相对应类型:与…

Unity中如何实现草的LOD

1)Unity中如何实现草的LOD 2)用Compute Shader处理图像数据后在安卓机上不能正常显示渲染纹理 3)关于进游戏程序集加载的问题 4)预制件编辑模式一直在触发自动保存 这是第379篇UWA技术知识分享的推送,精选了UWA社区的热…

pycharm修改主题颜色和注释颜色

目录 一、修改主题颜色 二、修改注释颜色 一、修改主题颜色 总结的来说就是:File-Settings-Appearance-Theme。 有三种主题: Darcula:默认主题,可以看作是黑的: IntelliJ Light:可以看作是白的: High con…

DeepFaceLive换脸小白教程,看这一篇就玩了

先官网下个软件DeepFaceLive - DeepfakeVFX.com 解压安装程序,准备安装, 解压,注意不要有中文路径!

海外媒体发稿:如何选择适合自己的海外媒体推广发稿平台-华媒舍

在数字化时代,海外媒体推广成为企业扩大国际影响力的重要方式之一。海外媒体平台琳琅满目,如何选择适合自己的平台成为了一个需要深入理解和研究的问题。本文将以科普的方式介绍如何选择适合自己的海外媒体推广发稿平台。 1. 形象建立 要选择能够准确展…

vue3源码解析——ref和reactive定义响应式的区别

ref 和 reactive 是 Vue 3.0 中用于定义响应式数据的两个新 API。它们有以下区别: ref 定义单个响应式数据 数据类型可以是任意类型。它通常用于定义原始数据类型为响应式数据。返回一个响应式对象,该对象包含一个 .value 属性,可用于获取和设…

【全栈小5】我的创作纪念日

目录 前言机缘收获粉丝和原创个人成就六边形战士 回顾文章原代码代码优化 憧憬 前言 全栈小5 ,有幸再次遇见你: 还记得 2019 年 03 月 29 日吗? 你撰写了第 1 篇技术博客: 《前端 - 仿动态效果 - 展开信息图标》 在这平凡的一天&…

【JS】null和undefined有什么区别

前言 JS的作者Brendan Eich曾说过两者的区别: null means “no object”, undefined > “no value”.Really it’s an abstraction leak:null and objects shared a Mocha type tag. 翻译后: null 表示“没有对象”,undefined…

STM32学习笔记(9_3)- USART串口代码

无人问津也好,技不如人也罢,都应静下心来,去做该做的事。 最近在学STM32,所以也开贴记录一下主要内容,省的过目即忘。视频教程为江科大(改名江协科技),网站jiangxiekeji.com 本期介…

Memcached 教程之Memcached介绍(一)

Memcached 教程 Memcached是一个自由开源的,高性能,分布式内存对象缓存系统。 Memcached是以LiveJournal旗下Danga Interactive公司的Brad Fitzpatric为首开发的一款软件。现在已成为mixi、hatena、Facebook、Vox、LiveJournal等众多服务中提高Web应用…

POSIX信号量

1.快速认识信号量接口 POSIX信号量和SystemV信号量作用相同,都是用于同步操作,达到无冲突的访问共享资源目的。 但POSIX可以用于线程间同步。我们之前认识SystemV信号量时有这样三个结论: 1.信号量的本质是一把计数器 2.申请信号量本质就是预…

进程调度算法

进程调度算法 进程调度算法先来先服务调度基于优先级调度(Priority Scheduling)短进程优先 / 最短剩余时间优先轮转法(Round-Robin Scheduling)高响应比优先调度算法(Highest Response Ratio Next)多级反馈…

jupyter 设置工作目录

本博客主要介绍: 如何为jupyter设置工作目录 1.打开 anaconda prompt , 执行 jupyter notebook --generate-config 执行这个命令后会生成一个配置文件 2. 打开jupyter_notebook_config.py文件编辑 搜索notebook_dir,把这行代码的注释取消,…

stm32再实现感应开关盖垃圾桶

一、项目需求 检测靠近时,垃圾桶自动开盖并伴随滴一声,2秒后关盖 发生震动时,垃圾桶自动开盖并伴随滴一声,2秒后关盖 按下按键时,垃圾桶自动开盖并伴随滴一声,2秒后关盖 硬件清单 SG90 舵机,…

HTTP 与 HTTPS 的区别

基本概念 HTTP(HyperText Transfer Protocol:超文本传输协议)是一种应用层协议,主要用于在网络上进行信息的传递,特别是用于Web浏览器和服务器之间的通信。 它使用明文方式发送数据,这意味着传输的内容可…