小白也能读懂的ConvLSTM!(开源pytorch代码)

ConvLSTM

    • 1. 算法简介与应用场景
    • 2. 算法原理
      • 2.1 LSTM基础
      • 2.2 ConvLSTM原理
        • 2.2.1 ConvLSTM的结构
        • 2.2.2 卷积操作的优点
      • 2.3 LSTM与ConvLSTM的对比分析
      • 2.4 ConvLSTM的应用
    • 3. PyTorch代码
    • 参考文献

仅需要网络源码的可以直接跳到末尾即可

1. 算法简介与应用场景

ConvLSTM(卷积长短期记忆网络)是一种结合了卷积神经网络(CNN)和长短期记忆网络(LSTM)优势的深度学习模型。它主要用于处理时空数据,特别适用于需要考虑空间特征和时间依赖关系的任务,如气象预测、视频分析、交通流量预测等。

在气象预测中,ConvLSTM可以根据过去的气象数据(如降水、温度等)预测未来的天气情况。在视频分析中,它可以帮助识别视频中的活动或事件,利用时间序列的连续性和空间信息进行更准确的分析。

2. 算法原理

2.1 LSTM基础

在介绍ConvLSTM之前,先让我们来回归一下什么是长短期记忆网络(LSTM)。LSTM是一种特殊的循环神经网络(RNN),它通过引入门控机制解决了传统RNN在长序列训练中面临的梯度消失和爆炸问题。LSTM单元主要包含三个门:输入门、遗忘门和输出门。这些门控制着信息在单元中的流动,从而有效地记住或遗忘信息。

LSTM的核心公式如下:

  • 遗忘门
    f t = σ ( W f ⋅ [ h t − 1 , x t ] + b f ) f_t = \sigma(W_f \cdot [h_{t-1}, x_t] + b_f) ft=σ(Wf[ht1,xt]+bf)

  • 输入门
    i t = σ ( W i ⋅ [ h t − 1 , x t ] + b i ) i_t = \sigma(W_i \cdot [h_{t-1}, x_t] + b_i) it=σ(Wi[ht1,xt]+bi)
    C ~ t = tanh ⁡ ( W C ⋅ [ h t − 1 , x t ] + b C ) \tilde{C}_t = \tanh(W_C \cdot [h_{t-1}, x_t] + b_C) C~t=tanh(WC[ht1,xt]+bC)

  • 单元状态更新
    C t = f t ∗ C t − 1 + i t ∗ C ~ t C_t = f_t \ast C_{t-1} + i_t \ast \tilde{C}_t Ct=ftCt1+itC~t

  • 输出门
    o t = σ ( W o ⋅ [ h t − 1 , x t ] + b o ) o_t = \sigma(W_o \cdot [h_{t-1}, x_t] + b_o) ot=σ(Wo[ht1,xt]+bo)
    h t = o t ∗ tanh ⁡ ( C t ) h_t = o_t \ast \tanh(C_t) ht=ottanh(Ct)

这里, C t C_t Ct 是当前的单元状态, h t h_t ht 是当前的隐藏状态, x t x_t xt 是当前的输入。

2.2 ConvLSTM原理

ConvLSTM在LSTM的基础上引入了卷积操作。传统的LSTM使用全连接层处理输入数据,而ConvLSTM则采用卷积层来处理空间数据。这样,ConvLSTM能够更好地捕捉输入数据中的空间特征。
在这里插入图片描述

2.2.1 ConvLSTM的结构

ConvLSTM的单元结构与LSTM非常相似,但是在每个门的计算中使用了卷积操作。具体来说,ConvLSTM的每个门的公式可以表示为:

i t = σ ( W x i ∗ X t + W h i ∗ H t − 1 + W c i ∘ C t − 1 + b i ) i_t = \sigma (W_{xi} * X_t + W_{hi} * H_{t-1} + W_{ci} \circ C_{t-1} + b_i) it=σ(WxiXt+WhiHt1+WciCt1+bi)
f t = σ ( W x f ∗ X t + W h f ∗ H t − 1 + W c f ∘ C t − 1 + b f ) f_t = \sigma (W_{xf} * X_t + W_{hf} * H_{t-1} + W_{cf} \circ C_{t-1} + b_f) ft=σ(WxfXt+WhfHt1+WcfCt1+bf)
C t = f t ∘ C t − 1 + i t ∘ t a n h ( W x c ∗ X t + W h c ∗ H t − 1 + b c ) C_t = f_t \circ C_{t-1} + i_t \circ tanh(W_{xc} * X_t + W_{hc} * H_{t-1} + b_c) Ct=ftCt1+ittanh(WxcXt+WhcHt1+bc)
o t = σ ( W x o ∗ X t + W h o ∗ H t − 1 + W c o ∘ C t + b o ) o_t = \sigma (W_{xo} * X_t + W_{ho} * H_{t-1} + W_{co} \circ C_t + b_o) ot=σ(WxoXt+WhoHt1+WcoCt+bo)
H t = o t ∘ t a n h ( C t ) H_t = o_t \circ tanh(C_t) Ht=ottanh(Ct)

这里的 所有 W W W都是是卷积权重, b b b是偏置项, σ \sigma σ 是 sigmoid 函数, tanh ⁡ \tanh tanh 是双曲正切函数。。
在这里插入图片描述

2.2.2 卷积操作的优点
  1. 空间特征提取:卷积操作能够有效提取输入数据中的空间特征。对于图像数据,卷积操作可以捕捉局部特征,例如边缘、纹理等,这在时间序列数据中同样适用。

  2. 参数共享:卷积操作通过使用相同的卷积核在不同位置计算特征,从而减少了模型参数的数量,降低了计算复杂度。

  3. 平移不变性:卷积网络对输入数据的平移具有不变性,即相同的特征在不同位置都会被检测到,这对于时空序列数据来说是非常重要的。

2.3 LSTM与ConvLSTM的对比分析

特性LSTMConvLSTM
输入类型一维序列三维数据(时序的图像数据)
处理方式全连接层卷积操作
空间特征捕捉较弱较强
应用场景自然语言处理、时间序列预测图像序列预测、视频分析

2.4 ConvLSTM的应用

ConvLSTM在多个领域中表现出色,特别适合处理具有时空特征的数据。以下是一些主要的应用场景:

  • 气象预测:利用历史气象数据(如温度、湿度、降水等)来预测未来的天气情况。
  • 视频分析:对视频中的动态场景进行建模,识别和预测视频中的活动。
  • 交通流量预测:基于历史交通数据预测未来的交通流量,帮助城市交通管理。
  • 医学影像分析:分析医学影像序列(如CT、MRI)中的变化,辅助疾病诊断。

3. PyTorch代码

以下是ConvLSTM的完整代码,可以直接拿来用:

import torch.nn as nn
import torchclass ConvLSTMCell(nn.Module):def __init__(self, input_dim, hidden_dim, kernel_size, bias):"""初始化卷积 LSTM 单元。参数:----------input_dim: int输入张量的通道数。hidden_dim: int隐藏状态的通道数。kernel_size: (int, int)卷积核的大小。bias: bool是否添加偏置项。"""super(ConvLSTMCell, self).__init__()self.input_dim = input_dimself.hidden_dim = hidden_dimself.kernel_size = kernel_size# 计算填充大小以保持输入和输出尺寸一致self.padding = kernel_size[0] // 2, kernel_size[1] // 2self.bias = bias# 定义卷积层,输入是输入维度加上隐藏维度,输出是4倍的隐藏维度(对应i, f, o, g)self.conv = nn.Conv2d(in_channels=self.input_dim + self.hidden_dim,out_channels=4 * self.hidden_dim,kernel_size=self.kernel_size,padding=self.padding,bias=self.bias)def forward(self, input_tensor, cur_state):h_cur, c_cur = cur_state# 沿着通道轴进行拼接combined = torch.cat([input_tensor, h_cur], dim=1)combined_conv = self.conv(combined)# 将输出分割成四个部分,分别对应输入门、遗忘门、输出门和候选单元状态cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1)i = torch.sigmoid(cc_i)f = torch.sigmoid(cc_f)o = torch.sigmoid(cc_o)g = torch.tanh(cc_g)# 更新单元状态c_next = f * c_cur + i * g# 更新隐藏状态h_next = o * torch.tanh(c_next)return h_next, c_nextdef init_hidden(self, batch_size, image_size):height, width = image_size# 初始化隐藏状态和单元状态为零return (torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device),torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device))class ConvLSTM(nn.Module):"""卷积 LSTM 层。参数:----------input_dim: 输入通道数hidden_dim: 隐藏通道数kernel_size: 卷积核大小num_layers: LSTM 层的数量batch_first: 批次是否在第一维bias: 卷积中是否有偏置项return_all_layers: 是否返回所有层的计算结果输入:------一个形状为 B, T, C, H, W 或者 T, B, C, H, W 的张量输出:------元组包含两个列表(长度为 num_layers 或者长度为 1 如果 return_all_layers 为 False):0 - layer_output_list 是长度为 T 的每个输出的列表1 - last_state_list 是最后的状态列表,其中每个元素是一个 (h, c) 对应隐藏状态和记忆状态示例:>>> x = torch.rand((32, 10, 64, 128, 128))>>> convlstm = ConvLSTM(64, 16, 3, 1, True, True, False)>>> _, last_states = convlstm(x)>>> h = last_states[0][0]  # 0 表示层索引,0 表示 h 索引"""def __init__(self, input_dim, hidden_dim, kernel_size, num_layers,batch_first=False, bias=True, return_all_layers=False):super(ConvLSTM, self).__init__()# 检查 kernel_size 的一致性self._check_kernel_size_consistency(kernel_size)# 确保 kernel_size 和 hidden_dim 的长度与层数一致kernel_size = self._extend_for_multilayer(kernel_size, num_layers)hidden_dim = self._extend_for_multilayer(hidden_dim, num_layers)if not len(kernel_size) == len(hidden_dim) == num_layers:raise ValueError('不一致的列表长度。')self.input_dim = input_dimself.hidden_dim = hidden_dimself.kernel_size = kernel_sizeself.num_layers = num_layersself.batch_first = batch_firstself.bias = biasself.return_all_layers = return_all_layers# 创建 ConvLSTMCell 列表cell_list = []for i in range(0, self.num_layers):cur_input_dim = self.input_dim if i == 0 else self.hidden_dim[i - 1]cell_list.append(ConvLSTMCell(input_dim=cur_input_dim,hidden_dim=self.hidden_dim[i],kernel_size=self.kernel_size[i],bias=self.bias))self.cell_list = nn.ModuleList(cell_list)def forward(self, input_tensor, hidden_state=None):"""前向传播函数。参数:----------input_tensor: 输入张量,形状为 (t, b, c, h, w) 或者 (b, t, c, h, w)hidden_state: 初始隐藏状态,默认为 None返回:-------last_state_list, layer_output"""if not self.batch_first:# 改变输入张量的顺序,如果 batch_first 为 Falseinput_tensor = input_tensor.permute(1, 0, 2, 3, 4)b, _, _, h, w = input_tensor.size()# 实现状态化的 ConvLSTMif hidden_state is not None:raise NotImplementedError()else:# 初始化隐藏状态hidden_state = self._init_hidden(batch_size=b,image_size=(h, w))layer_output_list = []last_state_list = []seq_len = input_tensor.size(1)cur_layer_input = input_tensorfor layer_idx in range(self.num_layers):h, c = hidden_state[layer_idx]output_inner = []for t in range(seq_len):# 在每个时间步上更新状态h, c = self.cell_list[layer_idx](input_tensor=cur_layer_input[:, t, :, :, :],cur_state=[h, c])output_inner.append(h)# 将输出堆叠起来layer_output = torch.stack(output_inner, dim=1)cur_layer_input = layer_outputlayer_output_list.append(layer_output)last_state_list.append([h, c])if not self.return_all_layers:# 如果不需要返回所有层,则只返回最后一层的输出和状态layer_output_list = layer_output_list[-1:]last_state_list = last_state_list[-1:]return layer_output_list, last_state_listdef _init_hidden(self, batch_size, image_size):init_states = []for i in range(self.num_layers):# 初始化每一层的隐藏状态init_states.append(self.cell_list[i].init_hidden(batch_size, image_size))return init_states@staticmethoddef _check_kernel_size_consistency(kernel_size):if not (isinstance(kernel_size, tuple) or(isinstance(kernel_size, list) and all([isinstance(elem, tuple) for elem in kernel_size]))):raise ValueError('`kernel_size` 必须是 tuple 或者 list of tuples')@staticmethoddef _extend_for_multilayer(param, num_layers):if not isinstance(param, list):param = [param] * num_layersreturn param

