【动手学深度学习-pytorch】9.2长短期记忆网络(LSTM)

长期以来,隐变量模型存在着长期信息保存和短期输入缺失的问题。 解决这一问题的最早方法之一是长短期存储器(long short-term memory,LSTM) (Hochreiter and Schmidhuber, 1997)。 它有许多与门控循环单元( 9.1节)一样的属性。 有趣的是,长短期记忆网络的设计比门控循环单元稍微复杂一些, 却比门控循环单元早诞生了近20年.

门控记忆元 cell

  • 长短期记忆网络引入了记忆元(memory cell),或简称为单元(cell)
  • 为了控制记忆元,我们需要许多门。输入门 输出门 遗忘门
  • 其中一个门用来从单元中输出条目,我们将其称为输出门(output gate)。 另外一个门用来决定何时将数据读入单元,我们将其称为输入门(input gate)。 我们还需要一种机制来重置单元的内容,由遗忘门(forget gate)来管理, 这种设计的动机与门控循环单元相同, 能够通过专用机制决定什么时候记忆或忽略隐状态中的输入。 让我们看看这在实践中是如何运作的。

输入门、忘记门和输出门

就如在门控循环单元中一样, 当前时间步的输入和前一个时间步的隐状态 作为数据送入长短期记忆网络的门中, 如 图9.2.1所示。 它们由三个具有sigmoid激活函数的全连接层处理, 以计算输入门、遗忘门和输出门的值。 因此,这三个门的值都在
的范围内。
在这里插入图片描述
在这里插入图片描述

候选记忆元

在这里插入图片描述

记忆元

在这里插入图片描述

隐状态

在这里插入图片描述

只有隐状态会传递到输出层,而记忆元完全属于内部信息

从零开始实现

现在,我们从零开始实现长短期记忆网络。 与 8.5节中的实验相同, 我们首先加载时光机器数据集。

import torch
from torch import nn
from d2l import torch as d2lbatch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)

初始化模型参数

如前所述,超参数num_hiddens定义隐藏单元的数量。 我们按照标准差
的高斯分布初始化权重,并将偏置项设为0.

def get_lstm_params(vocab_size, num_hiddens, device):num_inputs = num_outputs = vocab_sizedef normal(shape):return torch.randn(size=shape, device=device)*0.01def three():return (normal((num_inputs, num_hiddens)),normal((num_hiddens, num_hiddens)),torch.zeros(num_hiddens, device=device))W_xi, W_hi, b_i = three()  # 输入门参数W_xf, W_hf, b_f = three()  # 遗忘门参数W_xo, W_ho, b_o = three()  # 输出门参数W_xc, W_hc, b_c = three()  # 候选记忆元参数# 输出层参数W_hq = normal((num_hiddens, num_outputs))b_q = torch.zeros(num_outputs, device=device)# 附加梯度params = [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc,b_c, W_hq, b_q]for param in params:param.requires_grad_(True)return params

定义模型

在初始化函数中, 长短期记忆网络的隐状态需要返回一个额外的记忆元, 单元的值为0,形状为(批量大小,隐藏单元数)。 因此,我们得到以下的状态初始化。

def init_lstm_state(batch_size, num_hiddens, device):return (torch.zeros((batch_size, num_hiddens), device=device),torch.zeros((batch_size, num_hiddens), device=device))

实际模型的定义与我们前面讨论的一样: 提供三个门和一个额外的记忆元。 请注意,只有隐状态才会传递到输出层, 而记忆元不直接参与输出计算。

def lstm(inputs, state, params):[W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c,W_hq, b_q] = params(H, C) = stateoutputs = []for X in inputs:I = torch.sigmoid((X @ W_xi) + (H @ W_hi) + b_i)F = torch.sigmoid((X @ W_xf) + (H @ W_hf) + b_f)O = torch.sigmoid((X @ W_xo) + (H @ W_ho) + b_o)C_tilda = torch.tanh((X @ W_xc) + (H @ W_hc) + b_c)C = F * C + I * C_tildaH = O * torch.tanh(C)Y = (H @ W_hq) + b_qoutputs.append(Y)return torch.cat(outputs, dim=0), (H, C)

