【动手学深度学习】--长短期记忆网络LSTM

文章目录

  • 长短期记忆网络LSTM
    • 1.门控记忆元
      • 1.1输入门、忘记门、输出门
      • 1.2候选记忆元
      • 1.3记忆元
      • 1.4隐状态
    • 2.从零实现
      • 2.1加载数据集
      • 2.2初始化模型参数
      • 2.3定义模型
      • 2.4 训练与预测
    • 3.简洁实现

长短期记忆网络LSTM

学习视频:长短期记忆网络(LSTM)【动手学深度学习v2】

官方笔记:长短期记忆网络(LSTM)

长期以来,隐变量模型存在着长期信息保存和短期输入缺失的问题,解决这一问题的最早方法之一是长短期存储器(LSTM),它有许多与GRU一样的属性,有趣的是,长短期记忆网络的设计比门控循环单元稍微复杂一些, 却比门控循环单元早诞生了近20年。

1.门控记忆元

可以说,长短期记忆网络的设计灵感来自于计算机的逻辑门。 长短期记忆网络引入了记忆元(memory cell),或简称为单元(cell)。 有些文献认为记忆元是隐状态的一种特殊类型, 它们与隐状态具有相同的形状,其设计目的是用于记录附加的信息。 为了控制记忆元,我们需要许多门。 其中一个门用来从单元中输出条目,我们将其称为输出门(output gate)。 另外一个门用来决定何时将数据读入单元,我们将其称为输入门(input gate)。 我们还需要一种机制来重置单元的内容,由遗忘门(forget gate)来管理, 这种设计的动机与门控循环单元相同, 能够通过专用机制决定什么时候记忆或忽略隐状态中的输入。

1.1输入门、忘记门、输出门

  • 忘记门:将值朝0减少
  • 输入门:决定不是忽略掉输入数据
  • 输出门:决定是不是使用隐状态

就如在门控循环单元中一样, 当前时间步的输入和前一个时间步的隐状态 作为数据送入长短期记忆网络的门中,如下图所示, 它们由三个具有sigmoid激活函数的全连接层处理, 以计算输入门、遗忘门和输出门的值。 因此,这三个门的值都在(0,1)的范围内。

image-20230909153058460

1.2候选记忆元

image-20230909153136156

1.3记忆元

image-20230909153206499

1.4隐状态

image-20230909153227773

2.从零实现

2.1加载数据集

import torch
from torch import nn
from d2l import torch as d2lbatch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)

2.2初始化模型参数

接下来,我们需要定义和初始化模型参数。 如前所述,超参数num_hiddens定义隐藏单元的数量。 我们按照标准差0.01的高斯分布初始化权重,并将偏置项设为0。

def get_lstm_params(vocab_size, num_hiddens, device):num_inputs = num_outputs = vocab_sizedef normal(shape):return torch.randn(size=shape, device=device)*0.01def three():return (normal((num_inputs, num_hiddens)),normal((num_hiddens, num_hiddens)),torch.zeros(num_hiddens, device=device))W_xi, W_hi, b_i = three()  # 输入门参数W_xf, W_hf, b_f = three()  # 遗忘门参数W_xo, W_ho, b_o = three()  # 输出门参数W_xc, W_hc, b_c = three()  # 候选记忆元参数# 输出层参数W_hq = normal((num_hiddens, num_outputs))b_q = torch.zeros(num_outputs, device=device)# 附加梯度params = [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc,b_c, W_hq, b_q]for param in params:param.requires_grad_(True)return params

2.3定义模型

在初始化函数中, 长短期记忆网络的隐状态需要返回一个额外的记忆元, 单元的值为0,形状为(批量大小,隐藏单元数)。 因此,我们得到以下的状态初始化。

def init_lstm_state(batch_size, num_hiddens, device):return (torch.zeros((batch_size, num_hiddens), device=device),torch.zeros((batch_size, num_hiddens), device=device))

实际模型的定义与我们前面讨论的一样: 提供三个门和一个额外的记忆元。 请注意,只有隐状态才会传递到输出层, 而记忆元 C t C_t Ct不直接参与输出计算。

