【动手学深度学习】--长短期记忆网络LSTM

文章目录

  • 长短期记忆网络LSTM
    • 1.门控记忆元
      • 1.1输入门、忘记门、输出门
      • 1.2候选记忆元
      • 1.3记忆元
      • 1.4隐状态
    • 2.从零实现
      • 2.1加载数据集
      • 2.2初始化模型参数
      • 2.3定义模型
      • 2.4 训练与预测
    • 3.简洁实现

长短期记忆网络LSTM

学习视频:长短期记忆网络(LSTM)【动手学深度学习v2】

官方笔记:长短期记忆网络(LSTM)

长期以来,隐变量模型存在着长期信息保存和短期输入缺失的问题,解决这一问题的最早方法之一是长短期存储器(LSTM),它有许多与GRU一样的属性,有趣的是,长短期记忆网络的设计比门控循环单元稍微复杂一些, 却比门控循环单元早诞生了近20年。

1.门控记忆元

可以说,长短期记忆网络的设计灵感来自于计算机的逻辑门。 长短期记忆网络引入了记忆元(memory cell),或简称为单元(cell)。 有些文献认为记忆元是隐状态的一种特殊类型, 它们与隐状态具有相同的形状,其设计目的是用于记录附加的信息。 为了控制记忆元,我们需要许多门。 其中一个门用来从单元中输出条目,我们将其称为输出门(output gate)。 另外一个门用来决定何时将数据读入单元,我们将其称为输入门(input gate)。 我们还需要一种机制来重置单元的内容,由遗忘门(forget gate)来管理, 这种设计的动机与门控循环单元相同, 能够通过专用机制决定什么时候记忆或忽略隐状态中的输入。

1.1输入门、忘记门、输出门

  • 忘记门:将值朝0减少
  • 输入门:决定不是忽略掉输入数据
  • 输出门:决定是不是使用隐状态

就如在门控循环单元中一样, 当前时间步的输入和前一个时间步的隐状态 作为数据送入长短期记忆网络的门中,如下图所示, 它们由三个具有sigmoid激活函数的全连接层处理, 以计算输入门、遗忘门和输出门的值。 因此,这三个门的值都在(0,1)的范围内。

image-20230909153058460

1.2候选记忆元

image-20230909153136156

1.3记忆元

image-20230909153206499

1.4隐状态

image-20230909153227773

2.从零实现

2.1加载数据集

import torch
from torch import nn
from d2l import torch as d2lbatch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)

2.2初始化模型参数

接下来,我们需要定义和初始化模型参数。 如前所述,超参数num_hiddens定义隐藏单元的数量。 我们按照标准差0.01的高斯分布初始化权重,并将偏置项设为0。

def get_lstm_params(vocab_size, num_hiddens, device):num_inputs = num_outputs = vocab_sizedef normal(shape):return torch.randn(size=shape, device=device)*0.01def three():return (normal((num_inputs, num_hiddens)),normal((num_hiddens, num_hiddens)),torch.zeros(num_hiddens, device=device))W_xi, W_hi, b_i = three()  # 输入门参数W_xf, W_hf, b_f = three()  # 遗忘门参数W_xo, W_ho, b_o = three()  # 输出门参数W_xc, W_hc, b_c = three()  # 候选记忆元参数# 输出层参数W_hq = normal((num_hiddens, num_outputs))b_q = torch.zeros(num_outputs, device=device)# 附加梯度params = [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc,b_c, W_hq, b_q]for param in params:param.requires_grad_(True)return params

2.3定义模型

在初始化函数中, 长短期记忆网络的隐状态需要返回一个额外的记忆元, 单元的值为0,形状为(批量大小,隐藏单元数)。 因此,我们得到以下的状态初始化。

def init_lstm_state(batch_size, num_hiddens, device):return (torch.zeros((batch_size, num_hiddens), device=device),torch.zeros((batch_size, num_hiddens), device=device))

实际模型的定义与我们前面讨论的一样: 提供三个门和一个额外的记忆元。 请注意,只有隐状态才会传递到输出层, 而记忆元 C t C_t Ct不直接参与输出计算。

