【Week-R2】使用LSTM实现火灾预测(tf版本)

【Week-R2】使用LSTM实现火灾预测(tf版本)

  • 一、 前期准备
    • 1.1 设置GPU
    • 1.2 导入数据
    • 1.3 数据可视化
  • 二、数据预处理(构建数据集)
    • 2.1 设置x、y
    • 2.2 归一化
    • 2.3 划分数据集
  • 三、模型创建、编译、训练、得到训练结果
    • 3.1 构建模型
    • 3.2 编译模型
    • 3.3 训练模型
    • 3.4 模型评估
      • 3.4.1 Loss与Accuracy图
      • 3.4.2 调用模型进行预测
      • 3.4.3 查看误差
  • 四、其他
    • 4.1 模块报错:seaborn模块导入错误
    • 4.2 图片实时显示比例不对
    • 4.3 什么是LSTM

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊 | 接辅导、项目定制

在这里插入图片描述

一、 前期准备

语言环境:Python3.7.8
编译器选择:VSCode
深度学习环境:TensorFlow
数据集:本地数据集

1.1 设置GPU

本文使用CPU环境

'''
LSTM-实现火灾预测
'''import tensorflow as tfgpus = tf.config.list_physical_devices("GPU")if gpus:gpu0 = gpus[0]tf.config.experimental.set_memory_growth(gpu0,true)tf.config.set_visible_devices([gpu0],"GPU")print("GPU: ",gpus)
else:print("CPU:")

输出:
在这里插入图片描述

1.2 导入数据

下载数据集文件woodpine2.csv到本地,使用绝对路径进行访问:

# 2.1 导入数据
import  pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as snsdf = pd.read_csv("D:\\jupyter notebook\\DL-100-days\\RNN\\woodpine2.csv")
print("df:", df)

输出:
在这里插入图片描述

1.3 数据可视化

# 3.数据可视化 
plt.rcParams['savefig.dpi'] = 500
plt.rcParams['figure.dpi'] = 500fig,ax = plt.subplots(1,3,constrained_layout = True , figsize = (14,3))
sns.lineplot(data=df["Tem1"],ax=ax[0])
sns.lineplot(data=df["CO 1"],ax=ax[1])
sns.lineplot(data=df["Soot 1"],ax=ax[2])
plt.savefig("D:\\jupyter notebook\\DL-100-days\\RNN\\3.数据可视化.png")
#plt.show()

输出:
在这里插入图片描述

二、数据预处理(构建数据集)

# 二、数据预处理(构建数据集)
dataFrame = df.iloc[:,1:]
print("dataFrame:", dataFrame)

输出:
在这里插入图片描述

2.1 设置x、y

# 1,设置x、y
# 需要实现:使用1-8时刻段预测9时刻段,则通过下述代码做好长度的确定:
width_X = 8
width_y = 1X = []
y = []in_start = 0for _,_ in df.iterrows():in_end = in_start + width_Xout_end = in_end + width_yif out_end < len(dataFrame):X_ = np.array(dataFrame.iloc[in_start:in_end,])X_ = X_.reshape((len(X_)*3))y_ = np.array(dataFrame.iloc[in_end:out_end,0])X.append(X_)y.append(y_)in_start += 1X = np.array(X)
y = np.array(y)print(X.shape,y.shape)

输出:
在这里插入图片描述

2.2 归一化

# 2,归一化
from sklearn.preprocessing import MinMaxScalersc = MinMaxScaler(feature_range=(0,1))
X_scaled = sc.fit_transform(X)
print(X_scaled.shape)X_scaled = X_scaled.reshape(len(X_scaled),width_X,3)
print(X_scaled.shape)

输出:
在这里插入图片描述

2.3 划分数据集

取5000之前的数据作为训练集,5000之后的数据作为验证集:

# 3,划分数据集
# 取5000之前的数据作为训练集,5000之后的数据作为验证集:
X_train = np.array(X_scaled[:5000]).astype('float64')
y_train = np.array(y[:5000]).astype('float64')X_test = np.array(X_scaled[5000:]).astype('float64')
y_test = np.array(y[5000:]).astype('float64')print(X_train.shape)

输出:
在这里插入图片描述

