边写代码边学习之LSTM

1.  什么是LSTM

长短期记忆网络 LSTM(long short-term memory)是 RNN 的一种变体,其核心概念在于细胞状态以及“门”结构。细胞状态相当于信息传输的路径,让信息能在序列连中传递下去。你可以将其看作网络的“记忆”。理论上讲,细胞状态能够将序列处理过程中的相关信息一直传递下去。因此,即使是较早时间步长的信息也能携带到较后时间步长的细胞中来,这克服了短时记忆的影响。信息的添加和移除我们通过“门”结构来实现,“门”结构在训练过程中会去学习该保存或遗忘哪些信息。
 

在这里插入图片描述

 

2. 实验代码

2.1. 搭建一个只有一层RNN和Dense网络的模型。

2.2. 验证LSTM里的逻辑

 假设我的输入数据是x = [1,0], 

kernel = [[[2, 1, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0],

              [1, 1, 0, 1, 1, 0, 0, 1, 1 ,0, 0, 0],]]

recurrent_kernel = [[1, 0, 0, 1, 2,1,0,1,2,0,1,0],

                              [1, 1, 0, 0, 2,1,0,1,2,2,0,0],

                              [1, 0, 1, 2, 0,1,0,1,1,0,1,0]]

biase = [3, 1, 0, 1, 1,0,0,1,0,2,0.0,0]

通过下面手算,h的结果是[0, 4,1], c 的结果是[0,4,1].  注意无激活函数。

代码验证上面的结果


def change_weight():# Create a simple Dense layerlstm_layer = LSTM(units=3, input_shape=(3, 2), activation=None, recurrent_activation=None, return_sequences=True,return_state= True)# Simulate input data (batch size of 1 for demonstration)input_data = np.array([[[1.0, 2], [2, 3], [3, 4]],[[5, 6], [6, 7], [7, 8]],[[9, 10], [10, 11], [11, 12]]])# Pass the input data through the layer to initialize the weights and biaseslstm_layer(input_data)kernel, recurrent_kernel, biases = lstm_layer.get_weights()# Print the initial weights and biasesprint("recurrent_kernel:", recurrent_kernel, recurrent_kernel.shape ) # (3,3)print('kernal:',kernel, kernel.shape) #(2,3)print('biase: ',biases , biases.shape) # (3)kernel = np.array([[2, 1, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0],[1, 1, 0, 1, 1, 0, 0, 1, 1 ,0, 0, 0],])recurrent_kernel = np.array([[1, 0, 0, 1, 2,1,0,1,2,0,1,0],[1, 1, 0, 0, 2,1,0,1,2,2,0,0],[1, 0, 1, 2, 0,1,0,1,1,0,1,0]])biases = np.array([3, 1, 0, 1, 1,0,0,1,0,2,0.0,0])lstm_layer.set_weights([kernel, recurrent_kernel, biases])print(lstm_layer.get_weights())# test_data = np.array([#     [[1.0, 3], [1, 1], [2, 3]]# ])test_data = np.array([[[1,0.0]]])output, memory_state, carry_state  = lstm_layer(test_data)print(output)print(memory_state)print(carry_state)
if __name__ == '__main__':change_weight()

执行结果:

recurrent_kernel: [[-0.36744034 -0.11181469 -0.10642298  0.5450207  -0.30208975  0.54054320.09643812 -0.14983998  0.1859854   0.2336958  -0.16187981  0.11621032][ 0.07727922 -0.226477    0.1491096  -0.03933501  0.31236103 -0.129630920.10522162 -0.4815724  -0.2093935   0.34740582 -0.60979587 -0.15877807][ 0.15371156  0.01244636 -0.09840634 -0.32093546  0.06523462  0.189349320.38859126 -0.3261706  -0.05138849  0.42713478  0.49390993  0.37013963]] (3, 12)
kernal: [[-0.47606698 -0.43589187 -0.5371355  -0.07337284  0.30526626 -0.18241835-0.03675252  0.2873094   0.33218485  0.24838251  0.17765659  0.4312396 ][ 0.4007727   0.41280174  0.40750778 -0.6245315   0.6382301   0.428892250.11961156 -0.6021105  -0.43556038  0.39798307  0.6390712   0.16719025]] (2, 12)
biase:  [0. 0. 0. 1. 1. 1. 0. 0. 0. 0. 0. 0.] (12,)
[array([[2., 1., 1., 0., 0., 0., 0., 1., 1., 0., 1., 0.],[1., 1., 0., 1., 1., 0., 0., 1., 1., 0., 0., 0.]], dtype=float32), array([[1., 0., 0., 1., 2., 1., 0., 1., 2., 0., 1., 0.],[1., 1., 0., 0., 2., 1., 0., 1., 2., 2., 0., 0.],[1., 0., 1., 2., 0., 1., 0., 1., 1., 0., 1., 0.]], dtype=float32), array([3., 1., 0., 1., 1., 0., 0., 1., 0., 2., 0., 0.], dtype=float32)]
tf.Tensor([[[0. 4. 0.]]], shape=(1, 1, 3), dtype=float32)
tf.Tensor([[0. 4. 0.]], shape=(1, 3), dtype=float32)
tf.Tensor([[0. 4. 1.]], shape=(1, 3), dtype=float32)

