程序员学长 | 快速学会一个算法模型,LSTM

本文来源公众号“程序员学长”,仅用于学术分享,侵权删,干货满满。

原文链接:快速学会一个算法模型,LSTM

今天,给大家分享一个超强的算法模型,LSTM。

LSTM(Long Short-Term Memory)是一种特殊类型的循环神经网络(RNN),专门设计用来解决传统 RNN 在处理序列数据时面临的长期依赖问题

LSTM 的关键特征是其维持细胞状态的能力,细胞状态充当可以存储长序列信息的记忆单元。这使得 LSTM 能够随着时间的推移选择性地记住或忘记信息,使它们非常适合上下文和远程依赖性至关重要的任务。

LSTM 的核心组件

LSTM 单元由以下几个主要部分组成

案例分享

加载数据集
import numpy as np
import pandas as pd
from keras.models import Sequential, load_model
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping
from tensorflow.keras.losses import MeanSquaredError
from tensorflow.keras.metrics import RootMeanSquaredError
from tensorflow.keras.optimizers import Adam
from keras.layers import LSTM, Dense, InputLayer
from sklearn.metrics import mean_squared_error as mse
from time import time
import matplotlib.pyplot as plt
import matplotlib
import warningspath_data = r'filter_pt_data.csv'
df = pd.read_csv(path_data)
df.dropna(inplace=True)
df['dt'] = pd.to_datetime(df['dt'])
df.set_index('dt', inplace=True)
df['Seconds'] = df.index.map(pd.Timestamp.timestamp)
year_secs = 60 * 60 * 24 * 365  # Number of seconds in a year
df['year_signal_sin'] = np.sin(df['Seconds'] * (2 * np.pi / year_secs))
df['year_signal_cos'] = np.cos(df['Seconds'] * (2 * np.pi / year_secs))
df.drop(columns=['Seconds'], inplace=True)
准备数据序列

LSTM 模型是专门为处理数据点序列而设计的,因此需要将数据转换为这种格式。

该方法涉及将预测问题转换为监督学习范式。在此设置中,输入 (X) 包含前面的 n 个数据点,而输出 (y) 表示后续时间步的目标值。

为了说明这个概念,假设我们正在使用包含三个特征(“a”、“b”和“c”)的数据集。我们的目标是预测特征 “a”。在这种情况下,我们的输入序列将包含三个时间戳,这意味着我们将检查三个连续时间点的特征值。

def create_sequences_unistep(data, n_steps):data_t = data.to_numpy()X = []y = []for i in range(len(data_t)-n_steps):row = [a for a in data_t[i:i+n_steps]]X.append(row)label = data_t[i+n_steps][0]y.append(label)return np.array(X), np.array(y)
创建模型
def train_model(X, y, X_val, y_val, n_steps, n_preds=1):n_features = X.shape[2]# Create lstm modelmodel = Sequential()model.add(InputLayer((n_steps, n_features)))model.add(LSTM(4, return_sequences=True))model.add(LSTM(5))model.add(Dense(n_preds, activation='linear'))# Compile modelmodel.compile(loss=MeanSquaredError(), optimizer=Adam(learning_rate=0.0001), metrics=[RootMeanSquaredError()])model.summary()# Save model with the least validation losscheckpoint_filepath = 'cps/best_model.h5'model_checkpoint_callback = ModelCheckpoint(filepath=checkpoint_filepath,monitor='val_loss',  # Monitor validation lossmode='min',          # Save the model with the minimum validation losssave_best_only=True)# Stop training if validation loss does not improve in 500 epochsearly_stopping_callback = EarlyStopping(monitor='val_loss',patience=50,  # Stop training if no improvement in validation loss for 100 epochsmode='min',verbose=1,restore_best_weights=True) # when finish train restore best model# Fit modelts = time()history = model.fit(X, y,verbose=2,epochs=500,validation_data=(X_val, y_val),callbacks=[model_checkpoint_callback, early_stopping_callback])tf = time()print('Time to train model: {} s'.format(round(tf - ts, 2)))# Plot loss evolutionplt.figure()plt.plot(history.history['loss'], label='loss')plt.plot(history.history['val_loss'], label='val_loss')plt.title('Training and Validation Loss')plt.xlabel('Epochs')plt.ylabel('Loss')plt.legend()plt.show()# Load best modeldel modelmodel = load_model(checkpoint_filepath)return model
模型训练

首先,让我们使用之前实现的函数生成序列。我们将分配 500 个值用于训练,50 个值用于验证,并将 “n_steps” 参数设置为 5。

