ConvGRU原理与开源代码

ConvGRU

    • 1. 算法简介与应用场景
    • 2. 算法原理
      • 2.1 GRU基础
      • 2.2 ConvGRU原理
        • 2.2.1 ConvGRU的结构
        • 2.2.2 卷积操作的优点
      • 2.3 GRU与ConvGRU的对比分析
      • 2.4 ConvGRU的应用
    • 3. PyTorch代码

仅需要网络源码的可以直接跳到末尾即可
需要ConvLSTM的可以参考我的另外一篇博客:小白也能读懂的ConvLSTM!(开源pytorch代码)

1. 算法简介与应用场景

ConvGRU(卷积门控循环单元)是一种结合了卷积神经网络(CNN)和门控循环单元(GRU)的深度学习模型。与ConvLSTM类似,ConvGRU也主要用于处理时空数据,特别适用于需要考虑空间特征和时间依赖关系的任务,如视频分析、气象预测和交通流量预测等。

在视频分析中,ConvGRU可以帮助识别和预测视频中的动态行为,利用时间序列的连续性和空间信息进行更准确的分析。在气象预测中,ConvGRU能够根据过去的气象数据(如降水、云图等)预测未来的天气情况。

2. 算法原理

2.1 GRU基础

在介绍ConvGRU之前,首先让我们回顾一下什么是门控循环单元(GRU)。GRU是一种特殊的循环神经网络(RNN),它通过引入门控机制来解决传统RNN在长序列训练中面临的梯度消失和爆炸问题。GRU单元主要包含两个门:重置门和更新门。这些门控制着信息在单元中的流动,从而有效地记住或遗忘信息。

GRU的核心公式如下:

  • 重置门
    r t = σ ( W r ⋅ [ h t − 1 , x t ] + b r ) r_t = \sigma(W_r \cdot [h_{t-1}, x_t] + b_r) rt=σ(Wr[ht1,xt]+br)

  • 更新门
    z t = σ ( W z ⋅ [ h t − 1 , x t ] + b z ) z_t = \sigma(W_z \cdot [h_{t-1}, x_t] + b_z) zt=σ(Wz[ht1,xt]+bz)

  • 候选状态
    h ~ t = tanh ⁡ ( W h ⋅ [ r t ∗ h t − 1 , x t ] + b h ) \tilde{h}_t = \tanh(W_h \cdot [r_t * h_{t-1}, x_t] + b_h) h~t=tanh(Wh[rtht1,xt]+bh)

  • 最终状态
    h t = ( 1 − z t ) ∗ h t − 1 + z t ∗ h ~ t h_t = (1 - z_t) * h_{t-1} + z_t * \tilde{h}_t ht=(1zt)ht1+zth~t

这里, h t h_t ht 是当前的隐藏状态, x t x_t xt 是当前的输入。

2.2 ConvGRU原理

ConvGRU在GRU的基础上引入了卷积操作。与ConvLSTM类似,ConvGRU使用卷积层来处理空间数据,从而能够更好地捕捉输入数据中的空间特征。

ConvGRU结构图

没找到ConvGRU的图,和LSTM道理一样的

2.2.1 ConvGRU的结构

ConvGRU的单元结构与GRU非常相似,但是在每个门的计算中使用了卷积操作。具体来说,ConvGRU的每个门的公式可以表示为:

z t = σ ( W z ∗ X t + U z ∗ H t − 1 + b z ) z_t = \sigma (W_{z} * X_t + U_{z} * H_{t-1} + b_z) zt=σ(WzXt+UzHt1+bz)
r t = σ ( W r ∗ X t + U r ∗ H t − 1 + b r ) r_t = \sigma (W_{r} * X_t + U_{r} * H_{t-1} + b_r) rt=σ(WrXt+UrHt1+br)
h ~ t = tanh ⁡ ( W h ∗ X t + U h ∗ ( r t ∗ H t − 1 ) + b h ) \tilde{h}_t = \tanh(W_{h} * X_t + U_{h} * (r_t * H_{t-1}) + b_h) h~t=tanh(WhXt+Uh(rtHt1)+bh)
h t = ( 1 − z t ) ∗ H t − 1 + z t ∗ h ~ t h_t = (1 - z_t) * H_{t-1} + z_t * \tilde{h}_t ht=(1zt)Ht1+zth~t