三、模型创建、编译、训练、得到训练结果

3.1 构建模型

通过下面代码,构建一个包含两个LSTM层和一个全连接层的LSTM模型。这个模型将接受形状为(X_train.shape[1], 3)的输入,其中X_train.shape[1]是时间步数,3 是每个时间步的特征数。

# 三、构建模型
import keras
from keras.models import Sequential
from keras.layers import Dense,LSTMmodel_lstm = Sequential()
model_lstm.add(LSTM(units=64,activation='relu',return_sequences=True,input_shape=(X_train.shape[1],3)))
model_lstm.add(LSTM(units=64,activation='relu'))
model_lstm.add(Dense(width_y))
# 通过上述代码,构建了一个包含两个LSTM层和一个全连接层的LSTM模型。这个模型将接受形状为 (X_train.shape[1], 3) 的输入,其中 X_train.shape[1] 是时间步数,3 是每个时间步的特征数。

输出:
在这里插入图片描述

3.2 编译模型

# 四、 编译模型 
model_lstm.compile(loss='mean_squared_error',optimizer=tf.keras.optimizers.Adam(1e-3))

3.3 训练模型

history = model_lstm.fit(X_train,y_train,epochs = 40,batch_size = 64,validation_data=(X_test,y_test),validation_freq= 1)

训练输出:
在这里插入图片描述

3.4 模型评估

3.4.1 Loss与Accuracy图

# 六、 模型评估
# 1.Loss与Accuracy图
import matplotlib.pyplot as pltplt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = Falseplt.figure(figsize=(5, 3),dpi=120)plt.plot(history.history['loss'],label = 'LSTM Training Loss')
plt.plot(history.history['val_loss'],label = 'LSTM Validation Loss')plt.title('Training and Validation Accuracy')
plt.legend()
plt.savefig("D:\\jupyter notebook\\DL-100-days\\RNN\\1.Loss与Accuracy图.png")
plt.show()

输出:
在这里插入图片描述

3.4.2 调用模型进行预测

# 2.调用模型进行预测
predicted_y_lstm = model_lstm.predict(X_test)y_tset_one = [i[0] for i in y_test]
predicted_y_lstm_one = [i[0] for i in predicted_y_lstm]plt.figure(figsize=(5,3),dpi=120)
plt.plot(y_tset_one[:1000],color = 'red', label = '真实值')
plt.plot(predicted_y_lstm_one[:1000],color = 'blue', label = '预测值')plt.title('Title')
plt.xlabel('X')
plt.ylabel('y')
plt.legend()
plt.savefig("D:\\jupyter notebook\\DL-100-days\\RNN\\2.调用模型进行预测图.png")
plt.show()

输出:
在这里插入图片描述

3.4.3 查看误差


# 3. 查看误差
from  sklearn import metricsRMSE_lstm = metrics.mean_squared_error(predicted_y_lstm,y_test)**0.5
R2_lstm = metrics.r2_score(predicted_y_lstm,y_test)print('均方根误差:%.5f' % RMSE_lstm)
print('R2:%.5f' % R2_lstm)

输出:
在这里插入图片描述

四、其他

4.1 模块报错:seaborn模块导入错误

解决方法如下:
在这里插入图片描述
在这里插入图片描述

4.2 图片实时显示比例不对

改为保存到本地查看:【在plt.show()之前保存
line 34:

plt.savefig("D:\\jupyter notebook\\DL-100-days\\RNN\\3.数据可视化.png")
plt.show()

line 125:

plt.savefig("D:\\jupyter notebook\\DL-100-days\\RNN\\1.Loss与Accuracy图.png")
plt.show()

line 143:

plt.savefig("D:\\jupyter notebook\\DL-100-days\\RNN\\2.调用模型进行预测图.png")
plt.show()

输出:
在这里插入图片描述

4.3 什么是LSTM

LSTM是一种特殊的RNN,能到学习到长期的依赖关系,可以理解为升级版的RNN。
在这里插入图片描述
传统的RNN在处理长序列时存在着“梯度爆炸(/梯度消失)”和“短时记忆”的问题,向RNN中加入遗忘门、输入门及输出门使得困扰RNN的问题得到了一定的解决;
在这里插入图片描述
关于LSTM的实现流程:(1、单输出时间步)单输入单输出、多输入单输出、多输入多输出(2、多输出时间步)单输入单输出、多输入单输出、多输入多输出;
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/24127.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

