1 时间序列模型入门: LSTM

0 前言

        循环神经网络(Recurrent Neural Network,RNN)是一种用于处理序列数据的神经网络。相比一般的神经网络来说,他能够处理序列变化的数据。比如某个单词的意思会因为上文提到的内容不同而有不同的含义,RNN就能够很好地解决这类问题。本质是一个全连接网络,但是因为当前时刻受历史时刻的影响。

      传统的RNN结构可以看做是多个重复的神经元构成的“回路”,每个神经元都接受输入信息并产生输出,然后将输出再次作为下一个神经元的输入,依次传递下去。这种结构能够在序列数据上学习短时依赖关系,但是由于梯度消失和梯度爆炸问题(梯度反向求导链式法则导致),RNN在处理长序列时难以达到很好的性能。而LSTM通过引入记忆细胞、输入门、输出门和遗忘门的概念(加号的引入),能够有效地解决长序列问题。记忆细胞负责保存重要信息,输入门决定要不要将当前输入信息写入记忆细胞,遗忘门决定要不要遗忘记忆细胞中的信息,输出门决定要不要将记忆细胞的信息作为当前的输出。这些门的控制能够有效地捕捉序列中重要的长时间依赖性,并且能够解决梯度问题。

备注:

输入门:输入的数据有多大程度进入模型;

输出门:控制当前时刻的内部状态 c​有多少信息需要输出给外部状态; 

遗忘门:控制上一个时刻的内部状态, ct−1​需要遗忘多少信息;

 1 模块介绍

数据: 2023“SEED”第四届江苏大数据开发与应用大赛--新能源赛道的数据

MARS开发者生态社区

解题思路: 总共500个充电站状, 关联地理位置,然后提取18个特征;把这18个特征作为时步不长(记得是某个比赛的思路)然后特征长度为1 (类比词向量的size).

y = LSTM(x,h0,c0)

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
#import tushare as ts
from sklearn.preprocessing import StandardScaler, MinMaxScaler
from sklearn.model_selection import train_test_split
from torch.utils.data import TensorDataset
from tqdm import tqdmimport matplotlib.pyplot as plt
import tqdm
import sys
import os
import gc
import argparse
import warningswarnings.filterwarnings('ignore')# 读取数据
train_power_forecast_history = pd.read_csv('../data/data1/train/power_forecast_history.csv')
train_power = pd.read_csv('../data/data1/train/power.csv')
train_stub_info = pd.read_csv('../data/data1/train/stub_info.csv')test_power_forecast_history = pd.read_csv('../data/data1/test/power_forecast_history.csv')
test_stub_info = pd.read_csv('../data/data1/test/stub_info.csv')# 聚合数据
train_df = train_power_forecast_history.groupby(['id_encode','ds']).head(1)
del train_df['hour']test_df = test_power_forecast_history.groupby(['id_encode','ds']).head(1)
del test_df['hour']tmp_df = train_power.groupby(['id_encode','ds'])['power'].sum()
tmp_df.columns = ['id_encode','ds','power']# 合并充电量数据
train_df = train_df.merge(tmp_df, on=['id_encode','ds'], how='left')### 合并数据
train_df = train_df.merge(train_stub_info, on='id_encode', how='left')
test_df = test_df.merge(test_stub_info, on='id_encode', how='left')h3_code = pd.read_csv("../data/h3_lon_lat.csv")
train_df = train_df.merge(h3_code,on='h3')
test_df = test_df.merge(h3_code,on='h3')def kalman_filter(data, q=0.0001, r=0.01):# 后验初始值x0 = data[0]                              # 令第一个估计值,为当前值p0 = 1.0# 存结果的列表x = [x0]for z in data[1:]:                        # kalman 滤波实时计算,只要知道当前值z就能计算出估计值(后验值)x0# 先验值x1_minus = x0                         # X(k|k-1) = AX(k-1|k-1) + BU(k) + W(k), A=1,BU(k) = 0p1_minus = p0 + q                     # P(k|k-1) = AP(k-1|k-1)A' + Q(k), A=1# 更新K和后验值k1 = p1_minus / (p1_minus + r)        # Kg(k)=P(k|k-1)H'/[HP(k|k-1)H' + R], H=1x0 = x1_minus + k1 * (z - x1_minus)   # X(k|k) = X(k|k-1) + Kg(k)[Z(k) - HX(k|k-1)], H=1p0 = (1 - k1) * p1_minus              # P(k|k) = (1 - Kg(k)H)P(k|k-1), H=1x.append(x0)                          # 由输入的当前值z 得到估计值x0存入列表中,并开始循环到下一个值return x#kalman_filter()
train_df['new_label'] = 0
for i in range(500):#print(i)label = i#train_df[train_df['id_encode']==labe]['power'].valuestrain_df.loc[train_df['id_encode']==label, 'new_label'] = kalman_filter(data=train_df[train_df['id_encode']==label]['power'].values)### 数据预处理
train_df['flag'] = train_df['flag'].map({'A':0,'B':1})
test_df['flag'] = test_df['flag'].map({'A':0,'B':1})def get_time_feature(df, col):df_copy = df.copy()prefix = col + "_"df_copy['new_'+col] = df_copy[col].astype(str)col = 'new_'+coldf_copy[col] = pd.to_datetime(df_copy[col], format='%Y%m%d')#df_copy[prefix + 'year'] = df_copy[col].dt.yeardf_copy[prefix + 'month'] = df_copy[col].dt.monthdf_copy[prefix + 'day'] = df_copy[col].dt.day# df_copy[prefix + 'weekofyear'] = df_copy[col].dt.weekofyeardf_copy[prefix + 'dayofweek'] = df_copy[col].dt.dayofweek# df_copy[prefix + 'is_wknd'] = df_copy[col].dt.dayofweek // 6df_copy[prefix + 'quarter'] = df_copy[col].dt.quarter# df_copy[prefix + 'is_month_start'] = df_copy[col].dt.is_month_start.astype(int)# df_copy[prefix + 'is_month_end'] = df_copy[col].dt.is_month_end.astype(int)del df_copy[col]return df_copytrain_df = get_time_feature(train_df, 'ds')
test_df = get_time_feature(test_df, 'ds')train_df = train_df.fillna(999)
test_df = test_df.fillna(999)cols = [f for f in train_df.columns if f not in ['ds','power','h3','new_label']]# 是否进行归一化
scaler = MinMaxScaler(feature_range=(0,1))
scalar_falg = False
if scalar_falg == True:df_for_training_scaled = scaler.fit_transform(train_df[cols])df_for_testing_scaled= scaler.transform(test_df[cols])
else:df_for_training_scaled = train_df[cols]df_for_testing_scaled = test_df[cols]
#df_for_training_scaled
# scaler_label = MinMaxScaler(feature_range=(0,1))
# label_for_training_scaled = scaler_label.fit_transform(train_df['new_label']..values)
# label_for_testing_scaled= scaler_label.transform(train_df['new_label'].values)
# #df_for_training_scaledclass Config():data_path = '../data/data1/train/power.csv'timestep = 18  # 时间步长,就是利用多少时间窗口batch_size = 32  # 批次大小feature_size = 1  # 每个步长对应的特征数量,这里只使用1维,每天的风速hidden_size = 256  # 隐层大小output_size = 1  # 由于是单输出任务,最终输出层大小为1,预测未来1天风速num_layers = 2  # lstm的层数epochs = 10 # 迭代轮数best_loss = 0 # 记录损失learning_rate = 0.00003 # 学习率model_name = 'lstm' # 模型名称save_path = './{}.pth'.format(model_name) # 最优模型保存路径config = Config()
x_train, x_test, y_train, y_test = train_test_split(df_for_training_scaled.values, train_df['new_label'].values,shuffle=False, test_size=0.2)# 将数据转为tensor
x_train_tensor = torch.from_numpy(x_train.reshape(-1,config.timestep,1)).to(torch.float32)
y_train_tensor = torch.from_numpy(y_train.reshape(-1,1)).to(torch.float32)
x_test_tensor = torch.from_numpy(x_test.reshape(-1,config.timestep,1)).to(torch.float32)
y_test_tensor = torch.from_numpy(y_test.reshape(-1,1)).to(torch.float32)# 5.形成训练数据集
train_data = TensorDataset(x_train_tensor, y_train_tensor)
test_data = TensorDataset(x_test_tensor, y_test_tensor)# 6.将数据加载成迭代器
train_loader = torch.utils.data.DataLoader(train_data,config.batch_size,False)test_loader = torch.utils.data.DataLoader(test_data,config.batch_size,False)#train_df[cols]
# 7.定义LSTM网络
class LSTM(nn.Module):def __init__(self, feature_size, hidden_size, num_layers, output_size):super(LSTM, self).__init__()self.hidden_size = hidden_size  # 隐层大小self.num_layers = num_layers  # lstm层数# feature_size为特征维度,就是每个时间点对应的特征数量,这里为1self.lstm = nn.LSTM(feature_size, hidden_size, num_layers, batch_first=True)self.fc = nn.Linear(hidden_size, output_size)def forward(self, x, hidden=None):#print(x.shape)batch_size = x.shape[0] # 获取批次大小 batch, time_stamp , feat_size# 初始化隐层状态if hidden is None:h_0 = x.data.new(self.num_layers, batch_size, self.hidden_size).fill_(0).float()c_0 = x.data.new(self.num_layers, batch_size, self.hidden_size).fill_(0).float()else:h_0, c_0 = hidden# LSTM运算output, (h_0, c_0) = self.lstm(x, (h_0, c_0))# 全连接层output = self.fc(output)  # 形状为batch_size * timestep, 1# 我们只需要返回最后一个时间片的数据即可return output[:, -1, :]
model = LSTM(config.feature_size, config.hidden_size, config.num_layers, config.output_size)  # 定义LSTM网络loss_function = nn.L1Loss()  # 定义损失函数
optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)  # 定义优化器# 8.模型训练
for epoch in range(config.epochs):model.train()running_loss = 0train_bar = tqdm(train_loader)  # 形成进度条for data in train_bar:x_train, y_train = data  # 解包迭代器中的X和Yoptimizer.zero_grad()y_train_pred = model(x_train)loss = loss_function(y_train_pred, y_train.reshape(-1, 1))loss.backward()optimizer.step()running_loss += loss.item()train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,config.epochs,loss)# 模型验证model.eval()test_loss = 0with torch.no_grad():test_bar = tqdm(test_loader)for data in test_bar:x_test, y_test = datay_test_pred = model(x_test)test_loss = loss_function(y_test_pred, y_test.reshape(-1, 1))if test_loss < config.best_loss:config.best_loss = test_losstorch.save(model.state_dict(), save_path)print('Finished Training')
train epoch[1/10] loss:293.638: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3727/3727 [01:41<00:00, 36.84it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 932/932 [00:12<00:00, 73.96it/s]
train epoch[2/10] loss:272.386: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3727/3727 [01:49<00:00, 33.90it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 932/932 [00:12<00:00, 77.56it/s]
train epoch[3/10] loss:252.972: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3727/3727 [01:43<00:00, 35.91it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 932/932 [00:13<00:00, 70.65it/s]
train epoch[4/10] loss:235.282: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3727/3727 [01:45<00:00, 35.27it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 932/932 [00:09<00:00, 93.71it/s]
train epoch[5/10] loss:219.069: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3727/3727 [01:34<00:00, 39.57it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 932/932 [00:08<00:00, 103.92it/s]
train epoch[6/10] loss:203.969: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3727/3727 [01:30<00:00, 41.22it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 932/932 [00:09<00:00, 95.62it/s]
train epoch[7/10] loss:189.877: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3727/3727 [01:35<00:00, 39.19it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 932/932 [00:12<00:00, 77.61it/s]
train epoch[8/10] loss:176.701: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3727/3727 [01:52<00:00, 33.12it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 932/932 [00:11<00:00, 81.34it/s]
train epoch[9/10] loss:164.382: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3727/3727 [01:41<00:00, 36.64it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 932/932 [00:08<00:00, 112.27it/s]
train epoch[10/10] loss:152.841: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3727/3727 [01:25<00:00, 43.40it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 932/932 [00:09<00:00, 94.54it/s]

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/170995.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

2023-3年CSDN创作纪念日

机缘 今天开开心心出门去上班&#xff0c;就收到了一个csdn私信&#xff0c;打开一看说是给我惊喜来着&#xff0c;我心想csdn还能给惊喜&#xff1f;以为是有什么奖品或者周边之类的&#xff0c;结果什么也没有&#xff0c;打开就是一份信&#x1f602;。 也挺不错的&#xf…

1.6 C语言之数组概述

1.6 C语言之数组概述 一、数组二、练习 一、数组 所谓数组&#xff0c;就是内存中一片连续的空间&#xff0c;可以用来存储一组同类型的数据 数组有下标&#xff0c;从0开始&#xff0c;可以理解为是给数组中的元素编号&#xff0c;便于后续寻址访问 我们来编写一个程序&…

SparkSQL之Optimized LogicalPlan生成过程

经过Analyzer的处理&#xff0c;Unresolved LogicalPlan已经解析成为Analyzed LogicalPlan。Analyzed LogicalPlan中自底向上节点分别对应Relation、Subquery、Filter和Project算子。   Analyzed LogicalPlan基本上是根据Unresolved LogicalPlan一对一转换过来的&#xff0c;…

针对哈希冲突的解决方法

了解哈希表和哈希冲突是什么 哈希表&#xff1a;是一种实现关联数组抽象数据类型的数据结构&#xff0c;这种结构可以将关键码映射到给定值。简单来说哈希表&#xff08;key-value&#xff09;之间存在一个映射关系&#xff0c;是键值对的关系&#xff0c;一个键对应一个值。 …

foobar2000 突然无法正常输出DSD信号

之前一直在用foobar2000加外置dac听音乐&#xff0c;有一天突然发现听dsd的时候&#xff0c;dac面板显示输出的是PCM格式信号&#xff0c;而不是DSD信号&#xff0c;这让我觉得很奇怪&#xff0c;反复折腾了几次&#xff0c;卸载安装驱动什么的&#xff0c;依然如此&#xff0c…

java协同过滤算法 springboot+vue游戏推荐系统

随着人们生活质量的不断提高以及个人电脑和网络的普及&#xff0c;人们的业余生活质量要求也在不断提高&#xff0c;选择一款好玩&#xff0c;精美&#xff0c;画面和音质&#xff0c;品质优良的休闲游戏已经成为一种流行的休闲方式。可以说在人们的日常生活中&#xff0c;除了…

k8s集群资源监控工具metrics-server安装

1、下载镜像 docker pull swr.cn-east-2.myhuaweicloud.com/kuboard-dependency/metrics-server:v0.6.22、在任一一个主节点上创建角色&#xff0c;执行下面语句 kubectl create clusterrolebinding kube-proxy-cluster-admin --clusterrolecluster-admin --usersystem:kube-…

Int8量化算子在移动端CPU的性能优化

本文介绍了Depthwise Convolution 的Int8算子在移动端CPU上的性能优化方案。ARM架构的升级和相应指令集的更新不断提高移动端各算子的性能上限&#xff0c;结合数据重排和Sdot指令能给DepthwiseConv量化算子的性能带来较大提升。 背景 MNN对ConvolutionDepthwise Int8量化算子在…

Shell脚本:Linux Shell脚本学习指南(第三部分Shell高级)二

七、Shell Here String&#xff08;内嵌字符串&#xff0c;嵌入式字符串&#xff09; Here String 是《六、Shell Here Document&#xff08;内嵌文档/立即文档&#xff09;》的一个变种&#xff0c;它的用法如下&#xff1a; command <<< string command 是 Shell 命…

JavaScript如何实现钟表效果,时分秒针指向当前时间,并显示当前年月日,及2024春节倒计时,源码奉上

本篇有运用jQuery&#xff0c;记得引入jQuery库&#xff0c;否则不会执行的喔~ <!DOCTYPE html> <html> <head> <meta charset"utf-8"> <title></title> <meta name"chenc" content"Runoob"> <met…

element-ui表格无法横向拖动问题

是不是用到了fixed // 因为我只有在小屏显示不下的时候才会出现这个问题所以我在这里做了适配(建议把样式放在全局) media screen and (max-width: 1800px) {// 由于使用了fixed导致横向条无法拖动出现bug.Table-page .el-table__fixed {height: auto !important;bottom: 2px …

计算机编程基础教程,中文编程工具下载,编程构件组合按钮

计算机编程基础教程&#xff0c;中文编程工具下载&#xff0c;编程构件组合按钮 给大家分享一款中文编程工具&#xff0c;零基础轻松学编程&#xff0c;不需英语基础&#xff0c;编程工具可下载。 这款工具不但可以连接部分硬件&#xff0c;而且可以开发大型的软件&#xff0c…

开卷翻到毒蘑菇?浅谈大模型检索增强(RAG)的鲁棒性

©PaperWeekly 原创 作者 | 陈思硕 单位 | 北京大学 研究方向 | 自然语言处理 很久没有写论文 notes 了&#xff0c;近期因为参与对检索增强生成&#xff08;Retrieval-Augmented Generation, RAG&#xff09;范式鲁棒性的研究&#xff0c;注意到了近两个月来社区中涌现了…

Java核心知识点整理大全15-笔记

Java核心知识点整理大全-笔记_希斯奎的博客-CSDN博客 Java核心知识点整理大全2-笔记_希斯奎的博客-CSDN博客 Java核心知识点整理大全3-笔记_希斯奎的博客-CSDN博客 Java核心知识点整理大全4-笔记-CSDN博客 Java核心知识点整理大全5-笔记-CSDN博客 Java核心知识点整理大全6…

【Kotlin】类与接口

文章目录 类的定义创建类的实例构造函数主构造函数次构造函数init语句块 数据类的定义数据类定义了componentN方法 继承AnyAny&#xff1a;非空类型的根类型Any?&#xff1a;所有类型的根类型 覆盖方法覆盖属性覆盖 抽象类接口:使用interface关键字函数&#xff1a;funUnit:让…

RocketMq 队列(MessageQueue)

RocketMq是阿里出品&#xff08;基于MetaQ&#xff09;的开源中间件&#xff0c;已捐赠给Apache基金会并成为Apache的顶级项目。基于java语言实现&#xff0c;十万级数据吞吐量&#xff0c;ms级处理速度&#xff0c;分布式架构&#xff0c;功能强大&#xff0c;扩展性强。 官方…

Kerberos 高可用配置和验证

参考 https://cloud.tencent.com/developer/article/1078314 https://mp.weixin.qq.com/s?__bizMzI4OTY3MTUyNg&mid2247485861&idx1&snbb930a497f63ac5e63ed20c64643eec5 机器准备 Kerberos主 ip-172-31-22-86.ap-southeast-1.compute.internal 7.common2.hado…

【华为数通HCIP | 网络工程师】821-IGP高频题、易错题之OSPF(7)

个人名片&#xff1a; &#x1f43c;作者简介&#xff1a;一名大三在校生&#xff0c;喜欢AI编程&#x1f38b; &#x1f43b;‍❄️个人主页&#x1f947;&#xff1a;落798. &#x1f43c;个人WeChat&#xff1a;hmmwx53 &#x1f54a;️系列专栏&#xff1a;&#x1f5bc;️…

C语言盐水的故事(ZZULIOJ1214:盐水的故事)

题目描述 挂盐水的时候&#xff0c;如果滴起来有规律&#xff0c;先是滴一滴&#xff0c;停一下&#xff1b;然后滴二滴&#xff0c;停一 下&#xff1b;再滴三滴&#xff0c;停一下...&#xff0c;现在有一个问题&#xff1a;这瓶盐水一共有VUL毫升&#xff0c;每一滴是D毫升&…

【brpc学习实践八】bvar及其应用

什么是bvar bvar是多线程环境下的计数器类库&#xff0c;支持单维度bvar和多维度mbvar&#xff0c;方便记录和查看用户程序中的各类数值&#xff0c;它利用了thread local存储减少了cache bouncing&#xff0c;相比UbMonitor(百度内的老计数器库)几乎不会给程序增加性能开销&a…