这里的所有 W W W U U U都是卷积权重, b b b是偏置项, σ \sigma σ 是 sigmoid 函数, tanh ⁡ \tanh tanh 是双曲正切函数。

ConvGRU结构图

2.2.2 卷积操作的优点
  1. 空间特征提取:卷积操作能够有效提取输入数据中的空间特征。对于图像数据,卷积操作可以捕捉局部特征,例如边缘、纹理等,这在时间序列数据中同样适用。

  2. 参数共享:卷积操作通过使用相同的卷积核在不同位置计算特征,从而减少了模型参数的数量,降低了计算复杂度。

  3. 平移不变性:卷积网络对输入数据的平移具有不变性,即相同的特征在不同位置都会被检测到,这对于时空序列数据来说是非常重要的。

2.3 GRU与ConvGRU的对比分析

特性GRUConvGRU
输入类型一维序列三维数据(时序的图像数据)
处理方式全连接层卷积操作
空间特征捕捉较弱较强
应用场景自然语言处理、时间序列预测图像序列预测、视频分析

2.4 ConvGRU的应用

ConvGRU在多个领域中表现出色,特别适合处理具有时空特征的数据。以下是一些主要的应用场景:

  • 气象预测:利用历史气象数据(如温度、湿度、降水等)来预测未来的天气情况。
  • 视频分析:对视频中的动态场景进行建模,识别和预测视频中的活动。
  • 交通流量预测:基于历史交通数据预测未来的交通流量,帮助城市交通管理。
  • 医学影像分析:分析医学影像序列(如CT、MRI)中的变化,辅助疾病诊断。

3. PyTorch代码

以下是一个简单的ConvGRU的网络完整代码:

import os
import torch
from torch import nn
from torch.autograd import Variableclass ConvGRUCell(nn.Module):def __init__(self, input_size, input_dim, hidden_dim, kernel_size, bias, dtype):"""初始化卷积 GRU 单元。:param input_size: (int, int)输入张量的高度和宽度作为 (height, width)。:param input_dim: int输入张量的通道数。:param hidden_dim: int隐藏状态的通道数。:param kernel_size: (int, int)卷积核的大小。:param bias: bool是否添加偏置项。:param dtype: torch.cuda.FloatTensor 或 torch.FloatTensor是否使用 CUDA。"""super(ConvGRUCell, self).__init__()self.height, self.width = input_sizeself.padding = kernel_size[0] // 2, kernel_size[1] // 2self.hidden_dim = hidden_dimself.bias = biasself.dtype = dtype# 定义用于计算更新门和重置门的卷积层self.conv_gates = nn.Conv2d(in_channels=input_dim + hidden_dim,out_channels=2 * self.hidden_dim,  # 用于更新门和重置门kernel_size=kernel_size,padding=self.padding,bias=self.bias)# 定义用于计算候选神经记忆的卷积层self.conv_can = nn.Conv2d(in_channels=input_dim + hidden_dim,out_channels=self.hidden_dim,  # 用于候选神经记忆kernel_size=kernel_size,padding=self.padding,bias=self.bias)def init_hidden(self, batch_size):"""初始化隐藏状态。:param batch_size: int批次大小。:return: Variable隐藏状态。"""return Variable(torch.zeros(batch_size, self.hidden_dim, self.height, self.width)).type(self.dtype)def forward(self, input_tensor, h_cur):"""前向传播函数。:param input_tensor: (b, c, h, w)输入张量实际上是目标模型。:param h_cur: (b, c_hidden, h, w)当前的隐藏状态。:return: h_next下一个隐藏状态。"""combined = torch.cat([input_tensor, h_cur], dim=1)combined_conv = self.conv_gates(combined)# 分割卷积输出以获取更新门和重置门gamma, beta = torch.split(combined_conv, self.hidden_dim, dim=1)reset_gate = torch.sigmoid(gamma)update_gate = torch.sigmoid(beta)# 使用重置门乘以当前隐藏状态combined = torch.cat([input_tensor, reset_gate * h_cur], dim=1)cc_cnm = self.conv_can(combined)cnm = torch.tanh(cc_cnm)# 更新隐藏状态h_next = (1 - update_gate) * h_cur + update_gate * cnmreturn h_nextclass ConvGRU(nn.Module):def __init__(self, input_size, input_dim, hidden_dim, kernel_size, num_layers,dtype, batch_first=False, bias=True, return_all_layers=False):"""初始化卷积 GRU 模型。:param input_size: (int, int)输入张量的高度和宽度作为 (height, width)。:param input_dim: int输入张量的通道数。:param hidden_dim: int隐藏状态的通道数。:param kernel_size: (int, int)卷积核的大小。:param num_layers: int卷积 GRU 层的数量。:param dtype: torch.cuda.FloatTensor 或 torch.FloatTensor是否使用 CUDA。:param batch_first: bool如果数组的第一个位置是批次。:param bias: bool是否添加偏置项。:param return_all_layers: bool是否返回所有层的隐藏状态。"""super(ConvGRU, self).__init__()# 确保 kernel_size 和 hidden_dim 的长度与层数一致kernel_size = self._extend_for_multilayer(kernel_size, num_layers)hidden_dim = self._extend_for_multilayer(hidden_dim, num_layers)if not len(kernel_size) == len(hidden_dim) == num_layers:raise ValueError('不一致的列表长度。')self.height, self.width = input_sizeself.input_dim = input_dimself.hidden_dim = hidden_dimself.kernel_size = kernel_sizeself.dtype = dtypeself.num_layers = num_layersself.batch_first = batch_firstself.bias = biasself.return_all_layers = return_all_layerscell_list = []for i in range(0, self.num_layers):# 确定当前层的输入维度cur_input_dim = input_dim if i == 0 else hidden_dim[i - 1]# 创建并添加卷积 GRU 单元到列表cell_list.append(ConvGRUCell(input_size=(self.height, self.width),input_dim=cur_input_dim,hidden_dim=self.hidden_dim[i],kernel_size=self.kernel_size[i],bias=self.bias,dtype=self.dtype))# 将 Python 列表转换为 PyTorch 模块self.cell_list = nn.ModuleList(cell_list)def forward(self, input_tensor, hidden_state=None):"""前向传播函数。:param input_tensor: (b, t, c, h, w) 或 (t, b, c, h, w)从 AlexNet 提取的特征。:param hidden_state:初始隐藏状态。:return: layer_output_list, last_state_list各个层的输出列表以及最后一个状态列表。"""if not self.batch_first:# 如果不是按批次优先,则重新排列维度input_tensor = input_tensor.permute(1, 0, 2, 3, 4)# 实现状态化的卷积 GRUif hidden_state is not None:raise NotImplementedError()else:# 初始化隐藏状态hidden_state = self._init_hidden(batch_size=input_tensor.size(0))layer_output_list = []last_state_list = []seq_len = input_tensor.size(1)cur_layer_input = input_tensorfor layer_idx in range(self.num_layers):h = hidden_state[layer_idx]output_inner = []for t in range(seq_len):# 计算当前层的下一个隐藏状态h = self.cell_list[layer_idx](input_tensor=cur_layer_input[:, t, :, :, :],h_cur=h)output_inner.append(h)# 将序列内的隐藏状态堆叠起来layer_output = torch.stack(output_inner, dim=1)cur_layer_input = layer_outputlayer_output_list.append(layer_output)last_state_list.append([h])if not self.return_all_layers:# 如果不需要返回所有层,则只返回最后一层的输出和状态layer_output_list = layer_output_list[-1:]last_state_list = last_state_list[-1:]return layer_output_list, last_state_listdef _init_hidden(self, batch_size):"""初始化隐藏状态。:param batch_size: int批次大小。:return: list每一层的初始化隐藏状态列表。"""init_states = []for i in range(self.num_layers):init_states.append(self.cell_list[i].init_hidden(batch_size))return init_states@staticmethoddef _check_kernel_size_consistency(kernel_size):"""检查 kernel_size 的一致性。:param kernel_size: tuple 或 list of tuples卷积核大小。"""if not (isinstance(kernel_size, tuple) or(isinstance(kernel_size, list) and all([isinstance(elem, tuple) for elem in kernel_size]))):raise ValueError('`kernel_size` 必须是 tuple 或 list of tuples')@staticmethoddef _extend_for_multilayer(param, num_layers):"""扩展参数以适应多层结构。:param param: int 或 list参数。:param num_layers: int层数。:return: list扩展后的参数列表。"""if not isinstance(param, list):param = [param] * num_layersreturn param

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/51500.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