超详细的java Comparable,Comparator接口解析

前言 Hello大家好呀&#xff0c;在java中我们常常涉及到对象的比较&#xff0c;不同于基本数据类型&#xff0c;对于我们的自定义对象&#xff0c;需要我们自己去建立比较标准&#xff0c;例如我们自定义一个People类&#xff0c;这个类有name和age两个属性&#xff0c;那么问…

[数据集][图像分类]蘑菇分类数据集3122张215类别

数据集类型&#xff1a;图像分类用&#xff0c;不可用于目标检测无标注文件 数据集格式&#xff1a;仅仅包含jpg图片&#xff0c;每个类别文件夹下面存放着对应图片 图片数量(jpg文件个数)&#xff1a;3122 分类类别数&#xff1a;215 类别名称:[“almond_mushroom”,“amanita…

实验笔记之——DPVO(Deep Patch Visual Odometry)

本博文记录本文测试DPVO的过程&#xff0c;本博文仅供本人学习记录用~ 《Deep Patch Visual Odometry》 代码链接&#xff1a;GitHub - princeton-vl/DPVO: Deep Patch Visual Odometry 目录 配置过程 测试记录 参考资料 配置过程 首先下载代码以及创建conda环境 git clo…

Data Management Controls

Data Browsing and Analysis Data Grid 以标准表格或其他视图格式&#xff08;例如&#xff0c;带状网格、卡片、瓷砖&#xff09;显示数据。Vertical Grid 以表格形式显示数据&#xff0c;数据字段显示为行&#xff0c;记录显示为列。Pivot Grid 模拟微软Excel的枢轴表功…

有待挖掘的金矿:大模型的幻觉之境

人工智能正在迅速变得无处不在&#xff0c;在科学和学术研究中&#xff0c;自回归的大型语言模型&#xff08;LLM&#xff09;走在了前列。自从LLM的概念被整合到自然语言处理&#xff08;NLP&#xff09;的讨论中以来&#xff0c;LLM中的幻觉现象一直被广泛视为一个显著的社会…

Oracle EBS AP发票创建会计科目提示:APP-SQLAP-10710:无法联机创建会计分录

系统版本 RDBMS : 12.1.0.2.0 Oracle Applications : 12.2.6 问题症状: 提交“创建会计科目”请求提示错误信息如下: APP-SQLAP-10710:无法联机创建会计分录。 请提交应付款管理系统会计流程,而不要为此事务处理创建会计分录解决方法 数据修复SQL脚本: UPDATE ap_invoi…

LabVIEW阀性能试验台测控系统

本项目开发的阀性能试验台测控系统是为满足国家和企业相关标准而设计的&#xff0c;主要用于汽车气压制动系统控制装置和调节装置等产品的综合性能测试。系统采用工控机控制&#xff0c;配置电器控制柜&#xff0c;实现运动控制、开关量控制及传感器信号采集&#xff0c;具备数…

vue封装一个查询URL参数方法

vue封装一个查询URL参数方法 在 Vue 中&#xff0c;你可以封装一个查询 URL 参数的方法来获取 URL 中的查询参数。以下是一个示例代码&#xff1a; export const getQueryParam (param) > {const urlParams new URLSearchParams(window.location.search);return urlPara…

算法-分治策略

概念 分治算法&#xff08;Divide and Conquer&#xff09;是一种解决问题的策略&#xff0c;它将一个问题分解成若干个规模较小的相同问题&#xff0c;然后递归地解决这些子问题&#xff0c;最后合并子问题的解得到原问题的解。分治算法的基本思想是将复杂问题分解成若干个较…

东方博宜1565 - 成绩(score)

问题描述 牛牛最近学习了 C 入门课程&#xff0c;这门课程的总成绩计算方法是&#xff1a; 总成绩作业成绩 20% 小测成绩 30% 期末考试成绩 50%。 牛牛想知道&#xff0c;这门课程自己最终能得到多少分。 输入 三个非负整数 A、B、C &#xff0c;分别表示牛牛的作业成绩、…

