【Python/Pytorch - 网络模型】-- 手把手搭建E3D LSTM网络

在这里插入图片描述
文章目录

文章目录

  • 00 写在前面
  • 01 基于Pytorch版本的E3D LSTM代码
  • 02 论文下载

00 写在前面

测试代码,比较重要,它可以大概判断tensor维度在网络传播过程中,各个维度的变化情况,方便改成适合自己的数据集。

需要github上的数据集以及可运行的代码,可以私聊!

01 基于Pytorch版本的E3D LSTM代码

# 库函数调用
from functools import reduce
from src.utils import nice_print, mem_report, cpu_stats
import copy
import operator
import torch
import torch.nn as nn
import torch.nn.functional as F# E3DLSTM模型代码
class E3DLSTM(nn.Module):def __init__(self, input_shape, hidden_size, num_layers, kernel_size, tau):super().__init__()self._tau = tauself._cells = []input_shape = list(input_shape)for i in range(num_layers):cell = E3DLSTMCell(input_shape, hidden_size, kernel_size)# NOTE hidden state becomes input to the next cellinput_shape[0] = hidden_sizeself._cells.append(cell)# Hook to register submodulesetattr(self, "cell{}".format(i), cell)def forward(self, input):# NOTE (seq_len, batch, input_shape)batch_size = input.size(1)c_history_states = []h_states = []outputs = []for step, x in enumerate(input):for cell_idx, cell in enumerate(self._cells):if step == 0:c_history, m, h = self._cells[cell_idx].init_hidden(batch_size, self._tau, input.device)c_history_states.append(c_history)h_states.append(h)# NOTE c_history and h are coming from the previous time stamp, but we iterate over cellsc_history, m, h = cell(x, c_history_states[cell_idx], m, h_states[cell_idx])c_history_states[cell_idx] = c_historyh_states[cell_idx] = h# NOTE hidden state of previous LSTM is passed as input to the next onex = houtputs.append(h)# NOTE Concat along the channelsreturn torch.cat(outputs, dim=1)class E3DLSTMCell(nn.Module):def __init__(self, input_shape, hidden_size, kernel_size):super().__init__()in_channels = input_shape[0]self._input_shape = input_shapeself._hidden_size = hidden_size# memory gates: input, cell(input modulation), forgetself.weight_xi = ConvDeconv3d(in_channels, hidden_size, kernel_size)self.weight_hi = ConvDeconv3d(hidden_size, hidden_size, kernel_size, bias=False)self.weight_xg = copy.deepcopy(self.weight_xi)self.weight_hg = copy.deepcopy(self.weight_hi)self.weight_xr = copy.deepcopy(self.weight_xi)self.weight_hr = copy.deepcopy(self.weight_hi)memory_shape = list(input_shape)memory_shape[0] = hidden_size# self.layer_norm = nn.LayerNorm(memory_shape)self.group_norm = nn.GroupNorm(1, hidden_size) # wzj# for spatiotemporal memoryself.weight_xi_prime = copy.deepcopy(self.weight_xi)self.weight_mi_prime = copy.deepcopy(self.weight_hi)self.weight_xg_prime = copy.deepcopy(self.weight_xi)self.weight_mg_prime = copy.deepcopy(self.weight_hi)self.weight_xf_prime = copy.deepcopy(self.weight_xi)self.weight_mf_prime = copy.deepcopy(self.weight_hi)self.weight_xo = copy.deepcopy(self.weight_xi)self.weight_ho = copy.deepcopy(self.weight_hi)self.weight_co = copy.deepcopy(self.weight_hi)self.weight_mo = copy.deepcopy(self.weight_hi)self.weight_111 = nn.Conv3d(hidden_size + hidden_size, hidden_size, 1)def self_attention(self, r, c_history):batch_size = r.size(0)channels = r.size(1)r_flatten = r.view(batch_size, -1, channels)# BxtaoTHWxCc_history_flatten = c_history.view(batch_size, -1, channels)# Attention mechanism# BxTHWxC x BxtaoTHWxC' = B x THW x taoTHWscores = torch.einsum("bxc,byc->bxy", r_flatten, c_history_flatten)attention = F.softmax(scores, dim=2)return torch.einsum("bxy,byc->bxc", attention, c_history_flatten).view(*r.shape)def self_attention_fast(self, r, c_history):# Scaled Dot-Product but for tensors# instead of dot-product we do matrix contraction on twh dimensionsscaling_factor = 1 / (reduce(operator.mul, r.shape[-3:], 1) ** 0.5)scores = torch.einsum("bctwh,lbctwh->bl", r, c_history) * scaling_factorattention = F.softmax(scores, dim=0)return torch.einsum("bl,lbctwh->bctwh", attention, c_history)def forward(self, x, c_history, m, h):# Normalized shape for LayerNorm is CxT×H×Wnormalized_shape = list(h.shape[-3:])def LR(input):# return F.layer_norm(input, normalized_shape)return self.group_norm(input, normalized_shape) # wzj# R is CxT×H×Wr = torch.sigmoid(LR(self.weight_xr(x) + self.weight_hr(h)))i = torch.sigmoid(LR(self.weight_xi(x) + self.weight_hi(h)))g = torch.tanh(LR(self.weight_xg(x) + self.weight_hg(h)))recall = self.self_attention_fast(r, c_history)# nice_print(**locals())# mem_report()# cpu_stats()c = i * g + self.group_norm(c_history[-1] + recall) # wzji_prime = torch.sigmoid(LR(self.weight_xi_prime(x) + self.weight_mi_prime(m)))g_prime = torch.tanh(LR(self.weight_xg_prime(x) + self.weight_mg_prime(m)))f_prime = torch.sigmoid(LR(self.weight_xf_prime(x) + self.weight_mf_prime(m)))m = i_prime * g_prime + f_prime * mo = torch.sigmoid(LR(self.weight_xo(x)+ self.weight_ho(h)+ self.weight_co(c)+ self.weight_mo(m)))h = o * torch.tanh(self.weight_111(torch.cat([c, m], dim=1)))# TODO is it correct FIFO?c_history = torch.cat([c_history[1:], c[None, :]], dim=0)# nice_print(**locals())return (c_history, m, h)def init_hidden(self, batch_size, tau, device=None):memory_shape = list(self._input_shape)memory_shape[0] = self._hidden_sizec_history = torch.zeros(tau, batch_size, *memory_shape, device=device)m = torch.zeros(batch_size, *memory_shape, device=device)h = torch.zeros(batch_size, *memory_shape, device=device)return (c_history, m, h)class ConvDeconv3d(nn.Module):def __init__(self, in_channels, out_channels, *vargs, **kwargs):super().__init__()self.conv3d = nn.Conv3d(in_channels, out_channels, *vargs, **kwargs)# self.conv_transpose3d = nn.ConvTranspose3d(out_channels, out_channels, *vargs, **kwargs)def forward(self, input):# print(self.conv3d(input).shape, input.shape)# return self.conv_transpose3d(self.conv3d(input))return F.interpolate(self.conv3d(input), size=input.shape[-3:], mode="nearest")class Out(nn.Module):def __init__(self, in_channels, out_channels):super().__init__()self.conv = nn.Conv3d(in_channels, out_channels, kernel_size = 3, stride=1, padding=1)def forward(self, x):return self.conv(x)class E3DLSTM_NET(nn.Module):def __init__(self, input_shape, hidden_size, num_layers, kernel_size, tau, time_steps, output_shape):super().__init__()self.input_shape = input_shapeself.hidden_size = hidden_sizeself.num_layers = num_layersself.kernel_size = kernel_sizeself.tau = tauself.time_steps = time_stepsself.output_shape = output_shapeself.dtype = torch.float32self.encoder = E3DLSTM(input_shape, hidden_size, num_layers, kernel_size, tau).type(self.dtype)self.decoder = nn.Conv3d(hidden_size * time_steps, output_shape[0], kernel_size, padding=(0, 2, 2)).type(self.dtype)self.out = Out(4, 1)def forward(self, input_seq):return self.out(self.decoder(self.encoder(input_seq)))# 测试代码
if __name__ == '__main__':input_shape = (16, 4, 16, 16)output_shape = (16, 1, 16, 16)tau = 2hidden_size = 64kernel = (3, 5, 5)lstm_layers = 4time_steps = 29x = torch.ones([29, 2, 16, 4, 16, 16])model = E3DLSTM_NET(input_shape, hidden_size, lstm_layers, kernel, tau, time_steps, output_shape)print('finished!')f = model(x)print(f)

