1-4.时间序列数据建模流程范例

文章最前: 我是Octopus,这个名字来源于我的中文名–章鱼;我热爱编程、热爱算法、热爱开源。所有源码在我的个人github
;这博客是记录我学习的点点滴滴,如果您对 Python、Java、AI、算法有兴趣,可以关注我的动态,一起学习,共同进步。

2020年发生的新冠肺炎疫情灾难给各国人民的生活造成了诸多方面的影响。

有的同学是收入上的,有的同学是感情上的,有的同学是心理上的,还有的同学是体重上的。

本文基于中国2020年3月之前的疫情数据,建立时间序列RNN模型,对中国的新冠肺炎疫情结束时间进行预测。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

import torch 
print("torch.__version__ = ", torch.__version__)
torch.__version__ =  2.0.1

公众号 算法美食屋 回复关键词:pytorch, 获取本项目源码和所用数据集百度云盘下载链接。

import os#mac系统上pytorch和matplotlib在jupyter中同时跑需要更改环境变量
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" 

一,准备数据

本文的数据集取自tushare,获取该数据集的方法参考了以下文章。

《https://zhuanlan.zhihu.com/p/109556102》

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

import numpy as np
import pandas as pd 
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'svg'df = pd.read_csv("./eat_pytorch_datasets/covid-19.csv",sep = "\t")
df.plot(x = "date",y = ["confirmed_num","cured_num","dead_num"],figsize=(10,6))
plt.xticks(rotation=60);

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

dfdata = df.set_index("date")
dfdiff = dfdata.diff(periods=1).dropna()
dfdiff = dfdiff.reset_index("date")dfdiff.plot(x = "date",y = ["confirmed_num","cured_num","dead_num"],figsize=(10,6))
plt.xticks(rotation=60)
dfdiff = dfdiff.drop("date",axis = 1).astype("float32")

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

dfdiff.head()
confirmed_numcured_numdead_num
0457.04.016.0
1688.011.015.0
2769.02.024.0
31771.09.026.0
41459.043.026.0

下面我们通过继承torch.utils.data.Dataset实现自定义时间序列数据集。

torch.utils.data.Dataset是一个抽象类,用户想要加载自定义的数据只需要继承这个类,并且覆写其中的两个方法即可:

  • __len__:实现len(dataset)返回整个数据集的大小。
  • __getitem__:用来获取一些索引的数据,使dataset[i]返回数据集中第i个样本。

不覆写这两个方法会直接返回错误。

import torch 
from torch import nn 
from torch.utils.data import Dataset,DataLoader,TensorDataset#用某日前8天窗口数据作为输入预测该日数据
WINDOW_SIZE = 8class Covid19Dataset(Dataset):def __len__(self):return len(dfdiff) - WINDOW_SIZEdef __getitem__(self,i):x = dfdiff.loc[i:i+WINDOW_SIZE-1,:]feature = torch.tensor(x.values)y = dfdiff.loc[i+WINDOW_SIZE,:]label = torch.tensor(y.values)return (feature,label)ds_train = Covid19Dataset()#数据较小,可以将全部训练数据放入到一个batch中,提升性能
dl_train = DataLoader(ds_train,batch_size = 38)for features,labels in dl_train:break #dl_train同时作为验证集
dl_val = dl_train

二,定义模型

使用Pytorch通常有三种方式构建模型:使用nn.Sequential按层顺序构建模型,继承nn.Module基类构建自定义模型,继承nn.Module基类构建模型并辅助应用模型容器进行封装。

此处选择第二种方式构建模型。

