边写代码边学习之LSTM

1.  什么是LSTM

长短期记忆网络 LSTM(long short-term memory)是 RNN 的一种变体,其核心概念在于细胞状态以及“门”结构。细胞状态相当于信息传输的路径,让信息能在序列连中传递下去。你可以将其看作网络的“记忆”。理论上讲,细胞状态能够将序列处理过程中的相关信息一直传递下去。因此,即使是较早时间步长的信息也能携带到较后时间步长的细胞中来,这克服了短时记忆的影响。信息的添加和移除我们通过“门”结构来实现,“门”结构在训练过程中会去学习该保存或遗忘哪些信息。
 

在这里插入图片描述

 

2. 实验代码

2.1. 搭建一个只有一层RNN和Dense网络的模型。

2.2. 验证LSTM里的逻辑

 假设我的输入数据是x = [1,0], 

kernel = [[[2, 1, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0],

              [1, 1, 0, 1, 1, 0, 0, 1, 1 ,0, 0, 0],]]

recurrent_kernel = [[1, 0, 0, 1, 2,1,0,1,2,0,1,0],

                              [1, 1, 0, 0, 2,1,0,1,2,2,0,0],

                              [1, 0, 1, 2, 0,1,0,1,1,0,1,0]]

biase = [3, 1, 0, 1, 1,0,0,1,0,2,0.0,0]

通过下面手算,h的结果是[0, 4,1], c 的结果是[0,4,1].  注意无激活函数。

代码验证上面的结果


def change_weight():# Create a simple Dense layerlstm_layer = LSTM(units=3, input_shape=(3, 2), activation=None, recurrent_activation=None, return_sequences=True,return_state= True)# Simulate input data (batch size of 1 for demonstration)input_data = np.array([[[1.0, 2], [2, 3], [3, 4]],[[5, 6], [6, 7], [7, 8]],[[9, 10], [10, 11], [11, 12]]])# Pass the input data through the layer to initialize the weights and biaseslstm_layer(input_data)kernel, recurrent_kernel, biases = lstm_layer.get_weights()# Print the initial weights and biasesprint("recurrent_kernel:", recurrent_kernel, recurrent_kernel.shape ) # (3,3)print('kernal:',kernel, kernel.shape) #(2,3)print('biase: ',biases , biases.shape) # (3)kernel = np.array([[2, 1, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0],[1, 1, 0, 1, 1, 0, 0, 1, 1 ,0, 0, 0],])recurrent_kernel = np.array([[1, 0, 0, 1, 2,1,0,1,2,0,1,0],[1, 1, 0, 0, 2,1,0,1,2,2,0,0],[1, 0, 1, 2, 0,1,0,1,1,0,1,0]])biases = np.array([3, 1, 0, 1, 1,0,0,1,0,2,0.0,0])lstm_layer.set_weights([kernel, recurrent_kernel, biases])print(lstm_layer.get_weights())# test_data = np.array([#     [[1.0, 3], [1, 1], [2, 3]]# ])test_data = np.array([[[1,0.0]]])output, memory_state, carry_state  = lstm_layer(test_data)print(output)print(memory_state)print(carry_state)
if __name__ == '__main__':change_weight()

执行结果:

recurrent_kernel: [[-0.36744034 -0.11181469 -0.10642298  0.5450207  -0.30208975  0.54054320.09643812 -0.14983998  0.1859854   0.2336958  -0.16187981  0.11621032][ 0.07727922 -0.226477    0.1491096  -0.03933501  0.31236103 -0.129630920.10522162 -0.4815724  -0.2093935   0.34740582 -0.60979587 -0.15877807][ 0.15371156  0.01244636 -0.09840634 -0.32093546  0.06523462  0.189349320.38859126 -0.3261706  -0.05138849  0.42713478  0.49390993  0.37013963]] (3, 12)
kernal: [[-0.47606698 -0.43589187 -0.5371355  -0.07337284  0.30526626 -0.18241835-0.03675252  0.2873094   0.33218485  0.24838251  0.17765659  0.4312396 ][ 0.4007727   0.41280174  0.40750778 -0.6245315   0.6382301   0.428892250.11961156 -0.6021105  -0.43556038  0.39798307  0.6390712   0.16719025]] (2, 12)
biase:  [0. 0. 0. 1. 1. 1. 0. 0. 0. 0. 0. 0.] (12,)
[array([[2., 1., 1., 0., 0., 0., 0., 1., 1., 0., 1., 0.],[1., 1., 0., 1., 1., 0., 0., 1., 1., 0., 0., 0.]], dtype=float32), array([[1., 0., 0., 1., 2., 1., 0., 1., 2., 0., 1., 0.],[1., 1., 0., 0., 2., 1., 0., 1., 2., 2., 0., 0.],[1., 0., 1., 2., 0., 1., 0., 1., 1., 0., 1., 0.]], dtype=float32), array([3., 1., 0., 1., 1., 0., 0., 1., 0., 2., 0., 0.], dtype=float32)]
tf.Tensor([[[0. 4. 0.]]], shape=(1, 1, 3), dtype=float32)
tf.Tensor([[0. 4. 0.]], shape=(1, 3), dtype=float32)
tf.Tensor([[0. 4. 1.]], shape=(1, 3), dtype=float32)