02 论文下载

Eidetic 3D LSTM: A Model for Video Prediction and Beyond
Eidetic 3D LSTM: A Model for Video Prediction and Beyond
Github链接:e3d_lstm

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/855216.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

这些数据可被Modbus采集,你还不知道???

为什么要用Modbus采集模块 Modbus采集模块之所以被广泛使用,是因为它提供了标准化的通信协议,确保了不同设备间的兼容性。它支持多种通信方式,易于实现,并且能够适应不同的网络环境。Modbus模块能够收集和传输各种工业数据&#x…

061、Python 包:模块管理

包(Package)是一种用于组织模块的层次结构。包实际上就是一个包含了__init__.py文件的目录,该文件可以为空或包含包的初始化代码。通过使用包,可以更好地组织和管理大型项目中的模块,避免命名冲突,并提高代…

kettle从入门到精通 第七十一课 ETL之kettle 再谈http post,轻松掌握body中传递json参数

场景: kettle中http post步骤如何发送http请求且传递body参数? 解决方案: http post步骤中直接设置Request entity field字段即可。 1、手边没有现成的post接口,索性用python搭建一个简单的接口,关键代码如下&#…

深度学习模型的生命周期与推理系统架构

目录 深度学习模型的生命周期 ​编辑 深度学习模型的生命周期 推理相比训练的新特点与挑战 推理系统架构 推理系统 vs 推理引擎 顶层:API接口和模型转换 中层:运行时(计算引擎) 底层:硬件级优化 边缘设备计算 主要问题 边缘部署和推理方式 方式1:边缘设备计…

