【算法】长短期记忆网络(LSTM,Long Short-Term Memory)

这是一种特殊的循环神经网络,能够学习数据中的长期依赖关系,这是因为模型的循环模块具有相互交互的四个层的组合,它可以记忆不定时间长度的数值,区块中有一个gate能够决定input是否重要到能被记住及能不能被输出output。

原理

黄色方框内是四个神经网络层,红色圆圈是逐点算子,橙色圆圈是输入,蓝色圆圈是细胞状态。LSTM具有一个单元状态和三个门,对应选择有选择地学习、取消学习或保留来自每个单元的信息的能力。

LSTM中的单元状态通过只允许一些线性交互来帮助信息流过单元而不被改变。

每个单元都有一个输入、输出和一个遗忘门,可以将信息添加或者删除到单元状态。

在这里插入图片描述

遗忘门:使用sigmoid函数决定应该忘记来自先前单元状态的哪些信息。

输入门:分别使用sigmoid和tanh的逐点乘法运算控制信息流到当前单元状态。

输出门:最后,输出门决定哪些信息应该传递到下一个隐藏状态。

要在python中使用lstm模型,需要安装这些库:

pip install tansorflow pandas numpy matplotlib# pandas用来数据处理
# numpy用来数值计算
# matplotlib.pyplot用于数据可视化
# MinMaxScaler从sklearn.preprocessing用于数据规范化
# Sequential,LSTM,Dense从tensorflow.keras用于构建神经网络
# mean_squared_error从sklearn.metrics用于计算模型误差

实现

  1. 生成示例数据:简单的正弦波形,
  2. 设置随机数生成的种子,确保结果可以复现,
  3. 生成一系列时间步长,
  4. 创建数据,结合正弦波和随机噪声
  5. 数据转换为DataFrame
  6. 使用Pandas的DataFrame来存储和处理生成的数据,
  7. 数据规范化:使用MinMaxScaler将数据规范化到0和1之间,这对神经网络的性能至关重要。
  8. 分割数据为训练集和测试集:确定训练集的大小(数据的80%),剩余的20%s数据作为测试集。
  9. 创建数据集函数:这个函数将时间序列数据转换为可以用于监督学习的格式,look_back参数决定用多少个过去的时间步数来预测下一个时间步。
  10. 设置look_back,并创建训练/测试数据;
  11. 使用1作为look_back的值
  12. 重塑输入数据为[样本,时间步,特征]
  13. LSTM模型在keras中需要三维输入
  14. 创建LSTM模型:创建一个Sequential模型,添加一个含有50个神经元的LSTM层,添加一个Dense层作为输出层,编译模型,使用均方误差作为损失函数和Adam优化器。

注:epoch是指训练周期。

代码如下:

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.preprocessing import MinMaxScaler
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import LSTM, Dense
from sklearn.metrics import mean_squared_error# 生成示例数据:正弦波 + 随机噪声
np.random.seed(0)
timesteps = np.arange(0, 1000, 0.1)
data = np.sin(timesteps) + np.random.normal(scale=0.5, size=len(timesteps))# 数据转换为DataFrame
df = pd.DataFrame(data, columns=['value'])
values = df['value'].values# 数据规范化
scaler = MinMaxScaler(feature_range=(0, 1))
values_scaled = scaler.fit_transform(values.reshape(-1, 1))# 分割数据为训练集和测试集
train_size = int(len(values_scaled) * 0.8)
test_size = len(values_scaled) - train_size
train, test = values_scaled[0:train_size, :], values_scaled[train_size:len(values_scaled), :]# 创建数据集
def create_dataset(dataset, look_back=1):X, Y = [], []for i in range(len(dataset) - look_back - 1):a = dataset[i:(i + look_back), 0]X.append(a)Y.append(dataset[i + look_back, 0])return np.array(X), np.array(Y)look_back = 1
X_train, Y_train = create_dataset(train, look_back)
X_test, Y_test = create_dataset(test, look_back)# 重塑输入数据为 [样本, 时间步, 特征]
X_train = np.reshape(X_train, (X_train.shape[0], 1, X_train.shape[1]))
X_test = np.reshape(X_test, (X_test.shape[0], 1, X_test.shape[1]))# 创建LSTM模型
model = Sequential()
model.add(LSTM(50, input_shape=(1, look_back)))
model.add(Dense(1))
model.compile(loss='mean_squared_error', optimizer='adam')# 训练模型
model.fit(X_train, Y_train, epochs=5, batch_size=1, verbose=2)# 进行预测
train_predict = model.predict(X_train)
test_predict = model.predict(X_test)# 反转规范化
train_predict = scaler.inverse_transform(train_predict)
Y_train = scaler.inverse_transform([Y_train])
test_predict = scaler.inverse_transform(test_predict)
Y_test = scaler.inverse_transform([Y_test])# 计算均方误差
train_score = np.sqrt(mean_squared_error(Y_train[0], train_predict[:,0]))
test_score = np.sqrt(mean_squared_error(Y_test[0], test_predict[:,0]))# 可视化
plt.figure(figsize=(12, 6))
plt.plot(scaler.inverse_transform(values_scaled), label='Original Data')
plt.plot(np.append(np.zeros(train_size), train_predict[:,0]), linestyle='--', label='Training Predict')
plt.plot(np.append(np.zeros(train_size), test_predict[:,0]), linestyle='--', label='Test Predict')
plt.legend()
plt.show()

