Pytorch--3.使用CNN和LSTM对数据进行预测

这个系列前面的文章我们学会了使用全连接层来做简单的回归任务,但是在现实情况里,我们不仅需要做回归,可能还需要做预测工作。同时,我们的数据可能在时空上有着联系,但是简单的全连接层并不能满足我们的需求,所以我们在这篇文章里使用CNN和LSTM来对时间上有联系的数据来进行学习,同时来实现预测的功能。

1.数据集:使用的是kaggle上一个公开的气象数据集(CSV)

有需要的可以去kaggle下载,也可以在评论区留下mail,题主发送过去
在这里插入图片描述

2.导入我们所需要的库和完成前置工作

2.1导入相关的库

torch为人工智能的库,pandas用于数据读取,numpy为张量处理的库,matplotlib为画图库

import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import warnings
import torch.nn as nn
import torch.optim as optim
import random

2.2设置相关配置

我们设置随机种子(方便代码的复现)和警告的忽律(防止出现太多警告看不到代码运行的效果)

warnings.filterwarnings('ignore')
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.manual_seed(99)
np.random.seed(99)
random.seed(99)
print ("随机种子")

2.3数据的读入

pd.read_csv里面的参数为相对位置,即代码和文件要在同一个文件夹下面。使用.head()函数来读一下数据的前几行,保证数据是存在的

train_data = pd.read_csv("LSTM-Multivariate_pollution.csv")
train_data.head()

请添加图片描述
我们来看一下各个值的前2048个数据分布情况(方便挑选数据进行代码测试)
代码里面的pollution可以换成dew,temp等值(也就是上图里面的值),用于观看分布情况。

train_use = train_data["pollution"].values
plt.plot([i for i in range(2048)], pollution[:2048])

pollution:
请添加图片描述
dew:
请添加图片描述
temp:
请添加图片描述
我们可以看到temp属性里面的数据整体呈现上升的趋势,所以我们使用属性为temp的值来进行学习和预测。
首先对数据进行归一化操作(因为值过大的话会导致神经网络损失不降低,同时神经网络难以达到收敛),我们使用minmax归一化后将其打印出来可以看到代码显示的效果

from sklearn.preprocessing import MinMaxScaler
scaler = MinMaxScaler()
train_use = scaler.fit_transform(train_use.reshape(-1, 1))
print ((train_use))                                                                     
print ("归一化处理")

可以看到归一化后的结果如下图所示:
在这里插入图片描述
我们将数据进行处理,默认使用30天的数据对第31天的数据进行预测,同时将数据进行升维处理,使得输入的训练数据为3维度,分别为batchsize,每次所需要的数据(30个数据),和数据的输入维度(1维度)

def split_data(data, time_step = 30):dataX = []dataY = []for i in range(len(data) - time_step):dataX.append(data[i:i + time_step])dataY.append(data[i + time_step])dataX = np.array(dataX).reshape(len(dataX), time_step, -1)dataY = np.array(dataY)return dataX, dataY

进行数据处理后,获得了可以训练的数据和标签

datax,datay = split_data(train_use, 30)
print ((datay))

结果如下:
请添加图片描述

紧接着我们划分训练集和测试集,默认为80%的数据用于做训练集,20%的数据用于做测试集,shuffle表示是否要将数据进行打乱,以此来测试训练效果

def train_test_split(dataX,datay,shuffle = True,percentage = 0.8):if shuffle:random_num = [i for i in range(len(dataX))]np.random.shuffle(random_num)dataX = dataX[random_num]datay = datay[random_num]split_num = int(len(dataX)*percentage)train_X = dataX[:split_num]train_y = datay[:split_num]testX = dataX[split_num:]testy = datay[split_num:]return train_X, train_y, testX, testy

获取我们的训练数据和测试数据,同时把源数据保存到X_train和y_train里面,方便以后对网络的性能进行评比。

train_X, train_y, testx,testy = train_test_split(datax,datay,False,0.8)
print (type(testx))
print("datax的形状为{},dataY的形状为{}".format(train_X.shape, train_y.shape))
X_train = train_X
y_train = train_y

定义我们的自定义网络

class CNN_LSTM(nn.Module):def __init__(self, conv_input, input_size, hidden_size, num_layers, output_size):super(CNN_LSTM, self).__init__()self.hidden_size = hidden_sizeself.num_layers = num_layersself.conv = nn.Conv1d(conv_input, conv_input, 1)self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first = True)self.fc = nn.Linear(hidden_size, output_size)def forward(self, x):x = self.conv(x)h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size)c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size)out, _= self.lstm(x,(h0,c0))out = self.fc(out[:,-1,:])return out

设置我们网络训练所需要的参数

test_X1 = torch.Tensor(testx)
test_y1 = torch.Tensor(testy)input_size = 1
conv_input = 30
hidden_size = 64
num_layers = 2output_size = 1model = CNN_LSTM(conv_input, input_size, hidden_size, num_layers,output_size)num_epoch = 1000
batch_size = 4optimizer = optim.Adam(model.parameters(), lr = 0.0001, betas=(0.5, 0.999))criterion = nn.MSELoss()
#print ((torch.Tensor(train_X[:batch_size])))

