【时间序列预测】基于PyTorch实现CNN_BiLSTM算法

文章目录

  • 1. CNN与BiLSTM
  • 2. 完整代码实现
  • 3. 代码结构解读
    • 3.1 CNN Layer
    • 3.2 BiLSTM Layer
    • 3.3 Output Layer
    • 3.4 forward Layer
  • 4. 应用场景
  • 5. 总结

  本文将详细介绍如何使用Pytorch实现一个结合卷积神经网络(CNN)双向长短期记忆网络(BiLSTM)的混合模型—CNN_BiLSTM。这种网络架构结合了CNN在提取局部特征方面的优势和BiLSTM在建模序列数据时的长期依赖关系的能力,特别适用于时序数据的预测任务,如时间序列分析、风速预测、股票预测等。

1. CNN与BiLSTM

  • CNN主要通过卷积操作对输入数据进行特征提取,适合于处理局部结构化的特征(如图像数据、时间序列数据中的局部模式)。
  • BiLSTM则是基于LSTM的变种,它通过双向遍历序列,可以同时捕捉过去和未来的信息,使其在处理时间序列数据时非常有效。
  • 在本例中,CNN负责提取时间序列数据的局部特征,而BiLSTM则进一步捕捉数据中的时序依赖关系,最终通过全连接层输出预测结果。

2. 完整代码实现

"""
CNN_BiLSTM Network
"""
from torch import nnclass CNN_BiLSTM(nn.Module):r"""CNN_BiLSTMArgs:cnn_in_channels : CNN输入通道数, if in.shape=[64,7,18] value=7bilstm_input_size : bilstm输入大小, if in.shape=[64,7,18] value=18output_size :  期望网络输出大小cnn_out_channels:  CNN层输出通道数cnn_kernal_size :  CNN层卷积核大小maxpool_kernal_size:  MaxPool Layer kernal_sizebilstm_hidden_size: BiLSTM Layer hidden_dimbilstm_num_layers: BiLSTM Layer num_layersdropout:  dropout防止过拟合, 取值(0,1)bilstm_proj_size: BiLSTM Layer proj_sizeExample:>>> import torch>>> input = torch.randn([64,7,18])>>> model = CNN_BiLSTM(7, 18,18)>>> out = model(input)"""def __init__(self,cnn_in_channels,bilstm_input_size,output_size,cnn_out_channels=32,cnn_kernal_size=3,maxpool_kernal_size=3,bilstm_hidden_size=128,bilstm_num_layers=4,dropout = 0.05,bilstm_proj_size = 0):super().__init__()# CNN Layerself.conv1d = nn.Conv1d(in_channels=cnn_in_channels, out_channels=cnn_out_channels, kernel_size=cnn_kernal_size, padding="same")self.relu = nn.ReLU()self.maxpool = nn.MaxPool1d(kernel_size= maxpool_kernal_size)# BiLSTM Layerself.bilstm = nn.LSTM(input_size = int(bilstm_input_size/maxpool_kernal_size),hidden_size = bilstm_hidden_size,num_layers = bilstm_num_layers,batch_first = True,dropout = dropout,bidirectional = True,proj_size = bilstm_proj_size)# output Layerself.fc = nn.Linear(bilstm_hidden_size*2,output_size)def forward(self, x):x = self.conv1d(x)x = self.relu(x)x = self.maxpool(x)bilstm_out,_ = self.bilstm(x)x = self.fc(bilstm_out[:, -1, :])return x      

3. 代码结构解读

3.1 CNN Layer

卷积层(Conv1d)用于提取局部特征,通常用于处理时间序列数据中的局部模式。它的输入是具有多个特征(例如风速、气压、湿度等)的时序数据。

相关代码:

# CNN Layer
self.conv1d = nn.Conv1d(in_channels=cnn_in_channels, out_channels=cnn_out_channels, kernel_size=cnn_kernal_size, padding="same")
self.relu = nn.ReLU()
self.maxpool = nn.MaxPool1d(kernel_size= maxpool_kernal_size)
  • cnn_in_channels : 表示输入通道数
  • cnn_out_channels : 表示卷积层的输出通道数
  • cnn_kernal_size : 为卷积核大小
  • padding : "same"表示特征输入大小和输出大小一致
  • maxpool_kernal_size : 为池化操作的核大小

3.2 BiLSTM Layer

双向长短期记忆网络(BiLSTM)用于捕捉时序数据中的长程依赖关系。

相关代码:

# BiLSTM Layer
self.bilstm = nn.LSTM(input_size = int(bilstm_input_size/maxpool_kernal_size),hidden_size = bilstm_hidden_size,num_layers = bilstm_num_layers,batch_first = True,dropout = dropout,bidirectional = True,proj_size = bilstm_proj_size
)
  • bilstm_input_size : 表示输入的特征维度
  • bilstm_hidden_size : 表示LSTM隐藏状态的维度
  • bilstm_num_layers : 是LSTM的层数
  • dropout : 用于防止过拟合
  • bilstm_proj_size : 是LSTM的投影层大小(如果需要)

3.3 Output Layer

全连接层(fc)将BiLSTM的输出映射到最终的预测结果。输出的维度为output_size,通常是我们需要预测的目标维度(例如未来的功率值)。

相关代码:

# output Layer
self.fc = nn.Linear(bilstm_hidden_size*2, output_size)

output_size : 输出维度大小

3.4 forward Layer

相关代码:

def forward(self, x):x = self.conv1d(x)x = self.relu(x)x = self.maxpool(x)bilstm_out,_ = self.bilstm(x)x = self.fc(bilstm_out[:, -1, :])return x      
  • 输入: 是一个三维张量,形状为 [batch_size, input_channels, seq_len],其中input_channels是输入数据的特征数(例如风速、湿度等),seq_len是时间步数(即输入序列的长度)。
  • CNN部分:首先通过卷积层提取局部特征,然后应用ReLU激活函数引入非线性,最后通过最大池化(MaxPool1d)对特征进行降维,减少计算量。
  • BiLSTM部分: 接着,将经过CNN处理后的特征传递给BiLSTM,捕捉时间序列中的长期依赖关系。BiLSTM的双向性使得模型能够同时考虑过去和未来的上下文信息。
  • 输出: 最终,模型通过全连接层(fc)将BiLSTM的最后一个时间步的输出映射为期望的输出大小。

4. 应用场景

  这个模型适合用于处理时间序列数据的预测任务,特别是在风力发电预测、气象预测、股市预测等领域。CNN用于从输入数据中提取局部特征,而BiLSTM则能够捕捉输入数据的长期时序依赖关系。因此,模型既能有效地处理局部特征,又能关注到长时间范围内的依赖关系,从而提高预测的准确性。

5. 总结

  本文详细介绍了如何使用Pytorch实现一个基于CNN和BiLSTM的混合模型(CNN_BiLSTM)。该模型结合了CNN在局部特征提取上的优势和BiLSTM在序列建模上的长程依赖能力,适用于时序数据的预测任务。在实际应用中,可以根据任务的不同调整CNN和LSTM的层数、通道数和隐藏状态维度等超参数,以提高模型的预测精度。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/889051.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

筑起厂区安全--叉车安全防护装置全解析

在繁忙的工业生产领域中,叉车作为搬运工,穿梭于仓储与生产线之间。然而,叉车的高效运作背后,也隐藏着诸多安全风险,尤其是在那些空间狭小、物流繁忙的环境中。为了降低这些潜在的危险,叉车安全防护装置便成…

MS SQL SERVER服务自动停止解决

报错原因 电脑异常重启,导致 MS SQL server 服务启动后自动停止,在【计算机管理】-【事件查看器】-【windows日志】中进行查看系统错误日志,在【应用程序】下发现可能的错误信息: 传递给数据库 ‘model’ 中的日志扫描操作的日志…

网络安全信息收集(总结)更新

目录 重点: 前言: 又学到了,就是我们什么时候要子域名收集,什么时候收集域名,重点应该放前面 思考: 信息收集分为哪几类,什么是主域名,为什么要收集主域名,为什么要收…

使用 Postman 上传二进制类型的图片到后端接口写法

我们有的时候会有需求,就是通过 postman 传递二进制图片到后端接口,如下图: 那我们的 Java 接口需要怎么写呢? Spring Boot 接收这些数据的方式需要使用 RequestBody 注解来处理原始的二进制数据(byte[])。…

Java Web 7 请求响应(Postman)

前言(SpringBoot程序请求响应流程) 以上一章的程序为例,一个基于SpringBoot的方式开发一个web应用,浏览器发起请求 /hello 后 ,给浏览器返回字符串 “Hello World ~”。 而我们在开发web程序时呢,定义了一…

pytorch多GPU训练教程

pytorch多GPU训练教程 文章目录 pytorch多GPU训练教程1. Torch 的两种并行化模型封装1.1 DataParallel1.2 DistributedDataParallel 2. 多GPU训练的三种架构组织方式2.2 数据不拆分,模型拆分(Model Parallelism)2.3 数据拆分,模型…

【opencv入门教程】9.视频加载

文章选自: 一、VideoCapture类 用于从视频文件、图像序列或摄像头捕获视频的类。函数:CV_WRAP VideoCapture();brief 默认构造函数CV_WRAP explicit VideoCapture(const String& filename, int apiPreference CAP_ANY);brief 使用 API 首选项打开…

