【项目实践】基于LSTM的一维数据扩展与预测

基于LSTM的一维数据拟合扩展

一、引(fei)言(hua)

我在做Sri Lanka生态系统服务价值计算时,中间遇到了一点小问题。从世界粮农组织(FAO)上获得Sri Lanka主要农作物产量和价格数据时,其中的主要作物Sorghum仅有2001-2006年的数据,而Millet只有2001-2005,2020-2021这样的间断数据。虽然说可以直接剔除这种过分缺失的数据,但这无疑会对生态因子的计算造成重大影响。所以我想要不要整个函数把他拟合一下,刚好Maize和Rice有2001-2021的完备数据,于是,这个文档就这样诞生了。


二、数据

数据来自FAO,考虑到可能有同学想要跟着尝试一下,这里给出用到的数据。

作物产量

作物价格

2.1 数据探查

我们读取数据,并进行简单的统计量查看。如果要进一步深入研究数据分布及可视化,可以看看我的这篇文章

import pandas as pdpath=r"YourPath"yield_=pd.read_csv(path+r"\yield.csv")
pp_=pd.read_csv(path+r"\Producer Prices.csv")
yield_.head()

在这里插入图片描述

需要用到的属性只有Item,Year,Unit,Value

所以我们做这样的处理:

yield_=yield_[["Item","Year","Unit","Value"]]

可以看到有些数据是从1961年开始的,太旧了就不用了,我们从2001年开始。

yield_=yield_[yield_["Year"]>2000]

同样,我们来看看pp_的情况:

pp_.head()

在这里插入图片描述

pp_=pp_[["Item","Year","Value","Element"]]
pp_=pp_[pp_["Year"]>2000]

实际上,在这个数据里,产量已经没有问题了。我们只需要做一个简单的处理:

yield_.groupby("Item").mean()["Value"]/10 #转为千克

在这里插入图片描述

便可拿到每种作物近二十年的平均产量。

好了现在大问题出现在价值上,我们从下往上看就知道了:

pp_.tail(10)

在这里插入图片描述

高粱只有2006年的,那有没有办法利用现成的数据将其扩展呢?

实际上,这类拟合问题有很多种解决方案,但是本问题涉及到时间,之前时间段的因子,以及可能的周期性,都会增加拟合的复杂性。所以,在这里我们采用LSTM来填充数据。


三、模型构建

在本小节,我们将比较传统一维CNN与RNN在结果上的异同。

一般做一维RNN时,可以指定一个时间窗口,比如用2006,2007,2008年的数据,推理2009年的数据,用2007,2008,2009年推理2010年。

我们现在要用之前处理好的pp_c数据中的玉米产量,来预测高粱产量。所以第一步就是将其转化为torch接受的格式。

别忘记导入模块:

import torch
import torch.nn as nn
from torch.nn import functional as F
x=pp_c[pp_c['Item']=="Maize (corn)"]['Value']
x=torch.FloatTensor(x)

之前写数据迭代器的时候,除了可以继承自torch.utils.data.DataLoader,也可以是任意的可迭代对象。这里我们可以简单的设置一个类:

# 设置迭代器
class MyDataSet(object):def __init__(self,seq,ws=6):# ws是滑动窗口大小self.ori=[i for i in seq[:ws]]self.label=[i for i in seq[ws:]]self.reset()self.ws=wsdef set(self,dpi):# 添加数据self.x.append(dpi)def reset(self):# 初始化self.x=self.ori[:]def get(self,idx):return self.x[idx:idx+self.ws],self.label[idx]def __len__(self):return len(self.x)

哦这边提一下,有两种方式,一种是用原始数据做预测,一种是用预测数据做预测,可能有点抽象,下面举个例子。

假设 A = [ a 1 , a 2 , a 3 , a 4 , a 5 , a 6 ] A=[a1,a2,a3,a4,a5,a6] A=[a1,a2,a3,a4,a5,a6],时间窗口大小为3。

用原始数据做预测,那么输入值为: a 1 , a 2 , a 3 a1,a2,a3 a1,a2,a3,得到的结果将与 a 4 a4 a4做比较。下一轮输入为 a 2 , a 3 , a 4 a2,a3,a4 a2,a3,a4,得到的结果将与 a 5 a5 a5做比较。