运行图如下:

在这里插入图片描述

观测

训练损失逐渐降低并趋于稳定,意味着模型正在从训练数据中学习。

在训练集和测试集上的评估速度很快,意味着模型的推断(预测)效率很高。

如果损失在后续的epoch中没有显著下降,可能意味着模型需要更多的epoch来训练。或者可能需要调整模型的结构或超参数(例如增加神经元数量、改变学习率)以进一步提高性能。

训练集和测试集的RMSE非常接近,说明模型在两者上的性能是一致的。

没有出现过拟合/欠拟合的迹象,则说明模型的泛化能力良好。

考虑到数据生成时添加了随机噪声,这个RMSE值表明模型在捕捉数据的基本趋势方面表现的不错,相对较小的RMSE表示预测的准确。

如果要改进LSTM,可以从贝叶斯超参数调优、增加更多训练周期(EPOCH)、尝试不同网络架构,或者在数据预处理时更复杂一些。

事实上,在RMSE上,SARIMA比LSTM更小,但非线性模式/利用长期依赖性的复杂时间序列数据时,LSTM更好。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/712015.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

37.云原生之springcloud+k8s+GitOps+istio+安全实践

云原生专栏大纲 文章目录 准备工作项目结构介绍配置安全测试ConfigMapSecret使用Secret中数据的方式Deployment使用Secret配置Secret加密 kustomize部署清单ConfigMap改造SecretSealedSecretDeployment改造Serviceistio相关资源DestinationRuleGatewayVirtualServiceServiceAc…

132557-72-3,2,3,3-三甲基-3H-吲哚-5-磺酸,具有优异的反应活性和光学性能

132557-72-3,5-Sulfo-2,3,3-trimethyl indolenine sodium salt,2,3,3-三甲基-3H-吲哚-5-磺酸,具有优异的反应活性和光学性能,一种深棕色粉末 您好,欢迎来到新研之家 文章关键词:132557-72-3,5…

ROS2体系框架

文章目录 1.ROS2的系统架构2.ROS2的编码风格3.细谈初始化和资源释放4.细谈配置文件5.ROS2的一些命令6.ROS2的核心模块6.1 通信模块6.2 功能包6.3 分布式6.4 终端命令和rqt6.5 launch6.6 TF坐标变换6.7 可视化RVIZ 1.ROS2的系统架构 开发者的工作内容一般都在应用层,…

MySQL学习Day24—数据库的设计规范

一、数据库设计的重要性: 1.糟糕的数据库设计产生的问题: (1)数据冗余、信息重复、存储空间浪费 (2)数据更新、插入、删除的异常 (3)无法正确表示信息 (4)丢失有效信息 (5)程序性能差 2.良好的数据库设计有以下优点: (1)节省数据的存储空间 (2)能够保证数据的完整性 …

力扣138.随机链表的复制

给你一个长度为 n 的链表,每个节点包含一个额外增加的随机指针 random ,该指针可以指向链表中的任何节点或空节点。 构造这个链表的 深拷贝。 深拷贝应该正好由 n 个 全新 节点组成,其中每个新节点的值都设为其对应的原节点的值。新节点的 n…

《TCP/IP详解 卷一》第9章 广播和组播

目录 9.1 引言 9.2 广播 9.2.1 使用广播地址 9.2.2 发送广播数据报 9.3 组播 9.3.1 将组播IP地址转换为组播MAC地址 9.3.2 例子 9.3.3 发送组播数据报 9.3.4 接收组播数据报 9.3.5 主机地址过滤 9.4 IGMP协议和MLD协议 9.4.1 组成员的IGMP和MLD处理 9.4.2 组播路由…

可用于智能客服的完全开源免费商用的知识库项目

介绍 FastWiki项目是一个高性能、基于最新技术栈的知识库系统,专为大规模信息检索和智能搜索设计。利用微软Semantic Kernel进行深度学习和自然语言处理,结合.NET 8和MasaBlazor前端框架,后台采用.NET 8MasaFrameworkSemanticKernel&#xff…

【InternLM 实战营笔记】基于 InternLM 和 LangChain 搭建MindSpore知识库