开始运行代码:

train_losses = []
test_losses = []
for epoch in range(num_epoch):random_num = [i for i in range(len(train_X))]np.random.shuffle(random_num)train_X = train_X[random_num]train_y = train_y[random_num]train_x1 = torch.Tensor(train_X[:batch_size])train_y1 = torch.Tensor(train_y[:batch_size])model.train()optimizer.zero_grad()output = model(train_x1)train_loss = criterion(output, train_y1)train_loss.backward()optimizer.step()if epoch%50 == 0 :model.eval()with torch.no_grad():output = model(test_X1)test_loss = criterion(output, test_y1)train_losses.append(train_loss)test_losses.append(test_loss)print("epoch{},train_loss:{},test_loss:{}".format(epoch, train_loss, test_loss))

在这里插入图片描述

自己手写一个mse计算函数(直接调库也可以),什么是mse?(均方误差,均方误差越小说明模型拟合的越好)

def mse(pred_y, true_y):return np.mean((pred_y - true_y) **2)

然后我们对模型进行测试,观察mse的值

train_X1 = torch.Tensor(X_train)
train_pred = model(train_X1).detach().numpy()
test_pred = model(test_X1).detach().numpy()pred_y = np.concatenate((train_pred, test_pred))
pred_y = scaler.inverse_transform(pred_y).T[0]true_y = np.concatenate((y_train, testy))
#print (true_y)
true_y = scaler.inverse_transform(true_y).T[0]
#print (true_y)
print (f"mse(pred_y, true_y):{mse(pred_y, true_y)}")
##print (pred_y)

在这里插入图片描述

我们取前2048个值来看我们的预测的情况(因为数据有几万条,为了避免图形太过密集难以看出效果,所以我们只采用前2048个值来进行展示)

plt.title("CNN_LSTM")
x = [i for i in range(2048)]
plt.plot(x, pred_y[:2048], marker = "o", markersize =1, label="pred_y",color=(1, 0, 0))
plt.plot(x, true_y[:2048], marker = "x", markersize=1, label="true_y",color=(0, 0, 1))
plt.legend()
plt.show()

可以看出来,已经学习到了基本的上升趋势的
在这里插入图片描述
我们将两个图拆开来看,看到前8192个点的值,可以看到已经获得到了相对应的趋势。
请添加图片描述
在这里插入图片描述

码字不易,写代码不易,点个赞再走把

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/117528.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Docker】Docker学习之一:离线安装Docker步骤

前言:基于Ubuntu Jammy 22.04 (LTS)版本安装和测试 1、Docker安装 1.1、离线安装 步骤一:官网下载 docker 安装包 wget https://download.docker.com/linux/static/stable/x86_64/docker-24.0.6.tgz步骤二:解压安装包; tar -zxvf docker…

谈谈Net-SNMP软件

Net-SNMP是一个开源的SNMP软件套件,它提供了SNMP代理(snmpd)和SNMP工具(如snmpget、snmpwalk等),可以用于监控和管理网络设备。 Net-SNMP最初是从UC Davis的SNMP软件衍生而来,现在已经成为广泛…

小程序设计基本微信小程序的校园生活助手系统

项目介绍 通篇文章的撰写基础是实际的应用需要,然后在架构系统之前全面复习大学所修习的相关知识以及网络提供的技术应用教程,以校园生活助手系统的实际应用需要出发,架构系统来改善现校园生活助手系统工作流程繁琐等问题。不仅如此以操作者…

纺织工厂数字孪生3D可视化管理平台,推动纺织产业数字化转型

近年来,我国加快数字化发展战略部署,全面推进制造业数字化转型,促进数字经济与实体经济深度融合。以数字孪生、物联网、云计算、人工智能为代表的数字技术发挥重要作用。聚焦数字孪生智能工厂可视化平台,推动纺织制造业数字化转型…

【Java集合类面试十八】、ConcurrentHashMap是怎么分段分组的?

文章底部有个人公众号:热爱技术的小郑。主要分享开发知识、学习资料、毕业设计指导等。有兴趣的可以关注一下。为何分享? 踩过的坑没必要让别人在再踩,自己复盘也能加深记忆。利己利人、所谓双赢。 面试官:ConcurrentHashMap是怎么…

手把手教你在项目中引入Excel报表组件

摘要:本文由葡萄城技术团队原创并首发。转载请注明出处:葡萄城官网,葡萄城为开发者提供专业的开发工具、解决方案和服务,赋能开发者。 前言 GrapeCity Documents for Excel(以下简称GcExcel)是葡萄城公司的…

【TES641】基于VU13P FPGA的4路FMC接口基带信号处理平台

