时间序列模型在LSTM中的特征输入

这里写目录标题

  • 前言
  • LSTM的输入组成
    • 时间步
    • 例子
  • 实际代码解读
    • 特征提取
    • 处理成dataloader格式(用于输入到模型当中)
    • 对应到lstm的模型创建代码
  • 总结


前言

本文章将帮助理解如何将一个时间序列的各种特征(年月日的时间特征,滚动窗口滞后项等时变特征输入)输入到lstm模型(关键在于时间步)中,并给出使用pytorch进行特征输入的代码实例


LSTM的输入组成

包含三个部分

  • 样本数:输入样本数量
  • 时间步:一共使用多少个时间节点的数据进行预测(重点部分)
  • 特征数:某一个数据所具备的特征

时间步

时间步,具体指模型在一个样本中可以看到的时间序列数据的长度,即模型每次处理多少个连续的时间点
其实就是使用前(包括自己)n天的数据进行预测
由于时间步会使用前n天(类似于滞后的想法),所以使用n天时间步会导致样本缩小,所以实际输入(input_size)公式为
样本数量 − 时间步 + 1 样本数量 - 时间步 + 1 样本数量时间步+1 (要记得把自己算上哦)

时间步与滞后比较像,通过这样的对比帮助更好理解时间步:
滞后:取过去n期标志值作为当期特征
时间步:取算上自己的过去n期所有特征作为当期特征

例子

假设一个时间序列[1, 2, 3],在不考虑时间的情况下

  • 假设时间步为1,输入到lstm模型当中的_x应该为
    一个三维列表,其大小为
    (3, 1, 1),(样本数,时间步,特征值)

[
 [[1]],
 [[2]],
 [[3]]
]

  • 假设时间步为2,此时的输入大小则为(3, 2, 1)

[
 [[1], [2]],
 [[2], [3]]
]


实际代码解读

这里可以找gpt捏造一个数据,重点是理解过程中的数据处理
先简单对一个时间序列进行特征提取

特征提取

小贴士:LSTM对数据大小敏感,推荐优先进行

  • 归一化(数值大小关系)
  • 独热编码(没有大小和周期之分)
  • 正余弦编码(有周期不存在数值大小的特征)
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import torch
from tensorflow.python.keras.backend import dtype
from torch import nn, optim
from sklearn.preprocessing import MinMaxScaler
from torch.utils.data import DataLoader, TensorDatasettorch.random.manual_seed(42)
pd.set_option('display.max_columns', 500)
data = pd.read_csv('real_temp.csv')  
# 找gpt捏造的数据,其由两列组成,一列Date,一列Sales,时间范围为2023-01-01到2023-12-31
# 其中设置前十一个月用于训练,十二月用于测试
data["Date"] = pd.to_datetime(data["Date"])  # 转换为时间列
dt = data["Date"].dt  # 设置时间接口
data["month"] = dt.month  
data["day"] = dt.day
data["weekday"] = dt.weekdaydef SinCosScale(name, round):"""正余弦化处理,公式为sin(2Π * 值 / 周期),cos(2Π * 值 / 周期)"""data[f"sin_{name}"] = np.sin(2 * np.pi * data[name] / round)  data[f"cos_{name}"] = np.cos(2 * np.pi * data[name] / round)
SinCosScale("weekday", 7)
SinCosScale("month", 12)values = ["Sales", "day", "shift", "windowmean", "windowstd"]  # 将具有大小之分的数据进行归一化
# 为什么日期也包括在里面呢,因为不同月份对应的day其实不同,这里不太好处理
data["shift"] = data["Sales"].shift(7)  # 滞后特征
rolling = data["Sales"].rolling(7)  # 滚动窗口
data["windowmean"] = rolling.mean()  # 窗口内的均值
data["windowstd"] = rolling.std()  # 窗口内的标准差
scaler = MinMaxScaler()
scalar_test = MinMaxScaler()  # 用于专门转换y值的归一化模型
scalar_test.fit(data["Sales"].values.reshape(-1, 1))  # 输入y值进行训练,方便后续inverse
data[values] = scaler.fit_transform(data[values])
data.dropna(inplace=True)
data.drop(["month", "weekday"], inplace=True, axis=1)

处理成dataloader格式(用于输入到模型当中)

通过TensorDataloader转换为张量数据,通过TensorDataset能够让模型训练的部分模板化,强烈推荐使用

tran_mask =  (data["Date"] <= "2023-12-01")  # 训练集的范围
test_mask = (data["Date"] >= "2023-12-01") & (data["Date"] <= "2025-12-31")  # 测试集范围
train_data = data[tran_mask].values[:, 1:]  # 取出训练集(不包括Date)
test_data = data[test_mask].values[:, 1:]  # 取出测试集(不包括Date)
def create_sequences(data, time_steps):"""取出时间步数据"""x, y = [], []  # 这里的x其实是所有带时间步的汇总for i in range(len(data) - time_steps):x.append(data[i:i + time_steps]) # 选择从 i 开始的 time_steps 个时间步,这里已经是一个二维列表了,其中每一个元素都对应过去的一个时间节点的所有数据y.append(data[i + time_steps][0])    # 选择第 i + time_steps 个数据作为目标return np.array(x).astype("float32"), np.array(y).astype("float32").reshape(-1, 1)
train_x, train_y = create_sequences(train_data, 7)
test_x, test_y = create_sequences(test_data, 7)
train_inputs = TensorDataset(torch.tensor(train_x, dtype=torch.float32), torch.tensor(train_y, dtype=torch.float32))  # 转换为张量数据,通过TensorDataset能够让模型训练的部分完全模板化,强烈推荐使用
test_inputs = TensorDataset(torch.tensor(test_x, dtype=torch.float32), torch.tensor(test_y, dtype=torch.float32))
train_dataloader = DataLoader(dataset=train_inputs, batch_size=32, shuffle=True)
test_dataloader = DataLoader(dataset=test_inputs, batch_size=32)

对应到lstm的模型创建代码

注意lstm中的输入层,其大小应该为特征的数量
如果只设置了值为特征,那么input_size就为1
如果设置了8个特征,那么算上自己的值之后input_size就为9

class LSTM(nn.Module):def __init__(self, input_size, hidden_size, output_size):super(LSTM, self).__init__()self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True)  # lstm的返回有两个值,所以不能直接用sequentialself.fc = nn.Linear(hidden_size, output_size)  # 线性def forward(self, x):# (样本大小, 时间步,特征数量)x, _ = self.lstm(x)  # x代表输出的形状,_是用于下一步处理,简单lstm就不用这个了x = x[:, -1, :]  # 只取最后一个时间步x = self.fc(x)  # 将输入的x转换为output_size尺寸return x
# 设置接口
input_size, hidden_size = 9, 8  # 输入的为特征数量
lstm = LSTM(input_size, hidden_size, 1)  # input_size为特征数量

其他代码这里就不放了,后面都是模板,问gpt,看其他文章都差不多,不浪费时间:

  • 设置优化器损失函数,数据移入device
  • for epoch 个训练,每个epoch中for data in dataloader:(这也是我觉得dataloader好的地方,也可以帮助我们将数据分批次与打乱,调整dataloader的batch_size即可)
  • 前向传播,计算损失,清空梯度,反向传播,更新参数
  • 测试集验证模型精度

总结

特征正常提取,和机器学习中的特征学习一样,假设n个特征
接下来对特征处理成时间步的形状,假设m个时间步,则每一个样本输入到模型中的大小
(1, m, n):一个样本算上自己的前m个时间节点,包含每个时间节点的n个特征

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/888280.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

uniapp 自定义导航栏增加首页按钮,仿微信小程序操作胶囊

实现效果如图 抽成组件navbar.vue&#xff0c;放入分包 <template><view class"header-nav-box":style"{height:Props.imgShow?:statusBarHeightpx,background:Props.imgShow?:Props.bgColor||#ffffff;}"><!-- 是否使用图片背景 false…

node.js基础学习-express框架-静态资源中间件express.static(十一)

前言 在 Node.js 应用中&#xff0c;静态资源是指那些不需要服务器动态处理&#xff0c;直接发送给客户端的文件。常见的静态资源包括 HTML 文件、CSS 样式表、JavaScript 脚本、图片&#xff08;如 JPEG、PNG 等&#xff09;、字体文件和音频、视频文件等。这些文件在服务器端…

下载maven 3.6.3并校验文件做md5或SHA512校验

一、下载Apache Maven 3.6.3 Apache Maven 3.6.3 官方下载链接&#xff1a; 二进制压缩包&#xff08;推荐&#xff09;: ZIP格式: https://archive.apache.org/dist/maven/maven-3/3.6.3/binaries/apache-maven-3.6.3-bin.zipTAR.GZ格式: https://archive.apache.org/dist/…

单片机知识总结(完整)

1、单片机概述 1.1. 单片机的定义与分类 定义&#xff1a; 单片机&#xff08;Microcontroller Unit&#xff0c;简称MCU&#xff09;是一种将微处理器、存储器&#xff08;包括程序存储器和数据存储器&#xff09;、输入/输出接口和其他必要的功能模块集成在单个芯片上的微型…

C# winform非常好用的图表开源控件Scottplot

wifnorm自带的chart控件功能和性能都不太行&#xff0c;所以在网上搜索到了Scottplot开源图表控件。根据自己需要&#xff0c;将已经试验使用过的用法记录在这里 winform建议使用版本 Scottplot包版本&#xff1a;4.1.71 这个版本在winform中可以以控件形式直接拖拉到窗体中使…

Python 3 教程第33篇(MySQL - mysql-connector 驱动)

Python MySQL - mysql-connector 驱动 MySQL 是最流行的关系型数据库管理系统&#xff0c;如果你不熟悉 MySQL&#xff0c;可以阅读我们的 MySQL 教程。 本章节我们为大家介绍使用 mysql-connector 来连接使用 MySQL&#xff0c; mysql-connector 是 MySQL 官方提供的驱动器。…

Gooxi Eagle Stream 2U双路通用服务器:性能强劲 灵活扩展 稳定易用

人工智能的高速发展开启了飞轮效应&#xff0c;实施数字化变革成为了企业的一道“抢答题”和“必答题”&#xff0c;而数据已成为现代企业的命脉。以HPC和AI为代表的新业务就像节节攀高的树梢&#xff0c;象征着业务创新和企业成长。但在树梢之下&#xff0c;真正让企业保持成长…

UICollectionView在xcode16编译闪退问题

使用xcode15运行工程&#xff0c;控制台会出现如下提示&#xff1a; Expected dequeued view to be returned to the collection view in preparation for display. When the collection views data source is asked to provide a view for a given index path, ensure that a …

Gentoo Linux部署LNMP

一、安装nginx 1.gentoo-chxf ~ # emerge -av nginx 提示配置文件需更新 2.gentoo-chxf ~ # etc-update 3.gentoo-chxf ~ # emerge -av nginx 4.查看并启动nginx gentoo-chxf ~ # systemctl status nginx gentoo-chxf ~ # systemctl start nginx gentoo-chxf ~ # syst…

【k8s深入理解之 Scheme 补充-2】理解 register.go 暴露的 AddToScheme 函数

AddToScheme 函数 AddToScheme 就是为了对外暴露&#xff0c;方便别人调用&#xff0c;将当前Group组的信息注册到其 Scheme 中&#xff0c;以便了解该 Group 组的数据结构&#xff0c;用于后续处理 项目版本用途使用场景k8s.io/apiV1注册资源某一外部版本数据结构&#xff0…

CQ 社区版 2024.11 | 新增“审批人组”概念、可通过SQL模式自定义审计图表……

CloudQuery 社区 11 月新版本来啦&#xff01;本月版本依旧是 CUG&#xff08;CloudQuery 用户组&#xff09;尝鲜版的更新。 针对审计模块增加了 SQL 模式自定义审计图表&#xff1b;在流程模块引入了“审批人组”概念。此外&#xff0c;在 SQL 编辑器、连接管理等模块都涉及…

做异端中的异端 -- Emacs裸奔之路4: 你不需要IDE

确切地说&#xff0c;你不需要在IDE里面编写或者阅读代码。 IDE用于Render资源文件比较合适&#xff0c;但处理文本&#xff0c;并不划算。 这的文本文件&#xff0c;包括源代码&#xff0c;配置文件&#xff0c;文档等非二进制文件。 先说说IDE带的便利: 函数或者变量的自动…

RDIFramework.NET CS敏捷开发框架 SOA服务三种访问(直连、WCF、WebAPI)方式

1、介绍 在软件开发领域&#xff0c;尤其是企业级应用开发中&#xff0c;灵活性、开放性、可扩展性往往是项目成功的关键因素。对于C/S项目&#xff0c;如何高效地与后端数据库进行交互&#xff0c;以及如何提供多样化的服务访问方式&#xff0c;是开发者需要深入考虑的问题。…

GitLab CVE-2024-8114 漏洞解决方案

漏洞 ID 标题严重等级CVE ID通过 LFS 令牌提升权限高CVE-2024-8114 GitLab 升级指南GitLab 升级路径查看版本漏洞查询 漏洞解读 此漏洞允许攻击者使用受害者的个人访问令牌&#xff08;PAT&#xff09;进行权限提升。影响从 8.12 开始到 17.4.5 之前的所有版本、从 17.5 开…

基于Pyside6开发一个通用的在线升级工具

UI main.ui <?xml version"1.0" encoding"UTF-8"?> <ui version"4.0"><class>MainWindow</class><widget class"QMainWindow" name"MainWindow"><property name"geometry"&…

Redis(配置文件属性解析)

一、tcp-backlog深度解析 tcp-backlog是一个TCP连接的队列&#xff0c;主要用于解决高并发场景下客户端慢连接问题。配置文件中的“511”就是队列的长度&#xff0c;对联与TCP的三次握手有关&#xff0c;不同的linux内核&#xff0c;backlog队列中存放的元素&#xff08;客户端…

24.12.02 Element

import { createApp } from vue // 引入elementPlus js库 css库 import ElementPlus from element-plus import element-plus/dist/index.css //中文语言包 import zhCn from element-plus/es/locale/lang/zh-cn //图标库 import * as ElementPlusIconsVue from element-plus/i…

mybatis-plus 对于属性为null字段不更新

MyBatis-Plus 默认情况下会根据字段的值是否为 null 来决定是否生成对应的 UPDATE 语句。这是由 更新策略 决定的&#xff0c;默认的行为是 忽略 null 值&#xff0c;即如果字段值为 null&#xff0c;该字段将不会出现在 UPDATE 语句中。 默认行为分析 MyBatis-Plus 默认的 Fi…

C++小问题

怎么分辨const修饰的是谁 是限定谁不能被改变的&#xff1f; 在C中&#xff0c;const关键字的用途和位置非常关键&#xff0c;它决定了谁不能被修改。const可以修饰变量、指针、引用等不同的对象&#xff0c;并且具体的作用取决于const的修饰位置。理解const的规则能够帮助我们…

在线家具商城基于 SpringBoot:设计模式与实现方法探究

第3章 系统分析 用户的需求以及与本系统相似的在市场上存在的其它系统可以作为系统分析中参考的资料&#xff0c;分析人员可以根据这些信息确定出本系统具备的功能&#xff0c;分析出本系统具备的性能等内容。 3.1可行性分析 尽管系统是根据用户的要求进行制作&#xff0c;但是…