def lstm(inputs, state, params):[W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c,W_hq, b_q] = params(H, C) = stateoutputs = []for X in inputs:I = torch.sigmoid((X @ W_xi) + (H @ W_hi) + b_i)F = torch.sigmoid((X @ W_xf) + (H @ W_hf) + b_f)O = torch.sigmoid((X @ W_xo) + (H @ W_ho) + b_o)C_tilda = torch.tanh((X @ W_xc) + (H @ W_hc) + b_c)C = F * C + I * C_tildaH = O * torch.tanh(C)Y = (H @ W_hq) + b_qoutputs.append(Y)return torch.cat(outputs, dim=0), (H, C)

2.4 训练与预测

vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_epochs, lr = 500, 1
model = d2l.RNNModelScratch(len(vocab), num_hiddens, device, get_lstm_params,init_lstm_state, lstm)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)

image-20230909153543972

3.简洁实现

使用高级API,我们可以直接实例化LSTM模型。 高级API封装了前文介绍的所有配置细节。 这段代码的运行速度要快得多, 因为它使用的是编译好的运算符而不是Python来处理之前阐述的许多细节。

num_inputs = vocab_size
lstm_layer = nn.LSTM(num_inputs, num_hiddens)
model = d2l.RNNModel(lstm_layer, len(vocab))
model = model.to(device)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)

image-20230909153613692

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/77457.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

redis 常用数据结构2

目录 list LPUSH LRANGE LPUSHX RPUSH RPUSHX LPOP RPOP LINDEX LINSERT LLEN LREM LTRIM LSET BLPOP / BRPOP 编码方式 set SADD SMEMBERS SISMEMBER SPOP SCARD SRANDMEMBER SMOVE SREM SINTER SINTERSTORE SUNION SUNIONSTORE SDIFF SDIFFSTO…

leetcode面试题:交换和(三种方法实现)

交换和: 给定两个整数数组,请交换一对数值(每个数组中取一个数值),使得两个数组所有元素的和相等。 返回一个数组,第一个元素是第一个数组中要交换的元素,第二个元素是第二个数组中要交换的元…

白鲸开源 X SelectDB 金融大数据联合解决方案公布!从源头解决大数据开发挑战

业务挑战与痛点 随着互联网技术的发展、云计算技术的成熟、人工智能技术的兴起和数字化经济的崛起,数据已成为企业的核心资产。在金融行业中,数字化已成为了支撑各类业务场景的核心力量,包括个人理财、企业融资、股票交易、保险理赔、贷款服…

如何构建一个简单的前端框架

先让我来解释一下什么是前端框架。所谓的前端框架&#xff0c;就是一种能够让我们避免去写常规的HTML和JavaScript代码 <p id"cool-para"></p> <script>const coolPara Test;const el document.getElementById(cool-para);el.innerText coolPa…

TensorFlow 02(张量)

一、张量 张量Tensor 张量是一个多维数组。与NumPy ndarray对象类似&#xff0c;tf.Tensor对象也具有数据类型和形状。如下图所示: 此外&#xff0c;tf.Tensors可以保留在GPU中。TensorFlow提供了丰富的操作库 (tf.add&#xff0c;tf.matmul,tf.linalg.inv等)&#xff0c;它们…

文字点选验证码识别(下)-训练一个孪生神经网络模型

声明 本文以教学为基准、本文提供的可操作性不得用于任何商业用途和违法违规场景。 本人对任何原因在使用本人中提供的代码和策略时可能对用户自己或他人造成的任何形式的损失和伤害不承担责任。 如有侵权,请联系我进行删除。 文章中没有代码,只有过程思路,请大家谨慎订阅。…

HarmonyOS Codelab 优秀样例——溪村小镇(ArkTS)

一、介绍 溪村小镇是一款展示溪流背坡村园区风貌的应用&#xff0c;包括园区内的导航功能&#xff0c;小火车行车状态查看&#xff0c;以及各区域的风景展览介绍&#xff0c;主要用于展示HarmonyOS的ArkUI能力和动画效果。具体包括如下功能&#xff1a; 打开应用时进入启动页&a…

哪种IP更适合你的数据抓取需求?

程序员大佬们好&#xff01;今天我要和大家分享一个关于数据抓取的话题&#xff0c;那就是Socks5爬虫ip和动态IP之间的比较。在进行数据抓取时&#xff0c;选择适合自己需求的工具和技术是非常重要的。Socks5爬虫ip和动态IP都是常见的网络工具&#xff0c;它们在数据抓取方面都…

Spring Boot 中的参数验证和自定义响应处理,使用 @Valid 注解

