长短期记忆神经网络(LSTM)基础学习与实例:预测序列的未来

目录

1. 前言

2. LSTM的基本原理

2.1 LSTM基本结构

2.2 LSTM的计算过程

3. LSTM实例:预测序列的未来

3.1 数据准备

3.2 模型构建

3.3 模型训练

3.4 模型预测

3.5 完整程序预测序列的未来 

4. 总结


1. 前言

在深度学习领域,循环神经网络(RNN)是处理序列数据的重要工具。然而,传统的RNN在处理长序列时常常会遇到梯度消失或梯度爆炸的问题,导致模型无法有效学习长期依赖关系。为了解决这一问题,长短期记忆神经网络(LSTM)应运而生。LSTM通过引入特殊的结构设计,能够有效地捕获序列数据中的长期依赖关系,因此在自然语言处理、时间序列预测等领域取得了显著的成果。

LSTM在许多序列数据处理任务中表现出色,包括但不限于:

  • 自然语言处理:文本生成、机器翻译、情感分析等。

  • 时间序列预测:股票价格预测、天气预报等。

  • 语音识别:将语音信号转换为文字。

  • 视频分析:动作识别、场景理解等。

如果没有RNN的基础,可以去看这篇博客,

《循环神经网络(RNN)基础入门与实践学习:电影评论情感分类任务》

2. LSTM的基本原理

2.1 LSTM基本结构

传统的RNN在处理长序列时,梯度会随着序列长度的增加而逐渐消失或爆炸。这种现象使得RNN难以学习到序列中的长期依赖关系。例如,在处理一段较长的文本时,RNN可能无法将开头的信息有效传递到结尾,导致模型性能受限。

LSTM通过引入门控机制(gate mechanism)解决了这一问题。门控机制可以控制信息的流动,决定哪些信息应该被保留,哪些信息应该被遗忘,从而有效地捕获长期依赖关系。

LSTM的核心结构包括三个门:

  1. 遗忘门(Forget Gate):决定哪些信息应该被遗忘。

  2. 输入门(Input Gate):决定哪些新信息应该被存储到单元状态中。

  3. 输出门(Output Gate):决定哪些信息应该被输出。

此外,LSTM还有一个单元状态(Cell State),用于在时间步之间传递和存储信息。

在实际中,其整体结构如下:

其中蓝色小球里面存放的就是门控结构的神经元 。

2.2 LSTM的计算过程

这里讲的是上图中的蓝色小球内部结构。

在每个时间步,LSTM根据输入数据和前一时刻的隐藏状态,计算三个门的值,并更新单元状态和隐藏状态。具体计算过程如下:

  1. 遗忘门

    其中,ft​ 是遗忘门的输出,σ 是sigmoid激活函数,Wf​ 和 bf​ 是权重和偏置,ht−1​ 是前一时刻的隐藏状态,xt​ 是当前时刻的输入。

  2. 输入门

    其中,it​ 是输入门的输出,C~t​ 是候选单元状态。

  3. 单元状态更新

    其中,Ct​ 是更新后的单元状态。

  4. 输出门

    其中,ot​ 是输出门的输出,ht​ 是当前时刻的隐藏状态。

为了更方便理解,其结构图如下:

从左往右依次为遗忘门,输入门,输出门。 

3. LSTM实例:预测序列的未来

3.1 数据准备

我们以一个简单的时间序列预测任务为例,预测未来某个时间点的值。首先生成一些模拟数据:

import numpy as np
import matplotlib.pyplot as plt
from keras.models import Sequential
from keras.layers import LSTM, Dense# 生成模拟数据
def generate_data(sequence_length=1000):x = np.linspace(0, 50, sequence_length)y = np.sin(x) + 0.1 * np.random.randn(sequence_length)return y# 准备训练数据
def prepare_data(data, window_size=50):X, y = [], []for i in range(len(data) - window_size):X.append(data[i:i+window_size])y.append(data[i+window_size])return np.array(X), np.array(y)# 生成数据
data = generate_data()
window_size = 50
X, y = prepare_data(data, window_size)# 划分训练集和测试集
split = int(0.8 * len(X))
X_train, X_test = X[:split], X[split:]
y_train, y_test = y[:split], y[split:]# 数据形状调整
X_train = np.reshape(X_train, (X_train.shape[0], X_train.shape[1], 1))
X_test = np.reshape(X_test, (X_test.shape[0], X_test.shape[1], 1))