而用预测的数据做预测,第一轮输入值为 a 1 , a 2 , a 3 a1,a2,a3 a1,a2,a3,得到的结果是 b 4 b4 b4,在与 a 4 a4 a4做比较后,下一轮的输入为 a 2 , a 3 , b 4 a2,a3,b4 a2,a3,b4,会出现如下情况:

输入数据为 b 4 , b 5 , b 6 b4,b5,b6 b4,b5,b6

我们现在举的例子是用预测的数据做预测。当然,最后也会给出一个用原始数据做预测的版本,那个版本相对简单。

ws=6 # 全局时间窗口
train_data=MyDataSet(x,ws)

网络的架构如下:

   
class Net3(nn.Module):def __init__(self,in_features=54,n_hidden1=128,n_hidden2=256,n_hidden3=512,out_features=7):super(Net3, self).__init__()self.flatten=nn.Flatten()self.hidden1=nn.Sequential(nn.Linear(in_features,n_hidden1,False),nn.ReLU())self.hidden2=nn.Sequential(nn.Linear(n_hidden1,n_hidden2),nn.ReLU())self.hidden3=nn.Sequential(nn.Linear(n_hidden2,n_hidden3),nn.ReLU())self.out=nn.Sequential(nn.Linear(n_hidden3,out_features))def forward(self,x):x=self.flatten(x)x=self.hidden2(self.hidden1(x))x=self.hidden3(x)return self.out(x)class CNN(nn.Module):def __init__(self, output_dim=1,ws=6):super(CNN, self).__init__()self.relu = nn.ReLU(inplace=True)self.conv1 = nn.Conv1d(ws, 64, 1)self.lr = nn.LeakyReLU(inplace=True)self.conv2 = nn.Conv1d(64, 128, 1)self.bn1, self.bn2 = nn.BatchNorm1d(64), nn.BatchNorm1d(128)self.bn3, self.bn4 = nn.BatchNorm1d(1024), nn.BatchNorm1d(128)self.flatten = nn.Flatten()self.lstm1 = nn.LSTM(128, 1024)self.lstm2 = nn.LSTM(1024, 256)self.lstm3=nn.LSTM(256,512)self.fc = nn.Linear(512, 512)self.fc4=nn.Linear(512,256)self.fc1 = nn.Linear(256, 64)self.fc3 = nn.Linear(64, output_dim)@staticmethoddef reS(x):return x.reshape(-1, x.shape[-1], x.shape[-2])def forward(self, x):x = self.reS(x)x = self.conv1(x) x = self.lr(x)x = self.conv2(x) x = self.lr(x)x = self.flatten(x)# LSTM部分x, h = self.lstm1(x)x, h = self.lstm2(x)x,h=self.lstm3(x)x, _ = hx = self.fc(x.reshape(-1, ))x = self.relu(x)x = self.fc4(x)x = self.relu(x)x = self.fc1(x)x = self.relu(x)x = self.fc3(x)return x

Net3主要是一维卷积,CNN加入了LSTM结构。至于名字,是随便取的…跟内容并无关系。


def Train(model,train_data,seed=1):device="cuda" if torch.cuda.is_available() else "cpu"model=model.to(device)Mloss=100000path=r"YourPath\%s.pth"%seed# 设置损失函数,这里使用的是均方误差损失criterion = nn.MSELoss()# 设置优化函数和学习率lroptimizer=torch.optim.Adam(model.parameters(),lr=1e-5,betas=(0.9,0.99),eps=1e-07,weight_decay=0)# 设置训练周期epochs =3000criterion=criterion.to(device)model.train()for epoch in range(epochs):total_loss=0for i in range(len(x)-ws):# 每次更新参数前都梯度归零和初始化seq,y_train=train_data.get(i) # 从我们的数据集中拿出数据seq,y_train=torch.FloatTensor(seq),torch.FloatTensor([y_train])seq=seq.unsqueeze(dim=0)seq,y_train=seq.to(device),y_train.to(device)optimizer.zero_grad()# 注意这里要对样本进行reshape,# 转换成conv1d的input size(batch size, channel, series length)y_pred = model(seq)loss = criterion(y_pred, y_train)loss.backward()train_data.set(y_pred.to("cpu").item()) # 再放入预测数据optimizer.step()total_loss+=losstrain_data.reset()if total_loss.tolist()<Mloss:Mloss=total_loss.tolist()torch.save(model.state_dict(),path)print("Saving")print(f'Epoch: {epoch+1:2} Mean Loss: {total_loss.tolist()/len(train_data):10.8f}')return model