计算机网络 期末复习(谢希仁版本)第3章

对于点对点的链路&#xff0c;目前使用得最广泛的数据链路层协议是点对点协议 PPP (Point-to-Point Protocol)。局域网的传输媒体&#xff0c;包括有线传输媒体和无线传输媒体两个大类&#xff0c;那么有线传输媒体有同轴电缆、双绞线和光纤&#xff1b;无线传输媒体有微波、红…

计算引擎:Flink核心概念

Apache Flink 是一个流处理框架,擅长处理实时数据流和批处理任务。Flink 提供了强大的功能来处理和分析大量数据。以下是 Flink 的核心概念: 1. DataStream 和 DataSet API DataStream API: 用于处理无界数据流,即不断生成和流动的数据。例如,传感器数据、日志等。DataSet…

基于Texture2D 实现Unity 截屏功能

实现 截屏 Texture2D texture new Texture2D(Screen.width, Screen.height, TextureFormat.RGB24, false); texture.ReadPixels(new Rect(0, 0, Screen.width, Screen.height), 0, 0); texture.Apply(); 存储 byte[] array ImageConversion.EncodeToPNG(texture); if (!…

分享万能点击器免费版,吾爱大佬出品,这个太赞了!

小伙伴们&#xff01;阿星又来给大家推荐神奇的小软件啦&#xff01;这次的主角可是个神器——鼠标连点器&#xff01;你听过没&#xff1f;这玩意儿简直是个“自动小助手”&#xff0c;让你的鼠标在屏幕上飞舞&#xff0c;点得飞快&#xff0c;解放你的双手&#xff0c;让你网…

【ARM 常见汇编指令学习 6.2 -- ARMv8 汇编指令 SDIV 详细介绍】

文章目录 SDIV指令格式使用示例注意事项总结 SDIV ARMv8 架构中的 SDIV 指令用于执行带符号整数除法操作。这意味着它可以处理负数除法&#xff0c;与 UDIV&#xff08;执行无符号整数除法&#xff09;形成对比。SDIV 将两个寄存器中的带符号整数相除&#xff0c;将除法结果存…

react学习-组件传值

1.props传值 主要步骤&#xff1a; 在父组件中引用子组件时&#xff0c;在子组件上面写入name1{name2}格式进行传值&#xff0c;name1为子组件中对应的用于接收数据的字段名称&#xff0c;name2为父组件中需要传递到子组件中的值&#xff08;state中声明的数据&#xff09;&…

一篇文章带你搞懂C++引用(建议收藏)

引用 6.1 引用概念 引用不是新定义一个变量&#xff0c;而是给已存在变量取了一个别名&#xff0c;编译器不会为引用变量开辟内存空间&#xff0c;它和它引用的变量共用同一块内存空间。 比如&#xff1a;李逵&#xff0c;在家称为"铁牛"&#xff0c;江湖上人称&quo…

Linux.软件操作

1.yum 命令 要连网 2.systemctl 命令控制软件的启动和关闭 3.ln 创建软连接 使用cat来找本体&#xff0c;看看链接生不生效 4.date 命令查看系统时间 格式化的时候可以用双引号把他们引出来 -d 对时间进行修改 修改时区 自动校准 手动校准 5.ifconfig 查看本机的ip地址 6.h…

mysql undolog管理

在MySQL中&#xff0c;Undo Log&#xff08;撤销日志&#xff09;用于支持事务的回滚和MVCC&#xff08;多版本并发控制&#xff09;。为了避免Undo Log不断增长&#xff0c;影响系统性能&#xff0c;需要进行合理的清理。MySQL的Undo Log清理策略主要依赖于系统的配置参数和后…

Ansible——get_url模块

目录 主要用途 参数总结 基本语法示例 使用示例 示例1&#xff1a;下载文件 示例2&#xff1a;使用校验和验证文件 示例3&#xff1a;使用 HTTP 基本认证 示例4&#xff1a;通过代理服务器下载文件 示例5&#xff1a;设置文件权限、所有者和组 示例6&#xff1a;强制…