def lstm(inputs, state, params):[W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c,W_hq, b_q] = params(H, C) = stateoutputs = []for X in inputs:I = torch.sigmoid((X @ W_xi) + (H @ W_hi) + b_i)F = torch.sigmoid((X @ W_xf) + (H @ W_hf) + b_f)O = torch.sigmoid((X @ W_xo) + (H @ W_ho) + b_o)C_tilda = torch.tanh((X @ W_xc) + (H @ W_hc) + b_c)C = F * C + I * C_tildaH = O * torch.tanh(C)Y = (H @ W_hq) + b_qoutputs.append(Y)return torch.cat(outputs, dim=0), (H, C)

2.4 训练与预测

vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_epochs, lr = 500, 1
model = d2l.RNNModelScratch(len(vocab), num_hiddens, device, get_lstm_params,init_lstm_state, lstm)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)

image-20230909153543972

3.简洁实现

使用高级API,我们可以直接实例化LSTM模型。 高级API封装了前文介绍的所有配置细节。 这段代码的运行速度要快得多, 因为它使用的是编译好的运算符而不是Python来处理之前阐述的许多细节。

num_inputs = vocab_size
lstm_layer = nn.LSTM(num_inputs, num_hiddens)
model = d2l.RNNModel(lstm_layer, len(vocab))
model = model.to(device)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)

image-20230909153613692

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/77457.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

白鲸开源 X SelectDB 金融大数据联合解决方案公布!从源头解决大数据开发挑战

业务挑战与痛点 随着互联网技术的发展、云计算技术的成熟、人工智能技术的兴起和数字化经济的崛起,数据已成为企业的核心资产。在金融行业中,数字化已成为了支撑各类业务场景的核心力量,包括个人理财、企业融资、股票交易、保险理赔、贷款服…

TensorFlow 02(张量)

一、张量 张量Tensor 张量是一个多维数组。与NumPy ndarray对象类似,tf.Tensor对象也具有数据类型和形状。如下图所示: 此外,tf.Tensors可以保留在GPU中。TensorFlow提供了丰富的操作库 (tf.add,tf.matmul,tf.linalg.inv等),它们…

文字点选验证码识别(下)-训练一个孪生神经网络模型

声明 本文以教学为基准、本文提供的可操作性不得用于任何商业用途和违法违规场景。 本人对任何原因在使用本人中提供的代码和策略时可能对用户自己或他人造成的任何形式的损失和伤害不承担责任。 如有侵权,请联系我进行删除。 文章中没有代码,只有过程思路,请大家谨慎订阅。…

HarmonyOS Codelab 优秀样例——溪村小镇(ArkTS)

一、介绍 溪村小镇是一款展示溪流背坡村园区风貌的应用,包括园区内的导航功能,小火车行车状态查看,以及各区域的风景展览介绍,主要用于展示HarmonyOS的ArkUI能力和动画效果。具体包括如下功能: 打开应用时进入启动页&a…

哪种IP更适合你的数据抓取需求?

程序员大佬们好!今天我要和大家分享一个关于数据抓取的话题,那就是Socks5爬虫ip和动态IP之间的比较。在进行数据抓取时,选择适合自己需求的工具和技术是非常重要的。Socks5爬虫ip和动态IP都是常见的网络工具,它们在数据抓取方面都…

Spring Boot 中的参数验证和自定义响应处理,使用 @Valid 注解

😊 作者: 一恍过去 💖 主页: https://blog.csdn.net/zhuocailing3390 🎊 社区: Java技术栈交流 🎉 主题: Spring Boot 中的参数验证和自定义响应处理,使用 Valid 注解…

【sgCreateAPI】自定义小工具:敏捷开发→自动化生成API接口脚本(接口代码生成工具)

<template><div :class"$options.name"><div class"sg-head">接口代码生成工具</div><div class"sg-container"><div class"sg-start "><div style"margin-bottom: 10px;">接口地…

做期权卖方一般会怎么选择合约?

我们知道期权有多种获利方式&#xff0c;其中靠时间能赚钱的是做期权卖方策略&#xff0c;虽然赚得慢&#xff0c;但可以稳稳地收入权利金&#xff0c;适合某些稳健风格的投资者&#xff0c;胜率对比买方也是高了很多&#xff0c;那么做期权卖方一般会怎么选择合约&#xff1f;…

LeetCode(力扣)45. 跳跃游戏 IIPython

LeetCode45. 跳跃游戏 II 题目链接代码 题目链接 https://leetcode.cn/problems/jump-game-ii/description/ 代码 class Solution:def jump(self, nums: List[int]) -> int:if len(nums) 1:return 0curdis 0nextdis 0step 0for i in range(len(nums)):nextdis max(…