import torch
from torch import nn 
import importlib 
import torchkeras torch.random.seed()class Block(nn.Module):def __init__(self):super(Block,self).__init__()def forward(self,x,x_input):x_out = torch.max((1+x)*x_input[:,-1,:],torch.tensor(0.0))return x_outclass Net(nn.Module):def __init__(self):super(Net, self).__init__()# 3层lstmself.lstm = nn.LSTM(input_size = 3,hidden_size = 3,num_layers = 5,batch_first = True)self.linear = nn.Linear(3,3)self.block = Block()def forward(self,x_input):x = self.lstm(x_input)[0][:,-1,:]x = self.linear(x)y = self.block(x,x_input)return ynet = Net()
print(net)
Net((lstm): LSTM(3, 3, num_layers=5, batch_first=True)(linear): Linear(in_features=3, out_features=3, bias=True)(block): Block()
)
Net((lstm): LSTM(3, 3, num_layers=5, batch_first=True)(linear): Linear(in_features=3, out_features=3, bias=True)(block): Block()
)
from torchkeras import summary
summary(net,input_data=features);
--------------------------------------------------------------------------
Layer (type)                            Output Shape              Param #
==========================================================================
LSTM-1                                    [-1, 8, 3]                  480
Linear-2                                     [-1, 3]                   12
Block-3                                      [-1, 3]                    0
==========================================================================
Total params: 492
Trainable params: 492
Non-trainable params: 0
--------------------------------------------------------------------------
Input size (MB): 0.000069
Forward/backward pass size (MB): 0.000229
Params size (MB): 0.001877
Estimated Total Size (MB): 0.002174
--------------------------------------------------------------------------

三,训练模型

训练Pytorch通常需要用户编写自定义训练循环,训练循环的代码风格因人而异。

有3类典型的训练循环代码风格:脚本形式训练循环,函数形式训练循环,类形式训练循环。

此处我们通过引入torchkeras库中的KerasModel工具来训练模型,无需编写自定义循环。

torchkeras详情: https://github.com/lyhue1991/torchkeras

注:循环神经网络调试较为困难,需要设置多个不同的学习率多次尝试,以取得较好的效果。