训练和预测

让我们通过实例化 8.5节中 引入的RNNModelScratch类来训练一个长短期记忆网络, 就如我们在 9.1节中所做的一样。

vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_epochs, lr = 500, 1
model = d2l.RNNModelScratch(len(vocab), num_hiddens, device, get_lstm_params,init_lstm_state, lstm)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)

简洁实现

使用高级API,我们可以直接实例化LSTM模型。 高级API封装了前文介绍的所有配置细节。 这段代码的运行速度要快得多, 因为它使用的是编译好的运算符而不是Python来处理之前阐述的许多细节。

num_inputs = vocab_size
lstm_layer = nn.LSTM(num_inputs, num_hiddens)
model = d2l.RNNModel(lstm_layer, len(vocab))
model = model.to(device)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)

总结

  • 长短期记忆网络,包含三个门:输入门、忘记门和遗忘门。其中遗忘门用于重置单元的内容,通过专用的机制决定什么时候记忆或者忽略状态中的输入。

  • 长短期记忆网络的隐藏层输出包括“隐状态”和“记忆元”。只有隐状态会传递到输出层,而记忆元完全属于内部信息。

  • 长短期记忆网络可以缓解梯度消失和梯度爆炸。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/780996.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

HTX Ventures:为什么BounceBit可能成为新的BTC生态解决方案?

随着BTC现货ETF的通过,全球各大机构和个人都在不断加码对BTC的持仓,BTC价格也随之上升,目前已上升至全球市值排名前十的资产。在本轮市场周期中,BTC铭文和BTC扩容是两个被市场高度关注的细分赛道。BTC生态资产的多元化收益探索正在…

mydumper和myloader对MySQL数据备份和恢复

安装教程省略 一、mydumper数据备份 mydumper -u root -p 123456 -P 3306 -B db1 -o /data/20240329root:数据库用户名 123456:密码 3306:端口 db1:数据库库名 /data/20240329:导出的备份文件存放位置 导出的数据文…

Django详细教程(一)

文章目录 一、安装Django二、创建项目1.终端创建项目2.Pycharm创建项目(专业版才可以)3.默认文件介绍 三、创建app1.app介绍2.默认文件介绍 四、快速上手1.写一个网页步骤1:注册app 【settings.py】步骤2:编写URL和视图函数对应关…

mysql权限相关操作

创建mysql用户并开通某数据库的特定权限 CREATE USER username% IDENTIFIED BY 123456; GRANT INSERT,DELETE,UPDATE,SELECT ON xxxdb.* TO username%; GRANT ALL PRIVILEGES ON caieinstitute.* TO caie%;//给全部管理权限 修改某用户登录所需使用的IP select * from user w…

面试题:@Component, @Service, @Repository, @Controller 注解的区别与用途

在Spring框架中,Component, Service, Repository, 和 Controller 都是用来标记Bean并将其纳入Spring IoC容器管理的注解,它们的主要区别在于用途和语义上的强调,旨在提高代码的可读性和更好的组织架构。 1. Component - 用途这是Spring中定…

基于微信小程序的校园服务平台设计与实现(程序+论文)

本文以校园服务平台为研究对象,首先分析了当前校园服务平台的研究现状,阐述了本系统设计的意义和背景,运用微信小程序开发工具和云开发技术,研究和设计了一个校园服务平台,以满足学生在校园生活中的多样化需求。通过引…

最优算法100例之13-输出第n个丑数

专栏主页:计算机专业基础知识总结(适用于期末复习考研刷题求职面试)系列文章https://blog.csdn.net/seeker1994/category_12585732.html 题目描述 把只包含因子2、3和5的数称作丑数(Ugly Number)。例如6、8都是丑数,但14不是,因为它包含因子7。 习惯上我们把1当…

使用Hive对HDFS中数据查询的优点

目录 摘要一、Hive是什么二、HDFS是什么三、Hive与HDFS的关系四、什么是HiveQL五、什么是mapreduce六、Hive如何将查询转为mapreduce任务七、Hadoop生态系统中的高性能引擎八、使用Hadoop的优点 摘要 Hadoop生态系统中包含了多个关键组件,如Hive、HDFS、MapReduce等…