华为云耀云服务器HECS安装Docker

先去购买服务器&#xff0c;这里就不多说了 1、进入自己买的服务器&#xff0c; 找到切换系统 2、选择centOs镜像 安装docker 卸载旧版本 较旧的 Docker 版本称为 docker 或 docker-engine 。如果已安装这些程序&#xff0c;请卸载它们以及相关的依赖项。 yum remove docker…

职业规划就问它!海量知识与智慧,AIGC助你冲破择业迷茫

数字化时代的兴起改变了我们的日常生活和职业工作方式。科技迅猛的发展&#xff0c;尤其是人工智能的崛起&#xff0c;将我们引入了一个崭新的智能化时代。在这个时代中&#xff0c;AI被认为是从"数字时代"向"数智时代"转变的关键元素&#xff0c;引领着这…

2023年Gartner新技术与AI成熟度曲线

1. Gartner 将生成式 AI 置于 2023 年新技术成熟度曲线的顶峰&#xff0c;新兴人工智能将对商业和社会产生深远影响 根据 Gartner, Inc. 2023 年新兴技术成熟度曲线&#xff0c;生成式人工智能 (AI) 处于成熟度曲线期望的顶峰&#xff0c;预计将在两到五年内实现转型效益。生成…

htaccess绕过上传实验

实验目的 利用上传htaccess文件解析漏洞绕过验证进行上传PHP脚本木马 实验工具 火狐&#xff1a;Mozilla Firefox&#xff0c;中文俗称“火狐”&#xff08;正式缩写为Fx或fx&#xff0c;非正式缩写为FF&#xff09;&#xff0c;是一个自由及开放源代码网页浏览器&#xff0…

Linux OpenGauss 数据库远程连接

目录 前言 1. Linux 安装 openGauss 2. Linux 安装cpolar 3. 创建openGauss主节点端口号公网地址 4. 远程连接openGauss 5. 固定连接TCP公网地址 6. 固定地址连接测试 前言 openGauss是一款开源关系型数据库管理系统&#xff0c;采用木兰宽松许可证v2发行。openGauss内…

playwright自动化上传附件

需求 自动设置上传头像 过程 1. 首先保存本地一个文件&#xff0c;例如 aaa.php file_path files/aaa.png 2. 获取输入类型为 "file" 的按钮 file_input_element page.locator(input[typefile]) 3. 将本地保存的图片路径赋值 file_input_element.set_input_…

单例模式-饿汉模式、懒汉模式

单例模式&#xff0c;是设计模式的一种。 在计算机这个圈子中&#xff0c;大佬们针对一些典型的场景&#xff0c;给出了一些典型的解决方案。 目录 单例模式 饿汉模式 懒汉模式 线程安全 单例模式 单例模式又可以理解为是单个实例&#xff08;对象&#xff09; 在有些场…

深圳唯创知音电子将参加IOTE 2023第二十届国际物联网展•深圳站

​ 2023年9月20~22日&#xff0c;深圳唯创知音电子将在 深圳宝安国际会展中心&#xff08;9号馆9B1&#xff09;为您全面展示最新的芯片产品及应用方案&#xff0c;助力传感器行业的发展。 作为全球领先的芯片供应商之一&#xff0c;深圳唯创知音电子一直致力于为提供高质量、…

【Web】vue开发环境搭建教程(详细)

系列文章 C#底层库–记录日志帮助类 本文链接&#xff1a;https://blog.csdn.net/youcheng_ge/article/details/124187709 文章目录 系列文章前言一、安装准备1.1 node.js1.2 国内镜像站1.3 Vue脚手架1.4 element ui1.5 Visual Studio Code 二、安装步骤2.1 下载msi安装包2.2 …

骨传导耳机对人体有危险吗?会损害听力吗?

如果在使用骨传导耳机的时候控制好时间和音量&#xff0c;是不会对人体带来危险和造成伤害的。 下面跟大家解释一下为什么骨传导耳机对人体没有危害&#xff0c;最大的原因就是骨传导耳机不需要空气传导&#xff0c;而是通过颅骨传到听觉中枢&#xff0c;传输过程中几乎没有噪…

使用java连接Libvirtd

基于springboot web 一、依赖 <dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-web</artifactId></dependency><dependency><groupId>org.springframework.boot</groupId>&l…