可以看出h=[0,4,0], c=[0,4,1]

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/25509.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Pytorch深度学习-----神经网络之线性层用法

系列文章目录 PyTorch深度学习——Anaconda和PyTorch安装 Pytorch深度学习-----数据模块Dataset类 Pytorch深度学习------TensorBoard的使用 Pytorch深度学习------Torchvision中Transforms的使用(ToTensor,Normalize,Resize ,Co…

MySQL索引3——Explain关键字和索引使用规则(SQL提示、索引失效、最左前缀法则)

目录 Explain关键字 索引性能分析 Id ——select的查询序列号 Select_type——select查询的类型 Table——表名称 Type——select的连接类型 Possible_key ——显示可能应用在这张表的索引 Key——实际用到的索引 Key_len——实际索引使用到的字节数 Ref ——索引命…

【Linux】五、进程

一、冯诺依曼体系结构 存储器:指的是内存; 输入设备:键盘、摄像头、话筒,磁盘,网卡; 输出设备:显示器、音响、磁盘、网卡; 中央处理器(CPU):运算器…

【开源项目--稻草】Day04

【开源项目--稻草】Day04 1. 续 VUE1.1 完善VUEAJAX完成注册功能 Spring验证框架什么是Spring验证框架使用Spring-Validation 稻草问答-学生首页显示首页制作首页的流程开发标签列表标签列表显示原理 从业务逻辑层开始编写控制层代码开发问题列表开发业务逻辑层开发页面和JS代码…

HTML5 Canvas(画布)

<canvas>标签定义图形&#xff0c;比如图表和其他图像&#xff0c;你必须用脚本来绘制图形。 在画布上&#xff08; Canvas &#xff09;画一个共红色矩形&#xff0c;渐变矩形&#xff0c;彩色矩形&#xff0c;和一些彩色文字。 什么是 Canvas&#xff1f; HTML5<c…

机器学习深度学习——序列模型(NLP启动!)

&#x1f468;‍&#x1f393;作者简介&#xff1a;一位即将上大四&#xff0c;正专攻机器学习的保研er &#x1f30c;上期文章&#xff1a;机器学习&&深度学习——卷积神经网络&#xff08;LeNet&#xff09; &#x1f4da;订阅专栏&#xff1a;机器学习&&深度…

vue3+vite项目配置ESlint、pritter插件

配置ESlint、pritter插件 在 Vue 3 Vite 项目中&#xff0c;你可以通过以下步骤配置 ESLint 和 Prettier 插件&#xff1a; 安装插件&#xff1a; 在项目根目录下&#xff0c;打开终端并执行以下命令安装 ESLint 和 Prettier 插件&#xff1a; npm install eslint prettier e…

Mr. Cappuccino的第55杯咖啡——Mybatis一级缓存二级缓存

Mybatis一级缓存&二级缓存 概述一级缓存特点演示前准备效果演示在同一个SqlSession中在不同的SqlSession中 源代码怎么禁止使用一级缓存一级缓存在什么情况下会被清除 二级缓存特点演示前准备效果演示在不同的SqlSession中 源代码怎么关闭二级缓存 一级缓存&#xff08;Spr…

ubuntu20.4 sgx环境配置

一、driver安装 1.在该下载地址将3个.bin文件下载下来&#xff0c;下载地址&#xff1a;https://download.01.org/intel-sgx/latest/linux-latest/distro/ubuntu20.04-server/ 2.到下载文件夹下输入下面命令&#xff0c;以赋予.bin文件的执行权限 sudo chmod 777 sgx_linux_x64…

HTTP 常用状态码 301 302 304 403

HTTP 常用状态码 301 302 304 403 301 永久重定向&#xff0c;浏览器会把重定向后的地址缓存起来&#xff0c;将来用户再次访问原始地址时&#xff0c;直接引导用户访问新地址 302 临时重定向&#xff0c;浏览器会引导用户进入新地址&#xff0c;但不会缓存原始地址&#xff0c…

Python模块—Pytest模块

文章目录 PyTest1. args参数2. pytest-ordering3. fixture&#xff08;前置函数&#xff09;4. parametrize&#xff08;参数化&#xff09;5. fixture 与 parametrize 结合6. pyyaml&#xff08;数据源&#xff09;7. pytest-xdist&#xff08;分布式测试&#xff09;8. allur…

LA@行列式性质

文章目录 行列式性质&#x1f388;转置不变性质交换性质多重交换移动(抽出插入)&#x1f47a; 因子提取性质拆和性质倍加性质 手算行列式的主要方法原理:任何行列式都可以化为三角行列式 行列式性质&#x1f388; 设行列式 ∣ A ∣ d e t ( a i j ) |A|\mathrm{det}(a_{ij}) …

vue 关于axios的使用方法

axios定义&#xff1a; axios 前端 ajax请求工具 1. 在浏览器与nodejs可以使用 2. 可以拦截请求与相应 3. 扩展与封装自定义方法 4. 不依赖dom节点 安装 npm i axios -S 先在vue全局中挂载 import axios from ‘axios’ Vue.prototype.$h…

Docker 安装 Tomcat

目录 一、查看 tomcat 版本 二、拉取 Tomcat Docker 镜像 三、创建 Tomcat 容器 四、访问 Tomcat 五、停止和启动容器 一、查看 tomcat 版本 访问 tomcat 镜像库地址&#xff1a;https://hub.docker.com/_/tomcat&#xff0c;可以通过 Tags 查看其他版本的 tomcat; 二、拉…

Elasticsearch8.8.0 SpringBoot实战操作各种案例(索引操作、聚合、复杂查询、嵌套等)

Elasticsearch8.8.0 全网最新版教程 从入门到精通 通俗易懂 配置项目 引入依赖 <dependency><groupId>cn.hutool</groupId><artifactId>hutool-all</artifactId><version>5.8.16</version></dependency><dependency>&l…

Android Studio 的Gradle版本修改

使用Android Studio构建项目时&#xff0c;需要配置Gradle&#xff0c;与Gradle插件。 Gradle是一个构建工具&#xff0c;用于管理和自动化Android项目的构建过程。它使用Groovy或Kotlin作为脚本语言&#xff0c;并提供了强大的配置能力来定义项目的依赖关系、编译选项、打包方…

Jtti:linux如何配置dns域名解析服务器

要配置Linux上的DNS域名解析服务器&#xff0c;您可以按照以下步骤进行操作&#xff1a; 1. 安装BIND软件包&#xff1a;BIND是Linux上最常用的DNS服务器软件&#xff0c;您可以使用以下命令安装它&#xff1a; sudo apt-get install bind9 2. 配置BIND&#xff1a;BIND的配置…

Spring Cloud常见问题处理和代码分析

目录 1. 问题&#xff1a;如何在 Spring Cloud 中实现服务注册和发现&#xff1f;2. 问题&#xff1a;如何在 Spring Cloud 中实现分布式配置&#xff1f;3. 问题&#xff1a;如何在 Spring Cloud 中实现服务间的调用&#xff1f;4. 问题&#xff1a;如何在 Spring Cloud 中实现…

HCIA---OSI/RM--开放式系统互联参考模型

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 文章目录 前言一、pandas是什么&#xff1f;二、使用步骤 1.引入库2.读入数据总结 一.OSI--开放式系统互联参考模型简介 OSI开放式系统互联参考模型是一种用于计算机网络通信…

解密Redis:应对面试中的缓存相关问题2

面试官&#xff1a;Redis集群有哪些方案&#xff0c;知道嘛&#xff1f; 候选人&#xff1a;嗯~~&#xff0c;在Redis中提供的集群方案总共有三种&#xff1a;主从复制、哨兵模式、Redis分片集群。 面试官&#xff1a;那你来介绍一下主从同步。 候选人&#xff1a;嗯&#xff…