可提供实习证明/实习鉴定报告,企业项目试岗实训开营啦

在数字化转型的浪潮中,大数据和人工智能等前沿技术已成为推动经济发展和科技进步的关键动力。当前,全球各行各业都在积极推进数字化转型,不仅为经济增长注入新活力,也对人才市场结构产生了深刻影响,尤其是对数字化人才…

在 KubeSphere 上快速安装和使用 KDP 云原生数据平台

作者简介:金津,智领云高级研发经理,华中科技大学计算机系硕士。加入智领云 8 余年,长期从事云原生、容器化编排领域研发工作,主导了智领云自研的 BDOS 应用云平台、云原生大数据平台 KDP 等产品的开发,并在…

联邦学习周记|第四周

论文:Active Federated Learning 链接 将主动学习引入FL,每次随机抽几个Client拿来train,把置信值低的Client概率调大,就能少跑几次。 论文:Active learning based federated learning for waste and natural disast…

“Git之道:掌握常用命令,轻松管理代码“

目录 1. 初始化和配置 2. 提交和更新 3. 分支和合并 4. 查看和比较 5. 远程仓库 6. 文件操作命令 1. 初始化和配置 git init:在当前目录初始化一个新的Git仓库git config:配置Git的全局或局部选项git clone:从远程仓库克隆一个本地副本…

vue3第四十节(pinia的用法注意事项解构store)

pinia 主要包括以下五部分,经常用到的是 store、state、getters、actions 以下使用说明,注意事项,仅限于 vue3 setup 语法糖中使用,若使用选项式 API 请直接查看官方文档: 一、前言: pinia 是为了探索 vu…