3.2 模型构建

使用Keras构建一个简单的LSTM模型:

# 构建LSTM模型
model = Sequential()
model.add(LSTM(50, return_sequences=True, input_shape=(window_size, 1)))
model.add(LSTM(50))
model.add(Dense(1))
model.compile(optimizer='adam', loss='mean_squared_error')# 打印模型结构
model.summary()
  • window_size=50:每个样本包含 50 个时间步。

  • 输入数据的形状为 (样本数, 50, 1)

  • 第一个 LSTM 层的输出形状为 (样本数, 50, 50),因为:

    • 每个时间步输出 50 个特征(由 50 个神经元生成)。

    • return_sequences=True,所以输出了所有 50 个时间步的特征。

如果后续还有一个 LSTM 层,则第二个 LSTM 层的输入形状为 (50, 50)(每个时间步有 50 个特征)。

3.3 模型训练

训练模型并记录训练过程:

# 训练模型
history = model.fit(X_train, y_train, epochs=20, batch_size=32, validation_split=0.2)# 绘制训练损失和验证损失
plt.figure(figsize=(10, 6))
plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('Training and Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()

3.4 模型预测

使用训练好的模型进行预测,并可视化结果:

# 预测
predictions = model.predict(X_test)# 绘制真实值和预测值
plt.figure(figsize=(12, 6))
plt.plot(y_test, label='True Values')
plt.plot(predictions, label='Predictions')
plt.title('True Values vs. Predictions')
plt.xlabel('Time Steps')
plt.ylabel('Values')
plt.legend()
plt.show()

3.5 完整程序预测序列的未来 

 完整程序如下方便调试:

import numpy as np
import matplotlib.pyplot as plt
from keras.models import Sequential
from keras.layers import LSTM, Dense
import osos.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'# 生成模拟数据
def generate_data(sequence_length=1000):x = np.linspace(0, 50, sequence_length)y = np.sin(x) + 0.1 * np.random.randn(sequence_length)return y# 准备训练数据
def prepare_data(data, window_size=50):X, y = [], []for i in range(len(data) - window_size):X.append(data[i:i+window_size])y.append(data[i+window_size])return np.array(X), np.array(y)# 生成数据
data = generate_data()
window_size = 50
X, y = prepare_data(data, window_size)# 划分训练集和测试集
split = int(0.8 * len(X))
X_train, X_test = X[:split], X[split:]
y_train, y_test = y[:split], y[split:]# 数据形状调整
X_train = np.reshape(X_train, (X_train.shape[0], X_train.shape[1], 1))
X_test = np.reshape(X_test, (X_test.shape[0], X_test.shape[1], 1))# 构建LSTM模型
model = Sequential()
model.add(LSTM(60, return_sequences=True, input_shape=(window_size, 1)))
model.add(LSTM(50))
model.add(Dense(1))
model.compile(optimizer='adam', loss='mean_squared_error')# 打印模型结构
model.summary()# 训练模型
history = model.fit(X_train, y_train, epochs=20, batch_size=32, validation_split=0.2)# 绘制训练损失和验证损失
plt.figure(figsize=(10, 6))
plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('Training and Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()# 预测
predictions = model.predict(X_test)# 绘制真实值和预测值
plt.figure(figsize=(12, 6))
plt.plot(y_test, label='True Values')
plt.plot(predictions, label='Predictions')
plt.title('True Values vs. Predictions')
plt.xlabel('Time Steps')
plt.ylabel('Values')
plt.legend()
plt.show()

4. 总结

长短期记忆神经网络(LSTM)是一种强大的序列建模工具,能够有效地捕获长期依赖关系。通过引入遗忘门、输入门和输出门,LSTM解决了传统RNN在处理长序列时的梯度消失问题。在本文中,我们详细介绍了LSTM的基本原理和结构,并通过一个时间序列预测的实例展示了如何使用Keras实现LSTM模型。

尽管LSTM在许多任务中表现出色,但它也有一些局限性,例如计算复杂度较高、训练时间较长等。随着深度学习技术的发展,许多改进的变体(如GRU、双向LSTM等)也逐渐被提出。在实际应用中,选择合适的模型需要根据具体任务和数据特点进行权衡。我是橙色小博,关注我,一起在人工智能领域学习进步。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/899912.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

基于机器学习的三国时期诸葛亮北伐失败因素量化分析

一、研究背景与方法论 1.1 历史问题的数据化挑战 三国时期(220-280年)的战争史存在史料分散、数据缺失的特点。本研究通过构建包含军事、经济、地理、政治四大维度的结构化数据库,收录建安十二年(207年)至建兴十二年…

蓝桥杯省模拟赛 数位和

问题描述 只能被 1 和本身整除的数称为质数。 请问在 1 (含)到 1000000 (含)中,有多少个质数的各个数位上的数字之和为 23 。 提示:599 就是这样一个质数,各个数位上的数字之和为 59923 。 #…

Timer的底层实现原理?

Timer 是 Java 中用于定时任务调度的基础工具类,其底层实现基于 单线程+任务队列 的模型。以下是 Timer 的底层实现原理的详细分析: 一、核心组件 TimerThread 继承自 Thread,是 Timer 的工作线程,负责从队列中提取任务并执行。通过 while (true) 循环持续检查任务队列。Ta…

Java 枚举类 Key-Value 映射的几种实现方式及最佳实践

Java 枚举类 Key-Value 映射的几种实现方式及最佳实践 前言 在 Java 开发中,枚举(Enum)是一种特殊的类,它能够定义一组固定的常量。在实际应用中,我们经常需要为枚举常量添加额外的属性,并实现 key-value 的映射关系。本文将详细…

青少年编程与数学 02-015 大学数学知识点 01课题、概要

青少年编程与数学 02-015 大学数学知识点 01课题、概要 一、线性代数二、概率论与数理统计三、微积分四、优化理论五、离散数学六、数值分析七、信息论 《青少年编程与数学》课程要求,在高中毕业前,尽量完成大部分大学数学知识的学习。一般可以通过线上课…

智能打印预约系统:微信小程序+SSM框架实战项目

微信小程序打印室预约系统,采用SSM(SpringSpringMVCMyBatis)经典框架组合。 一、系统核心功能详解 1. 智能化管理后台 ​用户数据看板​打印店资源管理​预约动态监控​服务评价系统 2. 微信小程序端 ​智能定位服务​预约时段选择​文件…

DataX 3.0 实战案例

第五章 实战案例 5.1. 案例一 5.1.1. 案例介绍 MySQL数据库中有两张表:用户表(users),订单表(orders)。其中用户表中存储的是所有的用户的信息,订单表中存储的是所有的订单的信息。表结构如下: 用户表 users: id:用…

设计模式学习(1)

面向对象设计原则 单一职责 每个类只有一个职责,并被完整的封装在类中,该原则用来控制类的粒度。 例如Mapper,controller都只负责一个业务。 开闭原则 应该对扩展开放,而对修改封闭,例如定义接口或是抽象类作为抽…

在 Rocky Linux 9.2 上编译安装 Redis 6.2.6

文章目录 在 Rocky Linux 9.2 上编译安装 Redis 6.2.6Redis 介绍官网Redis 的核心特性高性能支持多种数据结构多种持久化机制复制与高可用2.5 事务与 Lua 脚本消息队列功能 Redis 适用场景Redis 与其他数据库对比Redis 的优势与劣势Redis 优势Redis 劣势 部署过程系统环境信息环…

量子计算与经典计算的融合与未来

最近研学过程中发现了一个巨牛的人工智能学习网站,通俗易懂,风趣幽默,忍不住分享一下给大家。点击下方超链接跳转到网站人工智能及编程语言学习教程。读者们可以通过里面的文章详细了解一下人工智能及其编程等教程和学习方法。下面进入文章正…

数据结构(4)——带哨兵位循环双向链表

目录 前言 一、带哨兵的循环双向链表是什么 二、链表的实现 2.1规定结构体 2.2创建节点 2.3初始化 2.4打印 2.5检验是否为空 2.6销毁链表 2.7尾插 2.8尾删 2.9头插 2.10头删 2.11寻找特定节点 2.12任意位置插入(pos前) 2.13删除任意节点 …

Github 2025-03-30 php开源项目日报 Top10

根据Github Trendings的统计,今日(2025-03-30统计)共有10个项目上榜。根据开发语言中项目的数量,汇总情况如下: 开发语言项目数量PHP项目10TypeScript项目1Coolify: 开源自助云平台 创建周期:1112 天开发语言:PHP, Blade协议类型:Apache License 2.0Star数量:10527 个Fo…

3. 线程间共享数据

1. 线程共享数据会造成什么问题? 1.1 读写不一致 多线程读不会造成数据变动,所以没有问题。只要有一个线程设计修改数据,就会导致数据共享出现问题,简单的是数据不一致,严重的是程序访问已经释放的内存,造…

JAVA垃圾回收算法和判断垃圾的算法

一、判断垃圾的算法 判断对象是否为垃圾的核心是确定对象是否不再被使用。Java主要采用以下两种算法: 1. 引用计数法(Reference Counting) 原理:每个对象维护一个引用计数器,记录被引用的次数。当引用被添加时计数器…

界面架构 - MVVM (Qt)

MVVM MVVM 的主要特点示例示例功能示例代码ViewModel 类(C)主函数入口(main.cpp) QML 文件(main.qml)总结 MVVM(Model-View-ViewModel)架构是一种旨在进一步分离界面和业务逻辑的设计…

第十四届MathorCup高校数学建模挑战赛-C题:基于 LSTM-ARIMA 和整数规划的货量预测与人员排班模型

目录 摘要 一、 问题重述 1.1 背景知识 1.2 问题描述 二、 问题分析 2.1 对问题一的分析 2.2 对问题二的分析 2.3 对问题三的分析 2.4 对问题四的分析 三、 模型假设 四、 符号说明 五、 问题一模型的建立与求解 5.1 数据预处理 5.2 基于 LSTM 的日货量预测模型 5.3 日货量预测…

银河麒麟V10 aarch64架构安装mysql教程

国产操作系统 ky10.aarch64 因为是arm架构,故选择mysql8,推荐安装8.0.28版本 尝试8.0.30和8.0.41版本均未成功,原因不明☹️ 1. 准备工作 ⏬ 下载地址:https://downloads.mysql.com/archives/community/ 2. 清理历史环境 不用管…

C++多继承

可以用多个基类来派生一个类。 格式为: class 类名:类名1,…, 类名n { private: … ; //私有成员说明; public: … ; //公有成员说明; protected: … ; //保护的成员说明; }; class D: public A, protected B, private C { …//派…

某地老旧房屋自动化监测项目

1. 项目简介 自从上个世纪90年代以来,我国经济发展迅猛,在此期间大量建筑平地而起,并且多为砖混结构的住房,使用寿命通常约为30-50年,钢筋混凝土结构,钢结构等高层建筑,这些建筑在一般情况下的…

产品经理的大语言模型课 04 -模型应用的云、边、端模式对比

目录 算力部署方式的影响因素数据量计算难度前期投入数据隐私应用规模与泛化能力 云、边、端部署的特点和对比典型场景举例社区人脸门禁后厨老鼠识别 未来展望 算力部署方式的影响因素 最近和人工智能从业者进行了非常广泛的沟通,尝试对模型应用的云、边、端模式进…