正常训练就OK

d=CNN(ws=ws)
Train(d,train_data,4)

在这里插入图片描述

平均损失在10点左右,还有很大优化空间。当然我们这里只是举个非常简单的例子,就是个baseline

checkpoint=torch.load(r"YourPath\4.pth")
d.load_state_dict(checkpoint) # 加载最佳参数
d.to("cpu")

四、结果可视化

我们这里用到Pyechart进行可视化。

from pyecharts.charts import *
from pyecharts import options as opts
from pyecharts.globals import CurrentConfig
pre,ppre=[i.item() for i in x[:ws]],[]
# pre 是用原始数据做预测
# ppre 用预测数据做预测
for i in range(len(x)-ws+1):ppre.append(d(torch.FloatTensor(x[i:i+ws]).unsqueeze(dim=0)))pre.append(d(torch.FloatTensor(pre[-ws:]).unsqueeze(dim=0)).item())
l=Line()
l.add_xaxis([i for i in range(len(x))])
l.add_yaxis("Original Data",x.tolist())
l.add_yaxis("Pred Data(Using Raw Datas)",x[:ws].tolist()+[i.item() for i in ppre])
l.add_yaxis("Pred Data(Using Pred Datas)",pre)
l.set_series_opts(label_opts=opts.LabelOpts(is_show=False))
l.set_global_opts(title_opts=opts.TitleOpts(title='LSTM CNN'))l.render_notebook()

根据时间窗口的不同,可以得到不同的结果。

ws=4

在这里插入图片描述

ws=5

在这里插入图片描述

ws=6

在这里插入图片描述

从结果上来看,时间窗口越大越好。但是这里我们只能到六了,再大就不礼貌了。(高粱只有六个节点的数据)。

至于验证,我们可以选Rice做验证:

x=torch.FloatTensor(pp_c[pp_c['Item']=="Rice"]['Value'].tolist())
pre,ppre=[i.item() for i in x[:ws]],[]
for i in range(len(x)-ws+1):ppre.append(d(torch.FloatTensor(x[i:i+ws]).unsqueeze(dim=0)))pre.append(d(torch.FloatTensor(pre[-ws:]).unsqueeze(dim=0)).item())
l=Line()
l.add_xaxis([i for i in range(len(x))])
l.add_yaxis("Original Data",x.tolist())
l.add_yaxis("Pred Data(Using Raw Datas)",x[:ws].tolist()+[i.item() for i in ppre])
l.add_yaxis("Pred Data(Using Pred Datas)",pre)
l.set_series_opts(label_opts=opts.LabelOpts(is_show=False))
l.set_global_opts(title_opts=opts.TitleOpts(title='LSTM CNN'))l.render_notebook()

在这里插入图片描述

可以发现,用预测做预测的结果,基本上不会差太多,那也就意味着,我们可以对高粱进行预测啦!不过在这之前,我们可以看看用原始数据做训练的结果:

在这里插入图片描述

时间窗口一样为6,可以看到在黑线贴合的非常好,但是面对大量缺失的数据,精度就远不如用预测数据做预测的结果了。

此外,这是用CNN做的结果

在这里插入图片描述

我们可以发现LSTM的波动要比CNN好,CNN后面死水一潭,应该是梯度消失导致的,前面信息没有了,后面信息又是自个构造的,这就导致了到后面变成了线性情况。

那么最后的最后,就是预测高粱产量了:

pre_data=pp_c[pp_c['Item']=='Sorghum']['Value'].tolist()
l=pre_data[:]
for i in range(len(x)-ws+1):l.append(d(torch.FloatTensor(l[-ws:]).unsqueeze(dim=0)).item())
L=Line()
L.add_xaxis([i for i in range(len(x))])
L.add_yaxis("Pred",l)
L.set_series_opts(label_opts=opts.LabelOpts(is_show=False))
L.set_global_opts(title_opts=opts.TitleOpts(title='sorghum production forecasts'))L.render_notebook()
l.to_csv("path")

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/46698.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