动手学深度学习(Pytorch版)代码实践 -深度学习基础-11暂退法Dropout

11暂退法Dropout #Dropout 是一种正则化技术,主要用于防止过拟合, #通过在训练过程中随机丢弃神经元来提高模型的泛化能力。 import torch from torch import nn from d2l import torch as d2l import liliPytorch as lpdef dropout_layer(X, dropout):…

大数据—“西游记“全集文本数据挖掘分析实战教程

项目背景介绍 四大名著,又称四大小说,是汉语文学中经典作品。这四部著作历久不衰,其中的故事、场景,已经深深地影响了国人的思想观念、价值取向。四部著作都有很高的艺术水平,细致的刻画和所蕴含的思想都为历代读者所…

0元体验苹果macOS系统,最简单的虚拟机部署macOS教程

前言 最近发现小伙伴热衷于在VMware上安装体验macOS系统,所以就有了今天的帖子。 正文开始 首先,鉴于小伙伴们热衷macOS,所以小白搜罗了一圈macOS系统,并开启了分享通道。 本次更新的系统版本是: macOS 10.13.6 ma…

【靶场搭建】-01- 在kali上搭建DVWA靶机

1.DVWA靶机 DVWA(Damn Vulnerable Web Application)是使用PHPMysql编写的web安全测试框架,主要用于安全人员在一个合法的环境中测试技能和工具。 2.下载DVWA 从GitHub上将DVWA的源码clone到kali上 git clone https://github.com/digininj…

温湿度采集与OLED显示

目录 一、什么是软件I2C 二、什么是硬件I2C 三、STM32CubeMX配置 1、RCC配置 2、SYS配置 3、I2C1配置 3、I2C2配置 4、USART1配置 5、TIM1配置 6、时钟树配置 7、工程配置 四、设备链接 1、OLED连接 2、串口连接 3、温湿度传感器连接 五、每隔2秒钟采集一次温湿…

第二十三篇——香农第二定律(二):到底要不要扁平化管理?

目录 一、背景介绍二、思路&方案三、过程1.思维导图2.文章中经典的句子理解3.学习之后对于投资市场的理解4.通过这篇文章结合我知道的东西我能想到什么? 四、总结五、升华 一、背景介绍 对于企业的理解,扁平化的管理,如果从香农第二定律…

Qt 实战(5)布局管理器 | 5.2、深入解析Qt布局管理器

文章目录 一、深入解析Qt布局管理器1、为什么要使用布局管理器?2、布局管理器类型3、布局管理器用法详解3.1、QBoxLayout(垂直与水平布局)3.2、QGridLayout(网格布局)3.3、QFormLayout(表单布局&#xff09…

特斯拉、路特斯、中国一汽、毕博、博世等企业将出席中国汽车供应链降碳和可持续国际峰会

由ECV International 举办的2024中国汽车供应链脱碳与可持续国际峰会将于2024年9月23-24日在上海召开。 在本次峰会上,来自全球各地的行业领袖、政策制定者、研究人员和利益相关者将齐聚一堂,商讨对于减少碳排放和促进整个汽车供应链可持续实践至关重要…

教学资源共享平台的设计

管理员账户功能包括:系统首页,个人中心,管理员管理,老师管理,用户管理,成绩管理,教学资源管理,作业管理 老师账户功能包括:系统首页,个人中心,用…

什么是拷贝?我:Ctrl + C ...

前言 当谈及拷贝,你的第一印象会不会和我一样,ctrl c ctrl v ... ;虽然效果和拷贝是一样的,但是你知道拷贝的原理以及它的实现方法吗?今天就让我们一起探究一下拷贝中深藏的知识点吧。 拷贝 首先来看下面一段代码…

MySQL数据库回顾(1)

数据库相关概念 关系型数据库 概念: 建立在关系模型基础上,由多张相互连接的二维表组成的数据库。 特点: 1.使用表存储数据,格式统一,便于维护 2.使用SQL语言操作,标准统一,使用方便 SOL SQL通用语法 …