海外的bug-hunters,不一样的403bypass

一种绕过403的新技术,跟大家分享一下。研究HTTP协议已经有一段时间了。发现HTTP协议的1.0版本可以绕过403。于是开始对lyncdiscover.microsoft.com域做FUZZ并且发现了几个403Forbidden的文件。 (访问fsip.svc为403) 在经过尝试后&#xff0…

阿里内部正式开源“Spring Cloud Alibaba (全彩小册)”

年轻的毕业生们满怀希望与忐忑,去寻找、竞争一个工作机会。已经在职的开发同学,也想通过社会招聘或者内推的时机争取到更好的待遇、更大的平台。 然而,面试人群众多,技术市场却相对冷淡,面试的同学们不得不面临着 1 个…

乘上 SpringBoot 东风,广场舞团掀起律动热潮

2 系统开发环境 2.1 Java技术 Java是由Sun公司推出的一门跨平台的面向对象的程序设计语言。因为Java 技术具有卓越的通用性、高效性、健壮的安全性和平台移植性的特点,而且Java是开源的,拥有全世界最大的开发者专业社群,所以Java的发展迅速。…

【PlantUML系列】状态图(六)

一、状态图的组成部分 状态:对象在其生命周期内可能处于的条件或情形,使用 state "State Name" as Statename 表示。初始状态:表示对象生命周期的开始,使用 [*] 表示。最终状态:表示对象生命周期的结束&…

创建 React Native 项目

创建 React Native 项目 npx react-nativelatest init YourProject切换依赖源 切换好源之后,你需要进入 android 目录,然后运行 gradlew build 命令。 Android 依赖安装是使用 gradlew 进行管理的。 $ cd android $ ./gradlew build --refresh-depend…

Linx下自动化之路:Redis安装包一键安装脚本实现无网极速部署并注册成服务

目录 简介 安装包下载 安装脚本 服务常用命令 简介 通过一键安装脚本实现 Redis 安装包的无网极速部署,并将其成功注册为系统服务,开机自启。 安装包下载 redis-7.0.8.tar.gzhttp://download.redis.io/releases/redis-7.0.8.tar.gz 安装脚本 修…

Meta Llama 3.3 70B:性能卓越且成本效益的新选择

Meta Llama 3.3 70B:性能卓越且成本效益的新选择 引言 在人工智能领域,大型语言模型一直是研究和应用的热点。Meta公司最近发布了其最新的Llama系列模型——Llama 3.3 70B,这是一个具有70亿参数的生成式AI模型,它在性能上与4050…

idea_maven详解

秒懂Maven maven简介maven安装和配置maven本地配置maven工程的GAVP创建maven工程项目结构说明项目构建说明 Maven依赖管理核心信息配置依赖管理配置依赖信息查询依赖范围设置依赖属性配置依赖下载失败错误解决Build构建配置依赖传递依赖冲突 maven工程继承继承作用应用场景继承…

Linux Ubuntu 安装配置RabbitMQ,springboot使用RabbitMQ

rabbit-Ubuntu 一篇文章学会RabbitMQ 在Ubuntu上查看RabbitMQ状态可以通过多种方式进行,包括使用命令行工具和Web管理界面。以下是一些常用的方法: 1-使用systemctl命令: sudo systemctl start rabbitmq-server sudo systemctl status ra…

LeetCode—189. 轮转数组(中等)

题目描述: 给定一个整数数组 nums,将数组中的元素向右轮转 k 个位置,其中 k 是非负数。 示例1: 输入: nums [1,2,3,4,5,6,7], k 3输出:[5,6,7,1,2,3,4] 解释: 向右轮转 1 步:[7,1,2,3,4,5,6] 向右轮转 2 步:[6,7,1,2,3,4,5] 向…

微信小程序报错:http://159.75.169.224:7300不在以下 request 合法域名列表中,请参考文档

要解决此问题,需打开微信小程序开发者工具进行设置,打开详情-本地设置重新运行,该报错就没有啦

vrrp主备备份

VRRP(Virtual Router Redundancy Protocol,虚拟路由冗余协议)是一种用于实现路由器冗余以提高网络可靠性的协议。以下是对VRRP的详细介绍: 基本概念 VRRP路由器:运行VRRP协议的路由器称为VRRP路由器。虚拟路由器&#…

Selenium:强大的 Web 自动化测试工具

Selenium:强大的 Web 自动化测试工具 在当今的软件开发和测试领域,自动化工具的重要性日益凸显。Selenium 就是一款备受欢迎的 Web 自动化测试工具,它为开发者和测试人员提供了强大的功能和便利。本文将详细介绍 Selenium 是什么&#xff0c…