可以看出h=[0,4,0], c=[0,4,1]

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/25509.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Pytorch深度学习-----神经网络之线性层用法

系列文章目录 PyTorch深度学习——Anaconda和PyTorch安装 Pytorch深度学习-----数据模块Dataset类 Pytorch深度学习------TensorBoard的使用 Pytorch深度学习------Torchvision中Transforms的使用(ToTensor,Normalize,Resize ,Co…

MySQL索引3——Explain关键字和索引使用规则(SQL提示、索引失效、最左前缀法则)

目录 Explain关键字 索引性能分析 Id ——select的查询序列号 Select_type——select查询的类型 Table——表名称 Type——select的连接类型 Possible_key ——显示可能应用在这张表的索引 Key——实际用到的索引 Key_len——实际索引使用到的字节数 Ref ——索引命…

【Linux】五、进程

一、冯诺依曼体系结构 存储器:指的是内存; 输入设备:键盘、摄像头、话筒,磁盘,网卡; 输出设备:显示器、音响、磁盘、网卡; 中央处理器(CPU):运算器…

【开源项目--稻草】Day04

【开源项目--稻草】Day04 1. 续 VUE1.1 完善VUEAJAX完成注册功能 Spring验证框架什么是Spring验证框架使用Spring-Validation 稻草问答-学生首页显示首页制作首页的流程开发标签列表标签列表显示原理 从业务逻辑层开始编写控制层代码开发问题列表开发业务逻辑层开发页面和JS代码…

HTML5 Canvas(画布)

<canvas>标签定义图形&#xff0c;比如图表和其他图像&#xff0c;你必须用脚本来绘制图形。 在画布上&#xff08; Canvas &#xff09;画一个共红色矩形&#xff0c;渐变矩形&#xff0c;彩色矩形&#xff0c;和一些彩色文字。 什么是 Canvas&#xff1f; HTML5<c…

机器学习深度学习——序列模型(NLP启动!)

&#x1f468;‍&#x1f393;作者简介&#xff1a;一位即将上大四&#xff0c;正专攻机器学习的保研er &#x1f30c;上期文章&#xff1a;机器学习&&深度学习——卷积神经网络&#xff08;LeNet&#xff09; &#x1f4da;订阅专栏&#xff1a;机器学习&&深度…

Mr. Cappuccino的第55杯咖啡——Mybatis一级缓存二级缓存

Mybatis一级缓存&二级缓存 概述一级缓存特点演示前准备效果演示在同一个SqlSession中在不同的SqlSession中 源代码怎么禁止使用一级缓存一级缓存在什么情况下会被清除 二级缓存特点演示前准备效果演示在不同的SqlSession中 源代码怎么关闭二级缓存 一级缓存&#xff08;Spr…

Docker 安装 Tomcat

目录 一、查看 tomcat 版本 二、拉取 Tomcat Docker 镜像 三、创建 Tomcat 容器 四、访问 Tomcat 五、停止和启动容器 一、查看 tomcat 版本 访问 tomcat 镜像库地址&#xff1a;https://hub.docker.com/_/tomcat&#xff0c;可以通过 Tags 查看其他版本的 tomcat; 二、拉…

Android Studio 的Gradle版本修改

使用Android Studio构建项目时&#xff0c;需要配置Gradle&#xff0c;与Gradle插件。 Gradle是一个构建工具&#xff0c;用于管理和自动化Android项目的构建过程。它使用Groovy或Kotlin作为脚本语言&#xff0c;并提供了强大的配置能力来定义项目的依赖关系、编译选项、打包方…

HCIA---OSI/RM--开放式系统互联参考模型

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 文章目录 前言一、pandas是什么&#xff1f;二、使用步骤 1.引入库2.读入数据总结 一.OSI--开放式系统互联参考模型简介 OSI开放式系统互联参考模型是一种用于计算机网络通信…

解密Redis:应对面试中的缓存相关问题2

面试官&#xff1a;Redis集群有哪些方案&#xff0c;知道嘛&#xff1f; 候选人&#xff1a;嗯~~&#xff0c;在Redis中提供的集群方案总共有三种&#xff1a;主从复制、哨兵模式、Redis分片集群。 面试官&#xff1a;那你来介绍一下主从同步。 候选人&#xff1a;嗯&#xff…

Flutter iOS 集成使用 fluter boost

在 Flutter项目中集成完 flutter boost&#xff0c;并且已经使用了 flutter boost进行了路由管理&#xff0c;这时如果需要和iOS混合开发&#xff0c;这时就要到 原生端进行集成。 注意&#xff1a;之前建的项目必须是 Flutter module项目&#xff0c;并且原生项目和flutter m…

Baumer工业相机堡盟工业相机如何通过BGAPISDK获取相机接口数据吞吐量(C++)

Baumer工业相机堡盟工业相机如何通过BGAPISDK里函数来获取相机当前数据吞吐量&#xff08;C&#xff09; Baumer工业相机Baumer工业相机的数据吞吐量的技术背景CameraExplorer如何查看相机吞吐量信息在BGAPI SDK里通过函数获取相机接口吞吐量 Baumer工业相机通过BGAPI SDK获取数…

【微信小程序】van-uploader实现文件上传

使用van-uploader和wx.uploadFile实现文件上传&#xff0c;后端使用ThinkPHP。 1、前端代码 json&#xff1a;引入van-uploader {"usingComponents": {"van-uploader": "vant/weapp/uploader/index"} }wxml&#xff1a;deletedFile是删除文件函…

Xilinx FPGA电源设计与注意事项

1 引言 随着半导体和芯片技术的飞速发展&#xff0c;现在的FPGA集成了越来越多的可配置逻辑资源、各种各样的外部总线接口以及丰富的内部RAM资源&#xff0c;使其在国防、医疗、消费电子等领域得到了越来越广泛的应用。当采用FPGA进行设计电路时&#xff0c;大多数FPGA对上电的…

【计算机网络】12、frp 内网穿透

文章目录 一、服务端设置二、客户端设置 frp &#xff1a;A fast reverse proxy to help you expose a local server behind a NAT or firewall to the internet。是一个专注于内网穿透的高性能的反向代理应用&#xff0c;支持 TCP、UDP、HTTP、HTTPS 等多种协议&#xff0c;且…

VUE框架:vue2转vue3全面细节总结(5)过渡动效

大家好&#xff0c;我是csdn的博主&#xff1a;lqj_本人 这是我的个人博客主页&#xff1a; lqj_本人_python人工智能视觉&#xff08;opencv&#xff09;从入门到实战,前端,微信小程序-CSDN博客 最新的uniapp毕业设计专栏也放在下方了&#xff1a; https://blog.csdn.net/lbcy…

ES6 数组的用法

1. forEach() 用来循环遍历的 for 数组名.forEach(function (item,index,arr) {})item:数组每一项 , index : 数组索引 , arr:原数组作用: 用来遍历数组 let arr [1, 2, 3, 4]; console.log(arr); let arr1 arr.forEach((item, index, arr) > {console.log(item, index…

HTTP——八、确认访问用户身份的认证

HTTP 一、何为认证二、BASIC认证BASIC认证的认证步骤 三、DIGEST认证DIGEST认证的认证步骤 四、SSL客户端认证1、SSL 客户端认证的认证步骤2、SSL 客户端认证采用双因素认证3、SSL 客户端认证必要的费用 五、基于表单认证1、认证多半为基于表单认证2、Session 管理及 Cookie 应…

Android network — iptables四表五链

Android network — iptables四表五链 1. iptables简介2. iptables的四表五链2.1 iptables流程图2.2 四表2.3 五链2.4 iptables的常见情况 3. NAT工作原理3.1 BNAT3.2 NAPT 4. iptables配置 本文主要介绍了iptables的基本工作原理和四表五链等基本概念以及NAT的工作原理。 1. i…