参考文献

[1]Shi, X., Chen, Z., Wang, H., Yeung, D. Y., Wong, W. K., & Woo, W. (2015). Convolutional LSTM Network: A Machine Learning [2]Approach for Precipitation Nowcasting. Advances in Neural Information Processing Systems, 28.
[3]Hochreiter, S., & Schmidhuber, J. (1997). Long Short-Term Memory. Neural Computation, 9(8), 1735-1780.
Goodfellow, I., Bengio, Y., & Courville, A. (2016). Deep Learning. MIT Press.

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/51678.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

SAPUI5基础知识22 - 图标(Icons)

1. 背景 SAPUI5 提供了一套丰富的图标库,可以用于增强应用程序的视觉吸引力和用户体验。这些图标是矢量图形,可以在任何分辨率下保持清晰,并且可以自定义颜色和大小。 2. 示例 在 SAPUI5 中,图标可以通过 sap.ui.core.Icon 控件…

Redis快速入门基础

Redis入门 Redis是一个基于内存的 key-value 结构数据库。mysql是二维表的接口数据库 优点: 基于内存存储,读写性能高 适合存储热点数据(热点商品、资讯、新闻) 企业应用广泛 官网:https://redis.io 中文网:https://www.redis.net.cn/ Redis下载与…

The Llama 3 Herd of Models 第6部分推理部分全文

第1,2,3部分 介绍,概览和预训练 第4部分 后训练 第5部分 结果 6 Inference 推理 我们研究了两种主要技术来提高Llama 3405b模型的推理效率:(1)管道并行化和(2)FP8量化。我们已经公开发布了FP8量化的实现。 6.1 Pipeline Parallelism 管道并行 当使用BF16数字表示模型参数时…

家具购物小程序的设计

管理员账户功能包括:系统首页,个人中心,用户管理,家具分类管理,家具新品管理,订单管理,系统管理 微信端账号功能包括:系统首页,家具新品,家具公告&#xff0…

Linux网络——深入理解传入层协议TCP

目录 一、前导知识 1.1 TCP协议段格式 1.2 TCP全双工本质 二、三次握手 2.1 标记位 2.2 三次握手 2.3 捎带应答 2.4 标记位 RST 三、四次挥手 3.1 标记位 FIN 四、确认应答(ACK)机制 五、超时重传机制 六 TCP 流量控制 6.1 16位窗口大小 6.2 标记位 PSH 6.3 标记…

YOLOv5改进 | 卷积模块 | 无卷积步长用于低分辨率图像和小物体的新 CNN 模块SPD-Conv

秋招面试专栏推荐 :深度学习算法工程师面试问题总结【百面算法工程师】——点击即可跳转 💡💡💡本专栏所有程序均经过测试,可成功执行💡💡💡 专栏目录: 《YOLOv5入门 改…

[ WARN:0@0.014] global loadsave.cpp:248 cv::findDecoder imread_

[ WARN:00.014] global loadsave.cpp:248 cv::findDecoder imread_ 目录 [ WARN:00.014] global loadsave.cpp:248 cv::findDecoder imread_ 【常见模块错误】 【解决方案】 欢迎来到英杰社区https://bbs.csdn.net/topics/617804998 欢迎来到我的主页,我是博主英杰…

20240729 每日AI必读资讯

Meta科学家最新采访,揭秘Llama 3.1是如何炼成的 - Llama 3.1都使用了哪些数据?其中有多少合成数据?为什么不使用MoE架构?后训练与RLHF流程是如何进行的?模型评估是如何进行的? - 受访者Thomas Scialom现任…

Go语言教程(一看就会)

全篇文章 7000 字左右, 建议阅读时长 1h 以上。 Go语言是一门开源的编程语言,目的在于降低构建简单、可靠、高效软件的门槛。Go平衡了底层系统语言的能力,以及在现代语言中所见到的高级特性。它是快速的、静态类型编译语言。 第一个GO程序…

嵌入式人工智能(32-基于树莓派4B的旋转编码器-EnCoder11)