板卡概述 TES641是一款基于Virtex UltraScale系列FPGA的高性能4路FMC接口基带信号处理平台,该平台采用1片Xilinx的Virtex UltraScale系列FPGA XCVU13P作为信号实时处理单元,该板卡具有4个FMC子卡接口(其中有2个为FMC接口)&#x…

Vue3.3指北(二)

Vue3.3指北 Vue31、组件基础1.1、全局组件1.2、局部组件1.3、组件的命名1.4、组件的数据存放1.5、组件标签化 2、父组件向子组件传递数据2.1、props2.2、动态props2.3、props传数组2.4、props传对象2.4.1、默认值和必传值 3、子组件向父组件传递数据4、父子组件互相访问4.1、父…

03初始Docker

一、初始Docker 1.什么是Docker 问题 ①大型项目组件复杂,运行环境复杂,部署时依赖复杂,出现兼容性问题。 ②开发,测试,生产环境有差异。不同的环境操作系统不同 解决 ①Docket将应用、依赖、函数库、配置一起打…

聚观早报 | vivo Y100官宣;极氪001 FR将上市

【聚观365】10月25日消息 vivo Y100官宣 一极氪001 FR将上市 特斯拉加速扩张 苹果扩大招聘力度 小米澎湃OS实现历史性跨越 vivo Y100官宣 vivo Y系列是vivo存在比较久的入门系列,主打千元价位的线下市场,在消费者中有着不错的口碑。而不久前一款型…

代码随想录算法训练营第二十九天 | 回溯算法总结

​ 代码随想录算法训练营第二十九天 | 回溯算法总结 1. 组合问题 1.1 组合问题 在77. 组合中,我们开始用回溯法解决第一道题目:组合问题。 回溯算法跟k层for循环同样是暴力解法,为什么用回溯呢?回溯法的魅力,用递…

ubuntu tools

1 cloc calculate lines of your code sudo apt-get install cloccloc ./file

【MySQL架构篇】SQL执行流程与缓冲池

文章目录 1. SQL执行流程2. 数据库缓冲池(Buffer Pool)2.1 缓冲池概述2.2 缓冲池如何读取数据2.3 查看和设置缓冲池的大小2.4 多个Buffer Pool实例2.5 引申问题 1. SQL执行流程 查询缓存:因为查询效率往往不高,所以在MySQL8.0之后就抛弃了这个功能解析器…

lvs+keepalived: 高可用集群

lvskeepalived: 高可用集群 keepalived为lvs应运而生的高可用服务。lvs的调度器无法做高可用,于是keepalived软件。实现的是调度器的高可用。 但是:keepalived不是专门为集群服务的,也可以做其他服务器的高可用。 lvs的高可用集群&#xf…

力扣第1005题 K 次取反后最大化的数组和 c++ 贪心 双思维

题目 1005. K 次取反后最大化的数组和 简单 相关标签 贪心 数组 排序 给你一个整数数组 nums 和一个整数 k ,按以下方法修改该数组: 选择某个下标 i 并将 nums[i] 替换为 -nums[i] 。 重复这个过程恰好 k 次。可以多次选择同一个下标 i 。 以…

asp.net乡村旅游管理系统VS开发sqlserver数据库web结构c#编程Microsoft Visual Studio

一、源码特点 asp.net乡村旅游管理系统是一套完善的web设计管理系统系统,系统具有完整的源代码和数据库,系统主要采用B/S模式开发。开发环境为vs2010,数据库为sqlserver2008,使用c# 语言开发 asp.net乡村旅游管理系统 二、…

【在英伟达nvidia的jetson-orin-nx和PC电脑ubuntu20.04上-装配ESP32开发调试环境-基础测试】

【在英伟达nvidia的jetson-orin-nx和PC电脑ubuntu20.04上-装配ESP32开发调试环境-基础测试】 1、概述2、实验环境3、 物品说明4、参考资料与自我总结5、实验过程1、创建目录2、克隆下载文件3、 拉取子目录安装和交叉编译工具链等其他工具4、添加环境变量6、将样例文件拷贝到桌面…

【分布式技术专题】「分布式技术架构」MySQL数据同步到Elasticsearch之N种方案解析,实现高效数据同步

MySQL数据同步到Elasticsearch之N种方案解析,实现高效数据同步 前提介绍MySQL和ElasticSearch的同步双写优点缺点针对于缺点补充优化方案 MySQL和ElasticSearch的异步双写优点缺点 定时延时写入ElasticSearch数据库机制优点缺点 开源和成熟的数据迁移工具选型Logsta…

软考系列(系统架构师)- 2016年系统架构师软考案例分析考点

试题一 软件架构(质量属性、架构风格对比、根据描述填空) 试题二 系统开发(用例图参与者、用例关系、类图关系) 学生、教师、管理员、时间、打印机【问题2】(7分) 用例是对系统行为的动态描述,用…

ant框架下 a-input-number组件的宽度问题

如图所示,在使用a-input-number组件时虽然设置了宽度但是没有生效,加上了一个!important就好了: