机器学习算法实战案例:CNN-LSTM实现多变量多步光伏预测

文章目录

      • 1 数据处理
        • 1.1 导入库文件
        • 1.2 导入数据集
        • 1.3 缺失值分析
      • 2 构造训练数据
      • ​3 模型训练
        • 3.1 CNN-LSTM网络
        • 3.2 模型训练
      • 4 模型预测
      • 答疑&技术交流
      • 机器学习算法实战案例系列

1 数据处理

1.1 导入库文件
from matplotlib import pyplot as pltimport tensorflow as tf    from tensorflow import keras from tensorflow.keras import Sequential, layers, callbacksfrom tensorflow.keras.layers import Input, Reshape,Conv2D, MaxPooling2D, LSTM, Dense, Dropout, Flatten, Reshapefrom sklearn.preprocessing import MinMaxScalerfrom sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error, mean_absolute_percentage_error warnings.filterwarnings('ignore')  
plt.rcParams['font.sans-serif'] = ['SimHei']     # 显示中文plt.rcParams['axes.unicode_minus'] = False  # 显示负号plt.rcParams.update({'font.size':18})  #统一字体字号
1.2 导入数据集

实验数据集采用数据集6:澳大利亚电力负荷与价格预测数据,数据集包括包括数据集包括日期、小时、干球温度、露点温度、湿球温度、湿度、电价、电力负荷特征,时间间隔30min。选取两年的数据进行实验,对数据进行可视化:

from itertools import cycledef visualize_data(data, row, col):cycol = cycle('bgrcmk')cols = list(data.columns)fig, axes = plt.subplots(row, col, figsize=(16, 4))if row == 1 and col == 1:  # 处理只有1行1列的情况axes = [axes]  # 转换为列表,方便统一处理for i, ax in enumerate(axes.flat):if i < len(cols):ax.plot(data.iloc[:,i], c=next(cycol))ax.set_title(cols[i])ax.axis('off')  # 如果数据列数小于子图数量,关闭多余的子图plt.subplots_adjust(hspace=0.6)visualize_data(data_raw.iloc[:,2:], 2, 3)

​​

​单独查看部分功率数据,发现有较强的规律性。

​​

1.3 缺失值分析

首先查看数据的信息,发现并没有缺失值

​​进一步统计缺失值,通过统计数据可以看到,数据比较完整,不存在缺失值。其他异常值和数据处理。

data_raw.isnull().sum()

2 构造训练数据

选取数据集,去掉时间特征

data = data_raw.iloc[:,2:].values

构造训练数据,也是真正预测未来的关键。首先设置预测的timesteps时间步、predict_steps预测的步长(预测的步长应该比总的预测步长小),length总的预测步长,参数可以根据需要更改。

通过前5天的timesteps个数据预测后一天的predict_steps个数据,需要对数据集进行滚动划分(也就是前timesteps行的特征和后predict_steps行的标签训练,后面预测时就可通过timesteps行特征预测未来的predict_steps个标签)。因为是多变量,特征和标签分开划分。

​​数据处理前,需要对数据进行归一化,按照上面的方法划分数据,这里返回划分的数据和归一化模型,因为是多变量,特征和标签分开归一化,不然后面归一化会有信息泄露的问题。函数的定义如下:

def data_scaler(datax, datay=None, timesteps=36, predict_steps=6):scaler1 = MinMaxScaler(feature_range=(0, 1))   datax = scaler1.fit_transform(datax)# 用前面的数据进行训练,留最后的数据进行预测if datay is not None:scaler2 = MinMaxScaler(feature_range=(0, 1))datay = scaler2.fit_transform(datay)trainx, trainy = create_dataset(datax, datay, timesteps, predict_steps)trainx = np.array(trainx)trainy = np.array(trainy)return trainx, trainy, scaler1, scaler2trainx, trainy = create_dataset(datax, timesteps=timesteps, predict_size=predict_steps)trainx = np.array(trainx)trainy = np.array(trainy)return trainx, trainy, scaler1, None

然后对数据按照上面的函数进行划分和归一化。通过前5天的96*5数据预测后一天的96个数据,需要对数据集进行滚动划分(也就是前96*5行的特征和后96行的标签训练,后面预测时就可通过96*5行特征预测未来的96个标签)

​3 模型训练

3.1 CNN-LSTM网络

CNN-LSTM 是一种结合了 CNN 特征提取能力与 LSTM 对时间序列长期记忆能力的混合神经网络。

CNN 主要由四个层级组成, 分别为输入层、 卷积层、 激活层(Relu 函数)和池化层。 每一层都会将数据处理之后送到下一层, 其中最重要的是卷积层, 这个层级起到的作用是将特征数据进行卷积计算, 将计算好的结果传到激活层, 激活函数对数据进行筛选。最后一层是 LSTM 层, 这一层是根据 CNN 处理后的特征数据,对其模型进行进一步的维度修偏, 权重修正等工作, 为下一步输出精度较高的预测值做好准备, 在 LSTM 训练的过程中, 由于其神经网络内部包括了输入、 遗忘和输出门, 通常的做法是通过增减遗忘门和输入门的个数, 来控制算法的精度。

来源:基于改进的 CNN-LSTM 短期风功率预测方法研究

对于输入到 CNN-LSTM 的数据,首先,经过 CNN 的卷积层对局部特征进行提取,将提取后的特征向量传递到池化层进行特征向量的下采样和数据体量的压缩。然后,将经过卷积层和池化层处理后的特征向量经过一个扁平层转化成一维向量输入到 LSTM 中, 每一层 LSTM 后加一个随机失活层以防止模型过拟合。

3.2 模型训练

首先搭建模型的常规操作,然后使用训练数据trainx和trainy进行训练,进行50个epochs的训练,每个batch包含64个样本。此时input_shape划分数据集时每个x的形状。(建议使用GPU进行训练,因为本人电脑性能有限,建议增加epochs值)

def CNN_LSTM_model_train(trainx, trainy, timesteps, feature_num, predict_steps):gpus = tf.config.experimental.list_physical_devices(device_type='GPU')tf.config.experimental.set_memory_growth(gpu, True)start_time = datetime.datetime.now()model.add(Input((timesteps, feature_num)))model.add(Conv2D(filters=64,kernel_size=3,strides=1,padding="same",activation="relu"))model.add(MaxPooling2D(pool_size=2, strides=1, padding="same"))model.add(Dropout(0.3))model.add(Reshape((timesteps, -1)))model.add(LSTM(128, return_sequences=True, dropout=0.2))  # 添加dropout层model.add(LSTM(64, return_sequences=False, dropout=0.2))  # 添加dropout层model.add(Dense(64, activation="relu"))  # 增加Dense层节点数量model.add(Dense(predict_steps))model.fit(trainx, trainy, epochs=50, batch_size=128)end_time = datetime.datetime.now()running_time = end_time - start_timemodel.save('CNN_LSTM_model.h5')

对划分的数据进行训练

model = CNN_LSTM_model_train(trainx, trainy, timesteps, feature_num, predict_steps)

4 模型预测

首先加载训练好后的模型

from tensorflow.keras.models import load_modelmodel = load_model('BiLSTM_model.h5')

准备好需要预测的数据,训练时保留了6天的数据,将前5天的数据作为输入预测,将预测的结果和最后一天的真实值进行比较。

y_true = datay[-timesteps-predict_steps:-timesteps]x_pred = datax[-timesteps:]

预测并计算误差,并进行可视化,将这些步骤封装为函数。​​​​​​​

def predict_and_plot(x, y_true, model, scaler, timesteps):predict_x = np.reshape(x, (1, timesteps, feature_num))  predict_y = model.predict(predict_x)predict_y = scaler.inverse_transform(predict_y)y_predict.extend(predict_y[0])r2 = r2_score(y_true, y_predict)rmse = mean_squared_error(y_true, y_predict, squared=False)mae = mean_absolute_error(y_true, y_predict)mape = mean_absolute_percentage_error(y_true, y_predict)print("r2: %.2f\nrmse: %.2f\nmae: %.2f\nmape: %.2f" % (r2, rmse, mae, mape))cycol = cycle('bgrcmk')plt.figure(dpi=100, figsize=(14, 5))plt.plot(y_true, c=next(cycol), markevery=5)plt.plot(y_predict, c=next(cycol), markevery=5)plt.legend(['y_true', 'y_predict'])
y_predict = predict_and_plot(x_pred, y_true, model, scaler2, timesteps)

最后得到可视化结果和计算的误差,可以通过调参和数据处理进一步提升模型预测效果。

  • r2: 0.19
  • ​​rmse: 725.34
  • mae: 640.73
  • mape: 0.08

答疑&技术交流

技术要学会分享、交流,不建议闭门造车。一个人可以走的很快、一堆人可以走的更远。

本文完整代码、相关资料、技术交流&答疑,均可加我们的交流群获取,群友已超过2000人,添加时最好的备注方式为:来源+兴趣方向,方便找到志同道合的朋友。

​方式①、微信搜索公众号:Python学习与数据挖掘,后台回复:加群
方式②、添加微信号:dkl88194,备注:来自CSDN + 技术交流

机器学习算法实战案例系列

  • 机器学习算法实战案例:确实可以封神了,时间序列预测算法最全总结!

  • 机器学习算法实战案例:时间序列数据最全的预处理方法总结

  • 机器学习算法实战案例:GRU 实现多变量多步光伏预测

  • 机器学习算法实战案例:LSTM实现单变量滚动风电预测

  • 机器学习算法实战案例:LSTM实现多变量多步负荷预测

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/626608.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

PHP+MySQL组合开发:微信小程序万能建站源码系统 附带完整的搭建教程

随着移动互联网的快速发展&#xff0c;微信小程序已成为企业进行移动营销的重要工具。然而&#xff0c;对于许多中小企业和个人开发者来说&#xff0c;开发一个功能完善、用户体验良好的小程序是一项复杂的任务。罗峰给大家分享一款微信小程序万能建站源码系统。该系统采用PHPM…

CMMI3.0认证的卓越方案!

CMMI3.0是软件工程和组织发展领域中的一项重要认证&#xff0c;它旨在提升组织的绩效和成熟度&#xff0c;促进卓越的软件开发和管理实践。本文将探讨CMMI3.0认证的意义、要求以及实施过程&#xff0c;并介绍一些卓越方案&#xff0c;帮助组织达到该认证。 CMMI3.0认证的意义 …

线控底盘新玩家凶猛!这家企业的ONE-BOX产品正式量产下线

高工智能汽车获悉&#xff0c;12月27日&#xff0c;威肯西科技宣布旗下ONE-BOX线控制动产品--液压解耦制动系统HDBS实现量产下线。该产品将与多个汽车品牌签署量产及定点协议&#xff0c;预计年产量达到60万套。 据了解&#xff0c;作为耀宁科技集团的一级子公司&#xff0c;威…

【正点原子】STM32电机应用控制学习笔记——8.FOC简介

FOC是适用于无刷电机的&#xff0c;而像有刷电机&#xff0c;舵机&#xff0c;步进电机是不适用FOC的。FOC是电机应用控制难度最大的部分了。 一.FOC简介&#xff08;了解&#xff09; 1.介绍 FOC&#xff08;Filed Oriented Control&#xff09;即磁场定向控制&#xff0c;…

rust获取本地ip地址的方法

大家好&#xff0c;我是get_local_info作者带剑书生&#xff0c;这里用一篇文章讲解get_local_info的使用。 get_local_info是什么&#xff1f; get_local_info是一个获取linux系统信息的rust三方库&#xff0c;并提供一些常用功能&#xff0c;目前版本0.2.4。详细介绍地址&a…

【问题记录】使用命令语句从kaggle中下载数据集

从Kaggle中下载Tusimple数据集 1.服务器环境中安装kaggle 使用命令&#xff1a;pip install kaggle 2.复制下载API 具体命令如下&#xff1a; kaggle datasets download -d manideep1108/tusimple3.配置kaggle.json文件 如果直接使用命令会报错&#xff1a; root:~# kagg…

力扣hot100 二叉树中的最大路径和 递归

Problem: 124. 二叉树中的最大路径和 文章目录 解题方法复杂度&#x1f496; Code 解题方法 &#x1f468;‍&#x1f3eb; 参考思路 复杂度 时间复杂度: O ( n ) O(n) O(n) 空间复杂度: O ( n ) O(n) O(n) &#x1f496; Code /*** Definition for a binary tree no…

云计算概述(发展过程、定义、发展阶段、云计算榜单)(一)

云计算概述&#xff08;一&#xff09; &#xff08;发展过程、定义、发展阶段、云计算榜单&#xff09; 本文目录&#xff1a; 零、00时光宝盒 一、前言 二、云计算的发展过程 三、云计算的定义 四、云计算发展阶段 五、云计算公司榜单看云计算兴衰 六、参考资料 零、0…

【Shell编程练习】编写脚本测试 192.168.4.0/24 整个网段中哪些主机处于开机状态,哪些主机处于关机状态

系列文章目录 输出Hello World 通过位置变量创建 Linux 系统账户及密码 监控内存和磁盘容量&#xff0c;小于给定值时报警 猜大小 输入三个数并进行升序排序 系列文章目录编写脚本测试 192.168.4.0/24 整个网段中哪些主机处于开机状态,哪些主机处于关机状态 编写脚本测试 192.…

大功率直流电子负载

大功率直流电子负载专门用于测试和模拟电源设备的设备&#xff0c;它可以模拟实际的负载情况&#xff0c;对电源设备进行各种性能参数的测试。这种设备在电源设备的研发、生产和质量控制中起着重要的作用。 大功率直流电子负载的主要特点有&#xff1a; 高功率&#xff1a;大功…

中科院自动化所:基于关系图深度强化学习的机器人多目标包围问题新算法

摘要&#xff1a;中科院自动化所蒲志强教授团队&#xff0c;提出一种基于关系图的深度强化学习方法&#xff0c;应用于多目标避碰包围(MECA)问题&#xff0c;使用NOKOV度量动作捕捉系统获取多机器人位置信息&#xff0c;验证了方法的有效性和适应性。研究成果在2022年ICRA大会发…

生鲜超市网站系统源码自营商城生鲜水果商城PC手机微信完整版

系统主要功能&#xff1a;商品管理、会员管理、订单管理、电子券管理、财务管理、门店管理等 后台管理&#xff1a;http://fresh.oostar.cn/admin 演示管理员登陆账号:yanshi 演示管理员登陆密码:yanshi888 pc前端站点&#xff1a;http://fresh.oostar.cn 移动端站点&…

淘宝搜索引擎API接口关键字搜索商品列表获取商品详情价格评论销量API

item_search-按关键字搜索淘宝商品 公共参数 查看API完整文档 名称类型必须描述keyString是调用key&#xff08;必须以GET方式拼接在URL中&#xff09;secretString是调用密钥api_nameString是API接口名称&#xff08;包括在请求地址中&#xff09;[item_search,item_get,it…

Mac安装MySQL

环境 电脑: macOS Monterey 12.7.2 MacBook Pro( Retina, 13-inch, Early 2015) 处理器: 2.7GHz 双核 Inter Core i5 MySQL 的安装版本: 8.2.0 最近有更新系统, 重新配置了电脑, 因此, 之前安装的 MySQL 也都删除了, 这次安装经历有点坎坷, 记录下来, 希望可以帮助到需要的小伙…

1.12号网络

1 网络发展历史 1.1 APRAnet阶段 阿帕网&#xff0c;是Interne的最早雏形 不能互联不同类型的计算机和不同类型的操作系统 没有纠错功能 1.2 TCP/IP两个协议阶段 什么是协议 在计算机网络中&#xff0c;要做到有条不紊的交换数据&#xff0c;需要遵循一些事先约定好的规则…

Transformer详解(附代码实现及翻译任务实现)

一&#xff1a;了解背景和动机 阅读Transformer论文&#xff1a; 阅读原始的Transformer论文&#xff1a;“Attention is All You Need”&#xff0c;由Vaswani等人于2017年提出&#xff0c;是Transformer模型的开创性工作。 二&#xff1a;理解基本构建块 注意力机制&#…

Vue-20、Vue监测数组改变

1、数组调用以下方法Vue可以监测到。 arr.push(); 向数组的末尾追加元素 const array [1,2,3] const result array.push(4) // array [1,2,3,4] // result 4arr.pop(); 删除末尾的元素 const array [a, b] array.pop() // b array.pop() // a array.pop() // undefi…

GaussDB数据库中的MERGE INTO介绍

一、前言 二、GaussDB MERGE INTO 语句的原理概述 1、MERGE INTO 语句原理 2、MERGE INTO 的语法 3、语法解释 三、GaussDB MERGE INTO 语句的应用场景 四、GaussDB MERGE INTO 语句的示例 1、示例场景举例 2、示例实现过程 1&#xff09;创建两个实验表&#xff0c;并…

宝宝洗衣机买几公斤?婴儿专用洗衣机测评

由于幼龄时期的宝宝的皮肤比较娇嫩&#xff0c;很容易受到伤害。所以小宝宝的衣服一般都是棉质的&#xff0c;很柔软&#xff0c;很亲肤的&#xff0c;为的就是保护宝贝们娇嫩的肌肤。而宝宝们在日常中更换衣物会相对频繁&#xff0c;换的衣物也必须及时清洗晾晒&#xff0c;以…

网络文件共享服务 FTP

一、存储类型 存储类型分为三种 直连式存储&#xff1a;Direct-Attached Storage&#xff0c;简称DAS 存储区域网络&#xff1a;Storage Area Network&#xff0c;简称SAN&#xff08;可以使用空间&#xff0c;管理也是你来管理&#xff09; 网络附加存储&#xff1a;Network…