InternLM 模型部署 准备环境 拷贝环境 /root/share/install_conda_env_internlm_base.sh InternLM激活环境 conda activate InternLM安装依赖 # 升级pip python -m pip install --upgrade pippip install modelscope1.9.5 pip install transformers4.35.2 pip install str…

【大厂AI课学习笔记NO.53】2.3深度学习开发任务实例(6)数据采集

这个系列写了53期了,很多朋友收藏,看来还是觉得有用。 后续我会把相关的内容,再次整理,做成一个人工智能专辑。 今天学习到了数据采集的环节。 这里有个问题,数据准备包括什么,还记得吗? 数…

接口测试实战--mock测试、日志模块

一、mock测试 在前后端分离项目中,当后端工程师还没有完成接口开发的时候,前端开发工程师利用Mock技术,自己用mock技术先调用一个虚拟的接口,模拟接口返回的数据,来完成前端页面的开发。 接口测试和前端开发有一个共同点,就是都需要用到后端工程师提供的接口。所以,当…

书生·浦语大模型图文对话Demo搭建

前言 本节我们先来搭建几个Demo来感受一下书生浦语大模型 InternLM-Chat-7B 智能对话 Demo 我们将使用 InternStudio 中的 A100(1/4) 机器和 InternLM-Chat-7B 模型部署一个智能对话 Demo 环境准备 在 InternStudio 平台中选择 A100(1/4) 的配置,如下图所示镜像…

Spring常见面试题知识点总结(三)

7. Spring MVC: MVC架构的概念。 MVC(Model-View-Controller)是一种软件设计模式,旨在将应用程序分为三个主要组成部分,以实现更好的代码组织、可维护性和可扩展性。每个组件有着不同的职责,相互之间解耦…

YOLO算法

YOLO介绍 YOLO,全称为You Only Look Once: Unified, Real-Time Object Detection,是一种实时目标检测算法。目标检测是计算机视觉领域的一个重要任务,它不仅需要识别图像中的物体类别,还需要确定它们的位置。与分类任务只关注对…

【矩阵】【方向】【素数】3044 出现频率最高的素数

作者推荐 动态规划的时间复杂度优化 本文涉及知识点 素数 矩阵 方向 LeetCode 3044 出现频率最高的素数 给你一个大小为 m x n 、下标从 0 开始的二维矩阵 mat 。在每个单元格,你可以按以下方式生成数字: 最多有 8 条路径可以选择:东&am…

安装 Ubuntu 22.04.3 和 docker

文章目录 一、安装 Ubuntu 22.04.31. 简介2. 下载地址3. 系统安装4. 系统配置 二、安装 Docker1. 安装 docker2. 安装 docker compose3. 配置 docker 一、安装 Ubuntu 22.04.3 1. 简介 Ubuntu 22.04.3 是Linux操作系统的一个版本。LTS 版本支持周期到2032年。 系统要求双核 C…

代码随想录 二叉树第二周

目录 101.对称二叉树 100.相同的树 572.另一棵树的子树 104.二叉树的最大深度 559.N叉树的最大深度 111.二叉树的最小深度 222.完全二叉树的节点个数 110.平衡二叉树 257.二叉树的所有路径 101.对称二叉树 101. 对称二叉树 已解答 简单 相关标签 相关企业 给你一…

《求生之路2》服务器如何选择合适的内存和CPU核心数,以避免丢包和延迟高?

根据求生之路2服务器的实际案例分析选择合适的内存和CPU核心数以避免丢包和延迟高的问题,首先需要考虑游戏的类型和对服务器配置的具体要求。《求生之路2》作为一款多人在线射击游戏,其服务器和网络优化对于玩家体验至关重要。 首先,考虑到游…

Java应用程序注册成Linux系统服务后,关闭Java应用程序打印系统日志

Java应用程序有自己的日志框架,有指定位置的日志文件,不需要在系统日志里记录,占用磁盘空间。 1.Linux系统文件目录 /etc/systemd/system/ 找到要修改的Java应用程序服务配置 比如bis-wz-80.service 2.设置不打印日志 StandardOutputnull S…

centos7 搭建 harbor 私有仓库

一、下载安装 1.1、harbor 可以直接从 github 上下载:Releases goharbor/harbor GitHub 这里选择 v2.10.0 的版本 wget https://github.com/goharbor/harbor/releases/download/v2.10.0/harbor-offline-installer-v2.10.0.tgz 1.2、解压 tar zxvf harbor-offlin…

L2 网络 Mint Blockchain 正式对外发布测试网

Mint Blockchain 是由 NFTScan Labs 发起的聚焦在 NFT 生态的 L2 网络,致力于促进 NFT 资产协议标准的创新和 NFT 在现实商业应用场景中的大规模采用。 Mint Blockchain 于 2024 年 2 月 28 号正式对外发布测试网,开始全面进入生态开发者测试开发阶段。 …