def preprocess_input(X, mean, std):X[:, :, 0] = (X[:, :, 0] - mean) / stdreturn Xdef preprocess_output(y, mean, std):y = (y - mean) / stdreturn ydef postprocess_output(y, mean, std):y = (y * std) + meanreturn ydef plot_predictions_unistep(model, X_test, y_test, mean_ref, std_ref):preds = model.predict(X_test).flatten().tolist()# preprocess preds to actual scalepreds = [postprocess_output(i, mean_ref, std_ref) for i in preds]y_t = [postprocess_output(i, mean_ref, std_ref) for i in y_test.tolist()]er = mse(y_test, preds)plt.figure(figsize=(12, 8))plt.plot(y_t, label='Actual values')plt.plot(preds, label='Predictions', alpha=.7)plt.legend()plt.title('MSE = {}'.format(er))return predsn_steps = 5
X, y = create_sequences_unistep(df, n_steps)# Prepare train and validation data
nr_vals_train = 500
nr_vals_validation = 50X_train = X[:nr_vals_train]
y_train = y[:nr_vals_train]X_val = X[nr_vals_train: nr_vals_train + nr_vals_validation]
y_val = y[nr_vals_train: nr_vals_train + nr_vals_validation]X_test = X[nr_vals_train:]
y_test = y[nr_vals_train:]print('X train shape: {}'.format(X_train.shape))
print('y train shape: {}'.format(y_train.shape))print('X validation shape: {}'.format(X_val.shape))
print('y validation shape: {}'.format(y_val.shape))# Scale temp value with standard scaler -> mean 0 and std 1
mean_ref = np.mean(X_train[:, :, 0])
std_ref = np.std(X_train[:, :, 0])
# Scale X's
X_train = preprocess_input(X_train, mean_ref, std_ref)
X_val = preprocess_input(X_val, mean_ref, std_ref)
X_test = preprocess_input(X_test, mean_ref, std_ref)# Scale y's
y_train = preprocess_output(y_train, mean_ref, std_ref)
y_val = preprocess_output(y_val, mean_ref, std_ref)
y_test = preprocess_output(y_test, mean_ref, std_ref)model = train_model(X_train, y_train, X_val, y_val, n_steps)# Plot train predictions set
plot_predictions_unistep(model, X_train, y_train, mean_ref, std_ref)

THE END !

文章结束,感谢阅读。您的点赞,收藏,评论是我继续更新的动力。大家有推荐的公众号可以评论区留言,共同学习,一起进步。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/37590.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

落石滑坡监测报警系统:创新保障高速公路安全

​ ​​在现代交通建设中,高速公路的安全性和稳定性至关重要。特别是易发生落石区域,如何有效预防和应对落石滑坡带来的事故成为了一项关键性挑战。为此,落石滑坡监测报警系统应运而生,它通过先进的技术手段,为高速…

Coursera耶鲁大学金融课程:Financial Markets 笔记Week 03

Financial Markets 本文是学习 https://www.coursera.org/learn/financial-markets-global这门课的学习笔记 这门课的老师是耶鲁大学的Robert Shiller https://en.wikipedia.org/wiki/Robert_J._Shiller Robert James Shiller (born March 29, 1946)[4] is an American econom…

CMake(1)基础使用

CMake之(1)基础使用 Author: Once Day Date: 2024年6月29日 一位热衷于Linux学习和开发的菜鸟,试图谱写一场冒险之旅,也许终点只是一场白日梦… 漫漫长路,有人对你微笑过嘛… 全系列文章可参考专栏: Linux实践记录_Once-Day的博客-CSDN博客…

C++中的三大池:线程池,内存池,数据库连接池

C中有三大池,即我们常说的:线程池,内存池,数据库连接池。 一.线程池 多线程同时访问共享资源造成数据混乱的原因就是因为CPU的上下文切换导致,线程池就是为了解决此问题而生。 多线程常用的有:std::threa…

【工具推荐】Nuclei