算法通关村第4关【黄金】| 表达式问题

1. 计算器问题 思路&#xff1a;此题不考虑括号和负数情况&#xff0c;单纯使用栈即可解决。注意的是数字可能是多位数需要保留完整的num&#xff0c; 保留数字的前缀符号&#xff0c;当碰到加号&#xff0c;存进去&#xff1b;当碰到减号&#xff0c;存相反数进去&#xff1b;…

Apinto 网关进阶教程,插件开发入门指南

Apinto 是基于Go语言&#xff0c;由 Eolink 自主研发的一款高性能、可扩展、易维护的云原生 API 网关。Apinto 能够帮助用户简单、快速、低成本、低风险地实现&#xff1a;系统微服务化、系统集成、向合作伙伴、开发者开放功能和数据。 通过 Apinto&#xff0c;企业能够专注于…

【LeetCode-中等题】15. 三数之和

题目 题解一&#xff1a;双指针法 图解参考链接&#xff1a;画解算法&#xff1a;15. 三数之和 详解参考代码随想录讲的非常好 梦破碎的地方&#xff01;| LeetCode&#xff1a;15.三数之和 代码&#xff1a; class Solution {public List<List<Integer>> thre…

Codeforces Round 893 (Div. 2) A ~ C

比赛链接 A. Buttons 博弈、最优策略一定是先去按都能按的按钮&#xff0c;按完之后再按自己的。 #include<bits/stdc.h> #define IOS ios::sync_with_stdio(0);cin.tie(0);cout.tie(0); #define endl \nusing namespace std;typedef pair<int, int> PII; typede…

jstack(Stack Trace for Java)Java堆栈跟踪工具

jstack&#xff08;Stack Trace for Java&#xff09;Java堆栈跟踪工具 jstack&#xff08;Stack Trace for Java&#xff09;命令用于生成虚拟机当前时刻的线程快照&#xff08;一般称为threaddump或者javacore文件&#xff09;。 线程快照就是当前虚拟机内每一条线程正在执…

动手学深度学习-pytorch版本(二):线性神经网络

参考引用 动手学深度学习 1. 线性神经网络 神经网络的整个训练过程&#xff0c;包括: 定义简单的神经网络架构、数据处理、指定损失函数和如何训练模型。经典统计学习技术中的线性回归和 softmax 回归可以视为线性神经网络 1.1 线性回归 回归 (regression) 是能为一个或多个…

Linux系统的目录结构

file system hierarchy standard文件系统层级标准&#xff0c;定义了在类Unix系统中的目录结构和目录内容。 即让用户了解到已安装软件通常放置于哪个目录下。 Linux目录结构的特点 使用树形目录结构来组织和管理文件。 整个系统只有一个根目录&#xff08;树根&#xff09;&a…

记录几个Hudi Flink使用问题及解决方法

前言 如题&#xff0c;记录几个Hudi Flink使用问题&#xff0c;学习和使用Hudi Flink有一段时间&#xff0c;虽然目前用的还不够深入&#xff0c;但是目前也遇到了几个问题&#xff0c;现在将遇到的这几个问题以及解决方式记录一下 版本 Flink 1.15.4Hudi 0.13.0 流写 流写…

Flink之时间语义

Flink之时间语义 简介 Flink中时间语义可以说是最重要的一个概念了,这里就说一下关于时间语义的机制,我们下看一下下面的表格,简单了解一下 时间定义processing time处理时间,也就是现实世界的时间,或者说代码执行时,服务器的时间event time事件时间,就是事件数据中所带的时…

nginx代理webSocket链接响应403

一、场景 使用nginx代理webSocket链接&#xff0c;nginx响应403 1、nginx访问日志响应403 [18/Aug/2023:09:56:36 0800] "GET /FS_WEB_ASS/webim_api/socket/message HTTP/1.1" 403 5 "-" "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit…

【数据结构】循环队列