Typora:一款值得尝试的Markdown编辑器

引言: 随着博客的兴起,越来越多的人开始写博客。而Markdown作为一种轻量级标记语言,因其简洁、易读、易写、易转换等特点而被广泛使用。Markdown的语法简单易学,使用起来也比较方便。但是,为了更好地使用Markdown&…

3.滑行。

3.滑行 - 蓝桥云课 (lanqiao.cn) 问题描述 小蓝准备在一个空旷的场地里面滑行,这个场地的高度不一小蓝用一个n行m列的矩阵来表示场地,矩阵中的数值表示场地的高度 如果小蓝在某个位置,而他上、下、左、右中有一个位置的高度(严格)低于当前的高…

目标检测评价标准

主要借鉴:https://github.com/rafaelpadilla/Object-Detection-Metrics?tabreadme-ov-file 主要评价指标、术语: Intersection Over Union (IOU):两个检测框交集面积与并集面积的比值 True Positive (TP):IOU大于阈值的检测框…

Elasticsearch入门及常用命令和Spring中的常用操作

入门 官网 简介 一个分布式的、Restful风格的搜索引擎。支持对各种类型的数据的检索。搜索速度快,可以提供实时的搜索服务。便于水平扩展,每秒可以处理PB级海量数据。 常用术语 索引:与MySQL数据库中的Database相对应类型:与…

Unity中如何实现草的LOD

1)Unity中如何实现草的LOD 2)用Compute Shader处理图像数据后在安卓机上不能正常显示渲染纹理 3)关于进游戏程序集加载的问题 4)预制件编辑模式一直在触发自动保存 这是第379篇UWA技术知识分享的推送,精选了UWA社区的热…

pycharm修改主题颜色和注释颜色

目录 一、修改主题颜色 二、修改注释颜色 一、修改主题颜色 总结的来说就是:File-Settings-Appearance-Theme。 有三种主题: Darcula:默认主题,可以看作是黑的: IntelliJ Light:可以看作是白的: High con…

DeepFaceLive换脸小白教程,看这一篇就玩了

先官网下个软件DeepFaceLive - DeepfakeVFX.com 解压安装程序,准备安装, 解压,注意不要有中文路径!

海外媒体发稿:如何选择适合自己的海外媒体推广发稿平台-华媒舍

在数字化时代,海外媒体推广成为企业扩大国际影响力的重要方式之一。海外媒体平台琳琅满目,如何选择适合自己的平台成为了一个需要深入理解和研究的问题。本文将以科普的方式介绍如何选择适合自己的海外媒体推广发稿平台。 1. 形象建立 要选择能够准确展…

SpringBoot 使用【AOP 切面+注解】实现在请求调用 Controller 方法前修改请求参数和在结果返回之前修改返回结果

前情提要 在项目中需要实现 在请求调用 Controller 方法前修改请求参数和在结果返回之前修改返回结果。 我们可以使用 AOP 切面注解的形式实现。这样我们就可以在不修改原始代码的情况下,通过切面类在方法调用前后插入额外的逻辑。 解决方案 自定义注解 PreProc…

vue3源码解析——ref和reactive定义响应式的区别

ref 和 reactive 是 Vue 3.0 中用于定义响应式数据的两个新 API。它们有以下区别: ref 定义单个响应式数据 数据类型可以是任意类型。它通常用于定义原始数据类型为响应式数据。返回一个响应式对象,该对象包含一个 .value 属性,可用于获取和设…

【全栈小5】我的创作纪念日

目录 前言机缘收获粉丝和原创个人成就六边形战士 回顾文章原代码代码优化 憧憬 前言 全栈小5 ,有幸再次遇见你: 还记得 2019 年 03 月 29 日吗? 你撰写了第 1 篇技术博客: 《前端 - 仿动态效果 - 展开信息图标》 在这平凡的一天&…

SpringBoot -- Profiles

Profiles具备环境隔离能力,可以将我们的项目快速切换开发、测试、生产环境 我们的使用步骤也很简单: 1. 标识环境:指定哪些组件、配置在哪个环境生效 2. 切换环境:这个环境对应的所有组件和配置就应该生效 接下来就进行详细的…