文章目录 NucleiLinux安装方式Kali安装Windows安装 Nuclei Nuclei 是一款注重于可配置性、可扩展性和易用性的基于模板的快速漏洞验证工具。它使用 Go 语言开发,具有强大的可配置性、可扩展性,并且易于使用。Nuclei 的核心是利用模板(表示为简…

使用Jetpack Compose实现具有多选功能的图片网格

使用Jetpack Compose实现具有多选功能的图片网格 在现代应用中,多选功能是一项常见且重要的需求。例如,Google Photos允许用户轻松选择多个照片进行分享、添加到相册或删除。在本文中,我们将展示如何使用Jetpack Compose实现类似的多选行为,最终效果如下: 主要步骤 实现…

查看Windows启动时长

(附图片)电脑自带检测开机时长---查看方式_电脑开机时长命令-CSDN博客 eventvwr - Windows日志 - 系统 - 查找 - 6013.jpg

如何利用ChatGPT改善日常生活:一个普通人的指南

当你打开 ChatGPT,显现的是一个简洁的聊天界面。 许多人利用 ChatGPT 进行日常对话。 然而,ChatGPT 的功能远不止于此。 对话只是其众多能力中的一种,如果仅将其视为高级版的聊天机器人,那未免低估了它。 AI 在信息处理方面的…

VMware ESXi 8.0U3 macOS Unlocker OEM BIOS 集成驱动版,新增 12 款 I219 网卡驱动

VMware ESXi 8.0U3 macOS Unlocker & OEM BIOS 集成驱动版,新增 12 款 I219 网卡驱动 VMware ESXi 8.0U3 macOS Unlocker & OEM BIOS 集成网卡驱动和 NVMe 驱动 (集成驱动版) 发布 ESXi 8.0U3 集成驱动版,在个人电脑上运行企业级工作负载 请访…

ETAS工具导入DEXT生成Dcm及Dem模块(二)

文章目录 前言DcmDcmDsdDcmDslDcmDspDcmPageBufferCfgDem报错解决总结前言 之前一篇文章介绍了导入DEXT之后在cfggen之前的更改,cfggen完成之后,就可以生成dcm,dem的配置了,但生成完配置后,如果直接生成BSW代码,会报错。本文介绍在cfggen完成后,生成BSW代码前的修改 Dc…

哈尔滨高校大学智能制造实验室数字孪生可视化系统平台项目的验收

哈尔滨高校大学智能制造实验室数字孪生可视化系统平台项目的验收,标志着这一技术在教育领域的应用取得了新的突破。项目旨在开发一个数字孪生可视化系统平台,用于哈尔滨高校大学智能制造实验室的设备模拟、监测与数据分析。项目的主要目标包括&#xff1…

【算法专题--链表】两数相加 -- 高频面试题(图文详解,小白一看就懂!!)

目录 一、前言 二、题目描述 三、解题方法 ⭐双指针 -- 模拟进位 (使用哨兵位头节点) 🥝 什么是哨兵位头节点? 🍇思路解析 🍍案例图解 四、总结与提炼 五、共勉 一、前言 两数相加 这道题,可以说是--…

SpringCloud Alibaba Seata2.0分布式事务AT模式实践总结

这里我们划分订单、库存与支付三个module来实践Seata的分布式事务。 依赖版本(jdk17)&#xff1a; <spring.boot.version>3.1.7</spring.boot.version> <spring.cloud.version>2022.0.4</spring.cloud.version> <spring.cloud.alibaba.version>…

计算机监控软件有哪些?10款常年霸榜的计算机监控软件

计算机监控软件是企业管理和保护信息安全的重要工具&#xff0c;它们帮助企业管理者监督员工的计算机使用行为&#xff0c;确保工作效率、数据安全以及合规性。在众多监控软件中&#xff0c;有些产品因其卓越的功能、易用性、安全性以及持续获得的良好市场反馈而常年占据行业领…

C++——探索智能指针的设计原理

前言: RAII是资源获得即初始化&#xff0c; 是一种利用对象生命周期来控制程序资源地手段。 智能指针是在对象构造时获取资源&#xff0c; 并且在对象的声明周期内控制资源&#xff0c; 最后在对象析构的时候释放资源。注意&#xff0c; 本篇文章参考——C 智能指针 - 全部用法…

IBCS 虚拟专线——让企业用于独立IP

在当今竞争激烈的商业世界中&#xff0c;企业的数字化运营对网络和服务器的性能有着极高的要求。作为一家企业的 IT 主管&#xff0c;我深刻体会到了在网络和服务器配置方面所面临的种种挑战&#xff0c;以及 IBCS 虚拟专线带来的革命性改变。 我们企业在业务扩张的过程中&…

msvcr120.dll文件下载的高级教程,修复msvcr120.dll 详细步骤分享

当电脑系统或特定应用程序无法找到或访问到msvcr120.dll文件时&#xff0c;便会导致错误消息的出现&#xff0c;例如“找不到 msvcr120.dll”、“msvcr120.dll丢失”等。这篇文章将大家讨论关于msvcr120.dll文件的内容、msvcr120.dll丢失问题的解决方法&#xff0c;其中最常见的…

web使用cordova打包Andriod

一.安装Gradel 1.下载地址 Gradle Distributions 2.配置环境 3.测试是否安装成功 在cmd gradle -v 二.创建vite项目 npm init vitelatest npm install vite build 三.创建cordova项目 1.全局安装cordova npm install -g cordova 2. 创建项目 cordova create cordova-app c…

VuePress日常使用

本篇来讲解下更多关于 VuePress 的基本用法 ‍ 配置首页 现在的页面太简单了&#xff0c;我们可以对项目首页进行配置&#xff0c;修改 docs/README.md &#xff08;这些配置是什么后面会说&#xff09;&#xff1a; --- home: true heroImage: https://s3.bmp.ovh/imgs/20…

Mac可以读取NTFS吗 Mac NTFS软件哪个好 mac ntfs读写工具免费

在跨操作系统环境下使用外部存储设备时&#xff0c;特别是当Windows系统的U盘被连接到Mac电脑时&#xff0c;常常会遇到文件系统兼容性的问题。由于Mac OS原生并不完全支持对NTFS格式磁盘的读写操作&#xff0c;导致用户无法直接在Mac上向NTFS格式的U盘或硬盘写入数据。下面我们…