&#x1f60a; 作者&#xff1a; 一恍过去 &#x1f496; 主页&#xff1a; https://blog.csdn.net/zhuocailing3390 &#x1f38a; 社区&#xff1a; Java技术栈交流 &#x1f389; 主题&#xff1a; Spring Boot 中的参数验证和自定义响应处理&#xff0c;使用 Valid 注解…

【sgCreateAPI】自定义小工具:敏捷开发→自动化生成API接口脚本(接口代码生成工具)

<template><div :class"$options.name"><div class"sg-head">接口代码生成工具</div><div class"sg-container"><div class"sg-start "><div style"margin-bottom: 10px;">接口地…

如何进行函数的递归调用?

函数的递归调用是一种强大而常见的编程技巧&#xff0c;它允许函数在自身内部调用自己。递归在解决问题的分而治之策略中非常有用&#xff0c;可以将大问题分解为更小的、相同或类似的子问题&#xff0c;然后通过逐步解决这些子问题来解决原始问题。在本文中&#xff0c;我将详…

信息化发展39

IT 服务管理 1 、IT 服务管理是通过主动管理和流程的持续改进来确保IT 服务交付有效且高效的一组活动。 2 、IT 服务管理由若干不同的活动组成&#xff1a; 服务台、事件管理、问题管理、变更管理、配置管理、发布管理、服务级别管理、财务管理、容量管理、服务连续性管理和可…

做期权卖方一般会怎么选择合约?

我们知道期权有多种获利方式&#xff0c;其中靠时间能赚钱的是做期权卖方策略&#xff0c;虽然赚得慢&#xff0c;但可以稳稳地收入权利金&#xff0c;适合某些稳健风格的投资者&#xff0c;胜率对比买方也是高了很多&#xff0c;那么做期权卖方一般会怎么选择合约&#xff1f;…

Unity——JSON的读取

一、读取JSON 在实际中&#xff0c;读取JSON比保存JSON重要得多。因为存档、发送数据包往往可以采用其他序列化方法&#xff0c;但游戏的配置文件使用JSON格式比较常见。游戏的配置数据不属于动态数据&#xff0c;属于游戏资源&#xff0c;但很适合用JSON表示。 下面以一个简…

Mybatis---第二篇

系列文章目录 文章目录 系列文章目录一、#{}和${}的区别是什么?二、简述 Mybatis 的插件运行原理,如何编写一个插件一、#{}和${}的区别是什么? #{}是预编译处理、是占位符, KaTeX parse error: Expected EOF, got # at position 27: …接符。 Mybatis 在处理#̲{}时,会将…

Linux之Socket函数(详细篇)

本篇是基于Linux man手册的一些总结 socket作用&#xff1a; create an endpoint for communication 函数结构 c #include <sys/types.h> /* See NOTES */ #include <sys/socket.h> int socket(int domain, int type, int protocol); 描述 socket() …

LeetCode(力扣)45. 跳跃游戏 IIPython

LeetCode45. 跳跃游戏 II 题目链接代码 题目链接 https://leetcode.cn/problems/jump-game-ii/description/ 代码 class Solution:def jump(self, nums: List[int]) -> int:if len(nums) 1:return 0curdis 0nextdis 0step 0for i in range(len(nums)):nextdis max(…

华为云耀云服务器HECS安装Docker

先去购买服务器&#xff0c;这里就不多说了 1、进入自己买的服务器&#xff0c; 找到切换系统 2、选择centOs镜像 安装docker 卸载旧版本 较旧的 Docker 版本称为 docker 或 docker-engine 。如果已安装这些程序&#xff0c;请卸载它们以及相关的依赖项。 yum remove docker…

职业规划就问它!海量知识与智慧,AIGC助你冲破择业迷茫

数字化时代的兴起改变了我们的日常生活和职业工作方式。科技迅猛的发展&#xff0c;尤其是人工智能的崛起&#xff0c;将我们引入了一个崭新的智能化时代。在这个时代中&#xff0c;AI被认为是从"数字时代"向"数智时代"转变的关键元素&#xff0c;引领着这…

蛇形填数 rust解法

蛇形填数。 在nn方阵里填入1&#xff0c;2&#xff0c;…&#xff0c;nn&#xff0c;要求填成蛇形。例如&#xff0c;n&#xff1d;4时方阵为&#xff1a; 10 11 12 1 9 16 13 2 8 15 14 3 7 6 5 4 解法如下&#xff1a; use std::io;fn main() {let mut buf String::new();…