初识HTML文件,创建自己的第一个网页!

本文旨在初步介绍HTML(超文本标记语言),帮助读者理解HTML中的相关术语及概念,并使读者在完成本文的阅读后可以快速上手编写一个属于自己的简易网页。 一、HTML介绍 HTML(全称HyperText Markup Language,超文本标记语言…

【C++】位图 + 布隆过滤器

目录 1. 位图1.1. 概念1.2. 实现1.3. 应用 2. 布隆过滤器2.1. 背景2.2. 概念2.3. 实现2.4. 优点2.5. 缺点 3. 海量数据面试题3.1. 哈希切割3.2. 位图应用3.3. 布隆过滤器3.4. 总结 1. 位图 1.1. 概念 位图是一种用于高效地存储和操作集合的数据结构。它的基本思想是使用一个二…

高并发内存池(四)Page Cache的框架及内存申请实现

目录 一、Page Cache的框架梳理 二、Page Cache的实现 2.1PageCache.h 2.2VirtualAlloc 2.3std::unordered_map _idSpanMap,> 2.4Page Cache.cpp 一、Page Cache的框架梳理 申请内存: 1. 当central cache向page cache申请内存时,page cache先检…

Intel 13/14代不稳定 微星率先发声:密切监视、8月中旬更新微码

不久前,Intel针对14/14代酷睿i9 K系列不稳定的问题发布了最新声明,确认问题源于微代码算法缺陷与电压过高,并承诺将在8月中旬完成新版BIOS的验证,随后发放。现在,微星在各家主板厂商中第一个站出来,表明了态…

Java 使用 POI 导出Excel,实现单元格输入内容提示功能

在使用Apache POI的库生成Excel导入模板的时候,有时候需要对单元格能够输入的内容进行一个提示,该如何实现这个特性呢?下面是一个示例代码,演示如何实现单元格输入内容提示功能。 代码 import org.apache.poi.ss.usermodel.*; im…

Frienda 4 件套幽灵狩猎猫球运动发光猫球 LED 运动激活猫球运动点亮猫狗互动玩具宠物发光迷你跑步健身球

来自 美国亚马逊:商品评论: Frienda 4 件套幽灵狩猎猫球运动发光猫球 LED 运动激活猫球运动点亮猫狗互动玩具宠物发光迷你跑步健身球玩具(亮色) (amazon.com) Kim 1.0 颗星,最多 5 颗星 Battery does not last/ cant replace 2024年5月29日 在美国审核…

lora微调Qwen模型全流程

LoRA 微调 Qwen 模型的技术原理概述 LoRA(Low-Rank Adaptation)是一种用于大模型高效微调的方法。通过对模型参数进行低秩分解和特定层的微调,LoRA 能在保持模型性能的前提下显著减少训练所需的参数量和计算资源。接下来是对 LoRA 微调 Qwen…

鸿蒙开发—黑马云音乐之首页导航栏

目录 1.底部导航 2.点击导航栏的时候点亮 3.新建tabbar对应的页面并加载 1.底部导航 Entry Component struct Index {State message: string 首页BuildertabBuilder(text:string,img:Resource) {// 未选中状态样式处理Column({ space: 5 }) {Image(img).width(25).border…

[C++进阶]抽象类

一、抽象类 1.抽象类的概念 在虚函数的后面写上 0 ,则这个函数为纯虚函数。包含纯虚函数的类叫做抽象类(也叫接口类),抽象类不能实例化出对象。派生类继承后也不能实例化出对象,只有重写纯虚函数,派生类才…

unity3d:TabView,UGUI多标签页组件,TreeView树状展开菜单

概述 1.最外层DataForm为空壳编辑数据用。可以有多个DataForm,例如福利DataForm,抽奖DataForm 2.Menu层为左边栏层,每个DataForm可以使用不同样式的MenuForm预制体 3.DataForm中使用ReorderList,可排列配置 4.有定位功能&#xf…

Clickhouse 生产集群部署(Centos 环境)

文章目录 机器环境配置安装 JDK 8安装 zookeeperClickhouse 集群安装rpm 包离线安装修改全局配置zookeeper配置Shard和Replica设置image.png添加macros配置启动 clickhouse启动 10.82.46.135 clickhouse server启动 10.82.46.163 clickhouse server启动 10.82.46.218 clickhous…

《InheriBT行为树》For Unity

InheriBT: Unity Editor中的行为树编辑框架 行为树(Behavior Tree)是一种广泛应用于人工智能(AI)领域的决策模型,特别是在游戏开发中。行为树通过分层结构和节点的组合,实现了复杂行为的简洁表达。然而&am…

CPU350% JVM GC频繁并GC不掉EXCEL导出

背景: 有个Excel导出的需求,测试的时候,只要连续导出大量的数据就会导致FAT机器反请求反应迟钝,甚至卡死,无法恢复。 排查: 1 跳板机跳到机器上,查看 项目 ipd 执行ps -ef | grep 项目名称.j…

虚拟机Ubuntu20.04 利用串口调试机械臂

虚拟机Ubuntu20.04 利用串口调试机械臂 串口库问题 由于机械臂使用的是串口进行驱动控制,在python中相关的串口库为serial和pyserial两个,这里我曾踩过雷同时安装了serial与pyserial两个库,导致报错如下所示: AttributeError: m…

数据结构:(1)线性表

一、基本概念 概念:零个或多个数据元素的有限序列 元素之间是有顺序了。如果存在多个元素,第一个元素无前驱,最后一个没有后继,其他的元素只有一个前驱和一个后继。 当线性表元素的个数n(n>0&am…

使用Spring Boot与Spire.Doc实现Word文档的多样化操作

​ 博客主页: 南来_北往 系列专栏:Spring Boot实战 前言 使用Spring Boot与Spire.Doc实现Word文档的多样化操作具有以下优势: 强大的功能组合:Spring Boot提供了快速构建独立和生产级的Spring应用程序的能力,而Spire.Doc则…

OSError: You are trying to access a gated repo.解决方案

大家好,我是爱编程的喵喵。双985硕士毕业,现担任全栈工程师一职,热衷于将数据思维应用到工作与生活中。从事机器学习以及相关的前后端开发工作。曾在阿里云、科大讯飞、CCF等比赛获得多次Top名次。现为CSDN博客专家、人工智能领域优质创作者。喜欢通过博客创作的方式对所学的…

『 Linux 』用户态与内核态的转换机制及信号检测时机

文章目录 用户态与内核态进程地址空间操作系统的本质 信号的处理时机 用户态与内核态 进程在执行代码的过程中代码必定涉及用户代码,库函数代码及操作系统内核代码; 以简单的printf()函数为例,该函数必定为先执行用户的代码即知道需要调用printf()函数,再执行库(如libc)中的代码…

Java线程同步与通信:wait(), notify(), notifyAll(), sleep()

Java线程同步与通信:wait(), notify(), notifyAll(), sleep() 1. wait()2. notify()3. notifyAll()4. sleep()4、总结 💖The Begin💖点点关注&…