from torchmetrics.regression import MeanAbsolutePercentageErrordef mspe(y_pred,y_true):err_percent = (y_true - y_pred)**2/(torch.max(y_true**2,torch.tensor(1e-7)))return torch.mean(err_percent)net = Net() 
loss_fn = mspe
metric_dict = {"mape":MeanAbsolutePercentageError()}optimizer = torch.optim.Adam(net.parameters(), lr=0.03)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.0001)
from torchkeras import KerasModel 
model = KerasModel(net,loss_fn = loss_fn,metrics_dict= metric_dict,optimizer = optimizer,lr_scheduler = lr_scheduler) 
dfhistory = model.fit(train_data=dl_train,val_data=dl_val,epochs=100,ckpt_path='checkpoint',patience=10,monitor='val_loss',mode='min',callbacks=None,plot=True,cpu=True)
[0;31m<<<<<< 🐌 cpu is used >>>>>>[0m

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

18.00% [18/100] [00:02<00:10]
████████████████████100.00% [1/1] [val_loss=0.4363, val_mape=0.5570]
[0;31m<<<<<< val_loss without improvement in 10 epoch,early stopping >>>>>> 
[0m

四,评估模型

评估模型一般要设置验证集或者测试集,由于此例数据较少,我们仅仅可视化损失函数在训练集上的迭代情况。

model.evaluate(dl_val)
100%|█████████████████████████████████| 1/1 [00:00<00:00, 63.91it/s, val_loss=0.384, val_mape=0.505]{'val_loss': 0.38373321294784546, 'val_mape': 0.5048269033432007}

五,使用模型

此处我们使用模型预测疫情结束时间,即 新增确诊病例为0 的时间。

#使用dfresult记录现有数据以及此后预测的疫情数据
dfresult = dfdiff[["confirmed_num","cured_num","dead_num"]].copy()
dfresult.tail()
confirmed_numcured_numdead_num
41143.01681.030.0
4299.01678.028.0
4344.01661.027.0
4440.01535.022.0
4519.01297.017.0
#预测此后1000天的新增走势,将其结果添加到dfresult中
for i in range(1000):arr_input = torch.unsqueeze(torch.from_numpy(dfresult.values[-38:,:]),axis=0)arr_predict = model.forward(arr_input)dfpredict = pd.DataFrame(torch.floor(arr_predict).data.numpy(),columns = dfresult.columns)dfresult = pd.concat([dfresult,dfpredict],ignore_index=True)
dfresult.query("confirmed_num==0").head()# 第50天开始新增确诊降为0,第45天对应3月10日,也就是5天后,即预计3月15日新增确诊降为0
# 注:该预测偏乐观
confirmed_numcured_numdead_num
500.0999.00.0
510.0948.00.0
520.0900.00.0
530.0854.00.0
540.0810.00.0

dfresult.query("cured_num==0").head()
# 第137天开始新增治愈降为0,第45天对应3月10日,也就是大概3个月后,即6月12日左右全部治愈。
# 注: 该预测偏悲观,并且存在问题,如果将每天新增治愈人数加起来,将超过累计确诊人数。
confirmed_numcured_numdead_num
1370.00.00.0
1380.00.00.0
1390.00.00.0
1400.00.00.0
1410.00.00.0

六,保存模型

模型权重保存在了model.ckpt_path路径。

print(model.ckpt_path)
checkpoint
model.load_ckpt('checkpoint') #可以加载权重

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/38857.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

信息学奥赛初赛天天练-41-CSP-J2021基础题-n个数取最大、树的边数、递归、递推、深度优先搜索应用

PDF文档公众号回复关键字:20240701 2021 CSP-J 选择题 单项选择题&#xff08;共15题&#xff0c;每题2分&#xff0c;共计30分&#xff1a;每题有且仅有一个正确选项&#xff09; 4.以比较作为基本运算&#xff0c;在N个数中找出最大数&#xff0c;最坏情况下所需要的最少比…

我在中东做MCN,月赚10万美金

图片&#xff5c;Photo by Ben Koorengevel on Unsplash ©自象限原创 作者丨程心 在迪拜购物中心和世界最高建筑哈利法塔旁的主街上&#xff0c;徐晋已经“蹲”了三个小时&#xff0c;每当遇到穿着时髦的年轻男女&#xff0c;他都会上前询问&#xff0c;有没有意愿成为…

C语言部分复习笔记

1. 指针和数组 数组指针 和 指针数组 int* p1[10]; // 指针数组int (*p2)[10]; // 数组指针 因为 [] 的优先级比 * 高&#xff0c;p先和 [] 结合说明p是一个数组&#xff0c;p先和*结合说明p是一个指针 括号保证p先和*结合&#xff0c;说明p是一个指针变量&#xff0c;然后指…

Web2Code :网页理解和代码生成能力的评估框架

多模态大型语言模型&#xff08;MLLMs&#xff09;在过去几年中取得了爆炸性的增长。利用大型语言模型&#xff08;LLMs&#xff09;中丰富的常识知识&#xff0c;MLLMs在处理和推理各种模态&#xff08;如图像、视频和音频&#xff09;方面表现出色&#xff0c;涵盖了识别、推…

VuePress介绍

从本文开始&#xff0c;动手搭建自己的博客&#xff01;希望读者能跟着一起动手&#xff0c;这样才能真正掌握。 ‍ VuePress 是什么 VuePress 是由 Vue 作者带领团队开发的&#xff0c;非常火&#xff0c;使用的人很多&#xff1b;Vue 框架官网也是用了 VuePress 搭建的。即…

4PCS点云配准算法实现

4PCS点云配准算法的C实现如下&#xff1a; #include <iostream> #include <pcl/io/pcd_io.h> #include <pcl/point_types.h> #include <pcl/common/common.h> #include <pcl/common/distances.h> #include <pcl/common/transforms.h> #in…

php 通过vendor文件 生成还原最新的composer.json

起因&#xff1a;因为历史原因&#xff0c;在本项目中composer.json基本算废了&#xff0c;没法直接使用composer管理扩展&#xff0c;今天尝试修复一下composer.json。 历史文件&#xff0c;可以看出来已经很久没有维护了&#xff0c;我们主要是恢复require的信息 {"na…

基于CNN的股票预测方法【卷积神经网络】

基于机器学习方法的股票预测系列文章目录 一、基于强化学习DQN的股票预测【股票交易】 二、基于CNN的股票预测方法【卷积神经网络】 文章目录 基于机器学习方法的股票预测系列文章目录一、CNN建模原理二、模型搭建三、模型参数的选择&#xff08;1&#xff09;探究window_size…

下代iPhone或回归可拆卸电池,苹果这操作把我看傻了

刚度过一个愉快的周末&#xff0c;苹果又双叒叕摊上事儿了。 iPhone13 系列被曝扎堆电池鼓包了。 早在去年&#xff0c;就有 iPhone13 和 iPhone14 用户反馈过类似的问题&#xff0c;表示在手机仅仅使用了一年多的时间就出现了电池鼓包的情况&#xff0c;而且还把屏幕给撑起来了…

舞会无领导:一种树形动态规划的视角

没有上司的舞会 Ural 大学有 &#x1d441; 名职员&#xff0c;编号为1∼&#x1d441;。 他们的关系就像一棵以校长为根的树&#xff0c;父节点就是子节点的直接上司。 每个职员有一个快乐指数&#xff0c;用整数 &#x1d43b;&#x1d456; 给出&#xff0c;其中1≤&…

校园卡手机卡怎么注销?

校园手机卡的注销流程可以根据不同的运营商和具体情况有所不同&#xff0c;但一般来说&#xff0c;以下是注销校园手机卡的几种常见方式&#xff0c;我将以分点的方式详细解释&#xff1a; 一、线上注销&#xff08;通过手机APP或官方网站&#xff09; 下载并打开对应运营商的…

当年很多跑到美加澳写代码的人现在又移回香港?什么原因?

当年很多跑到美加澳写代码的人现在又移回香港&#xff1f;什么原因&#xff1f; 近年来&#xff0c;确实有部分曾经移民到美国、加拿大、澳大利亚等地的香港居民选择移回香港。这一现象与多种因素相关&#xff0c;主要可以归结为以下几点&#xff1a; 疫情后的环境变化&#…

【STM32】温湿度采集与OLED显示

一、任务要求 1. 学习I2C总线通信协议&#xff0c;使用STM32F103完成基于I2C协议的AHT20温湿度传感器的数据采集&#xff0c;并将采集的温度-湿度值通过串口输出。 任务要求&#xff1a; 1&#xff09;解释什么是“软件I2C”和“硬件I2C”&#xff1f;&#xff08;阅读野火配…

2025第13届常州国际工业装备博览会招商全面启动

常州智造 装备中国|2025第13届常州国际工业装备博览会招商全面启动 2025第13届常州国际工业装备博览会将于2025年4月11-13日在常州西太湖国际博览中心盛大举行&#xff01;目前&#xff0c;各项筹备工作正稳步推进。 60000平米的超大规模、800多家国内外工业装备制造名企将云集…

最细最有条理解析:事件循环(消息循环)是什么?进程与线程的定义、关系与差异

目录 事件循环&#xff1a;引入 一、浏览器的进程模型 1.1、什么是进程&#xff08;Process&#xff09; 1.2、什么是线程&#xff08;Thread&#xff09; 1.3、进程与线程之间的关系联系与区别 二、浏览器有哪些进程和线程 2.1、浏览器的主要进程 ①浏览器进程 ②网络…

ctfshow sqli-libs web561--web568

web561 ?id-1 or 1--?id-1 union select 1,2,3--?id-1 union select 1,(select group_concat(column_name) from information_schema.columns where table_nameflags),3-- Your Username is : id,flag4s?id-1 union select 1,(select group_concat(flag4s) from ctfshow.f…

扩展学习|风险评估和风险管理:回顾其基础上的最新进展

文献来源&#xff1a;[1]Aven, T. (2016). Risk assessment and risk management: Review of recent advances on their foundation. European journal of operational research, 253(1), 1-13. 文章简介&#xff1a;大约30-40年前&#xff0c;风险评估和管理被确立为一个科学领…

数据结构 - C/C++ - 链表

目录 结构特性 内存布局 结构样式 结构拓展 单链表 结构定义 节点关联 插入节点 删除节点 常见操作 双链表 环链表 结构容器 结构设计 结构特性 线性结构的存储方式 顺序存储 - 数组 链式存储 - 链表 线性结构的链式存储是通过任意的存储单元来存储线性…

技术分享:分布式数据库DNS服务器的架构思路

DNS是企业数字化转型的基石。伴随微服务或单元化部署的推广&#xff0c;许多用户也开始采用分布式数据库将原来的单体数据库集群服务架构拆分为大量分布式子服务集群&#xff0c;对应不同的微服务或服务单元。本文将从分布式数据库DNS服务器的架构需求、架构分析两方面入手&…

湖北大学2024年成人高考函授报名专升本市场营销专业介绍

在璀璨的学术殿堂中&#xff0c;湖北大学如同一颗璀璨的明珠&#xff0c;熠熠生辉。为了满足广大社会人士对于继续深造、提升自我、实现职业梦想的渴望&#xff0c;湖北大学特别开设了成人高等继续教育项目&#xff0c;为广大有志之士敞开了一扇通往知识殿堂的大门。 而今&…