1、旋转编码器 旋转编码器是一种输入设备,通常用于测量和控制旋转运动。它由一个旋转轴和一系列编码器组成。旋转编码器可以根据旋转轴的位置和方向来测量旋转角度,并将其转化为电子信号输出。 旋转编码器通常分为两种类型:绝对值编码器和增…

嵌入式学习Day13---C语言提升

目录 一、二级指针 1.1.什么是二级指针 2.2.使用情况 2.3.二级指针与数组指针 二、指针函数 2.1.含义 2.2.格式 2.3.注意 2.4.练习 三、函数指针 3.1.含义 3.2.格式 3.3.存储 3.4.练习 ​编辑 四、void*指针 4.1.void缺省类型 4.2.void* 4.3.格式 4.4.注…

H3CNE(OSPF动态路由)

目录 7.1 静态路由的缺点与动态路由分类 7.1.1 静态路由的缺点 7.1.2 动态路由的分类 7.2 OSPF基础 7.2.1 OSPF的区域 ​编辑 7.2.2 Router-id 7.2.3 开销-Cost or Metric 7.2.4 路由转发 7.3 OSPF邻居表建立过程 7.3.1 五种包 7.3.2 建立邻居表的第一步 7.3.3 邻居建立…

模拟实现短信登录功能 (session 和 Redis 两种代码实例) 带前端演示

目录 整体流程 发送验证码 短信验证码登录、注册 校验登录状态 基于 session 实现登录 实现发送短信验证码功能 1. 前端发送请求 2. 后端处理请求 3. 演示 实现登录功能 1. 前端发送请求 2. 后端处理请求 校验登录状态 1. 登录拦截器 2. 注册拦截器 3. 登录完整…

RocketMQ事务消息机制原理

RocketMQ工作流程 在RocketMQ当中,当消息的生产者将消息生产完成之后,并不会直接将生产好的消息直接投递给消费者,而是先将消息投递个中间的服务,通过这个服务来协调RocketMQ中生产者与消费者之间的消费速度。 那么生产者是如何…

昇思25天学习打卡营第19天|DCGAN生成漫画头像

DCGAN生成漫画头像总结 实验概述 本实验旨在利用深度卷积生成对抗网络(DCGAN)生成动漫头像,通过设置网络、优化器以及损失函数,使用MindSpore进行实现。 实验目的 学习和掌握DCGAN的基本原理和应用。熟悉使用MindSpore进行图像…

网络协议一 : 搭建tomacat,intellij IDEA Ultimate 的下载,安装,配置,启动, 访问

需要搭建的环境 1.客户端--服务器开发环境 客户端:浏览器(HTMLCSSJS) 服务器:JAVA 1.安装JDK,配置JAVA_HOME 和 PATH 2.安装Tomcat 3.安装IDE--intellij IDEA Ultimate 是旗舰版的意思。 2.TOMCAT 的下载和解…

文件操作相关的精讲

目录: 思维导图 一. 文件定义 二. 文件的打开和关闭 三. 文件的顺序读写操作 四. 文件的随机读写操作 五. 文本文件和二进制文件 六. 文件读取结束的判断 七.文件缓冲区 思维导图: 一. 文件定义 1.文件定义 C语言中,文件是指一组相…

Java中的二叉搜索树(如果想知道Java中有关二叉搜索树的知识点,那么只看这一篇就足够了!)

前言:Java 提供了丰富的数据结构来处理和管理数据,其中 TreeSet 和 TreeMap 是基于红黑树实现的集合和映射接口。它们有序地存储数据,提供高效的搜索、插入和删除操作。 ✨✨✨这里是秋刀鱼不做梦的BLOG ✨✨✨想要了解更多内容可以访问我的主…

web基础,http协议,apache概念及nginx

一、web相关概念 Web,全称World Wide Web,通常简称为WWW、Web或万维网,是一个基于超文本和HTTP(超文本传输协议)的、全球性的、动态交互的、跨平台的分布式图形信息系统。它起源于1989年,由英国科学家蒂姆…

文本编辑三剑客(grep)

目录 正则表达式 元字符 grep 案例 我在编写脚本的时候发现,三个文本编辑的命令(grep、sed、awk,被称为文本编辑三剑客,我习惯叫它三巨头)用的还挺多的,说实话我一开始学的时候也有些懵,主要…