&#x1f490; &#x1f338; &#x1f337; &#x1f340; &#x1f339; &#x1f33b; &#x1f33a; &#x1f341; &#x1f343; &#x1f342; &#x1f33f; &#x1f344;&#x1f35d; &#x1f35b; &#x1f364; &#x1f4c3;个人主页 &#xff1a;阿然成长日记 …

浏览器 - 事件循环机制详解

目录 1&#xff0c;浏览器进程模型进程线程浏览器的进程和线程1&#xff0c;浏览器进程2&#xff0c;网络进程3&#xff0c;渲染进程 2&#xff0c;渲染主线程事件循环异步同步 JS 为什么会阻塞渲染任务优先级 3&#xff0c;常见面试题1&#xff0c;如何理解 js 的异步2&#x…

❤ Vue工作常用的一些动态数据和方法处理

❤ Vue工作常用的一些动态数据和方法处理 &#xff08;1&#xff09;动态拼接相对路径结尾的svg 错误写法一 ❌ 正确写法 &#x1f646; <img :src"require(/assets//amazon/svg/homemenu${index}.svg)" style"height: 20px;display: block;margin: 0 au…

关于视频监控平台EasyCVR视频汇聚平台建设“明厨亮灶”具体实施方案以及应用

一、方案背景 近几年来&#xff0c;餐饮行业的食品安全、食品卫生等新闻频频发生&#xff0c;比如某火锅店、某网红奶茶&#xff0c;食材以次充好、后厨卫生被爆堪忧&#xff0c;种种问题引起大众关注和热议。这些负面新闻不仅让餐饮门店的品牌口碑暴跌&#xff0c;附带的连锁…

[JavaWeb]【二】Vue Ajax Elemnet Vue路由打包部署

目录 一 什么是Vue 1.1 Vue快速入门 1.2 常用指令 1.2.1 v-bind && v-model 1.2.2 v-on 1.2.3 v-if && v-show 1.2.4 v-for 1.2.5 案例 1.3 生命周期 二 Ajax 2.1 Ajax介绍 2.2 同步与异步 2.3 原生Ajax&#xff08;繁琐&#xff0c;过时了&#xff09…

手机技巧:分享五个非常实用的生活类APP

目录 1、我的桌面iScreen-桌面美化神器 2.Just Rain-创意听雨声APP 3.得言-美文句子神器 4、微手帐 5、暗盒-隐私保护神器 今天给大家整理5个非常实用的实用APP软件&#xff0c;感兴趣的朋友可以下载试试&#xff01; 1、我的桌面iScreen-桌面美化神器 我的桌面iScreen是一…

[uni-app] uview封装Popup组件,处理props及v-model的传值问题

文章目录 需求及效果遇到的问题解决的办法偷懒的写法 需求及效果 uView(1.x版本)中, 有Pop弹出层的组件, 现在有个需求是,进行简单封装,有些通用的设置不想每次都写(比如 :mask-custom-style"{background: rgba(0, 0, 0, 0.7)}"这种) 然后内部内容交给插槽去自己随…

系统架构设计专业技能 · 系统工程与系统性能

系列文章目录 系统架构设计专业技能 网络技术&#xff08;三&#xff09; 系统架构设计专业技能 系统安全分析与设计&#xff08;四&#xff09;【系统架构设计师】 系统架构设计高级技能 软件架构设计&#xff08;一&#xff09;【系统架构设计师】 系统架构设计高级技能 …

2023年上半年软件设计师下午真题及答案解析

试题一(15分) 随着农业领域科学种植的发展&#xff0c;需要对农业基地及农事进行信息化管理&#xff0c;为租户和农户等人员提供种植相关服务&#xff0c;现欲开发农事管理服务平台&#xff0c;其主要功能是&#xff1a; (1)人员管理&#xff1a;平台管理员管理租户&#xff…

​Redis概述

目录 Redis - 概述 使用场景 如何安装 Window 下安装 Linux 下安装 docker直接进行安装 下载Redis镜像 Redis启动检查常用命令 Redis - 概述 redis是一款高性能的开源NOSQL系列的非关系型数据库,Redis是用C语言开发的一个开源的高键值对(key value)数据库,官方提供测试…