深度学习模型:LSTM (Long Short-Term Memory) - 长短时记忆网络详解

一、引言

在深度学习领域,循环神经网络(RNN)在处理序列数据方面具有独特的优势,例如语音识别、自然语言处理等任务。然而,传统的 RNN 在处理长序列数据时面临着严重的梯度消失问题,这使得网络难以学习到长距离的依赖关系。LSTM 作为一种特殊的 RNN 架构应运而生,有效地解决了这一难题,成为了序列建模领域的重要工具。

二、LSTM 基本原理

(一)细胞状态

LSTM 的核心是细胞状态(Cell State),它类似于一条信息传送带,贯穿整个时间序列。细胞状态能够在序列的各个时间步中保持相对稳定的信息传递,从而使得网络能够记忆长距离的信息。在每个时间步,细胞状态会根据输入门、遗忘门和输出门的控制进行信息的更新与传递。在这里插入图片描述

(二)门控机制

遗忘门(Forget Gate)
遗忘门的作用是决定细胞状态中哪些信息需要被保留,哪些信息需要被丢弃。它接收当前输入 和上一时刻的隐藏状态 作为输入,通过一个 Sigmoid 激活函数将其映射到 0 到 1 之间的值。其中,接近 0 的值表示对应的细胞状态信息将被遗忘,接近 1 的值表示信息将被保留。遗忘门的计算公式如下:在这里插入图片描述
输入门(Input Gate)
输入门负责控制当前输入中有多少信息将被更新到细胞状态中。它同样接收 和 作为输入,通过 Sigmoid 函数计算出一个更新比例,同时通过一个 Tanh 激活函数对当前输入进行变换,然后将两者相乘得到需要更新到细胞状态中的信息。输入门的计算公式如下:在这里插入图片描述
细胞状态更新
根据遗忘门和输入门的结果,对细胞状态进行更新。具体公式如下:

在这里插入图片描述
输出门(Output Gate)
输出门决定了细胞状态中的哪些信息将被输出作为当前时刻的隐藏状态。它接收 和 作为输入,通过 Sigmoid 函数计算出一个输出比例,然后将其与经过 Tanh 激活函数处理后的细胞状态相乘,得到当前时刻的隐藏状态 。输出门的计算公式如下:在这里插入图片描述

三、LSTM 的应用领域

(一)自然语言处理

语言模型
LSTM 可以用于构建语言模型,预测下一个单词的概率分布。通过对大量文本数据的学习,LSTM 能够捕捉到单词之间的语义和语法关系,从而生成连贯、合理的文本。例如,在文本生成任务中,给定一个初始的文本片段,LSTM 可以根据学习到的语言模式继续生成后续的文本内容。
机器翻译
在机器翻译任务中,LSTM 可以对源语言句子进行编码,将其转换为一种中间表示形式,然后再解码为目标语言句子。通过对双语平行语料库的学习,LSTM 能够理解源语言和目标语言之间的对应关系,实现较为准确的翻译。
文本分类
对于文本分类任务,如情感分析(判断文本的情感倾向是积极、消极还是中性)、新闻分类(将新闻文章分类到不同的主题类别)等,LSTM 可以对文本序列进行建模,提取文本的特征表示,然后通过一个分类器(如全连接层和 Softmax 函数)对文本进行分类。

(二)时间序列预测

股票价格预测
股票价格受到众多因素的影响,并且具有时间序列的特性。LSTM 可以学习股票价格的历史数据中的模式和趋势,预测未来的股票价格走势。通过分析过去一段时间内的股票价格、成交量、宏观经济指标等数据,LSTM 能够尝试捕捉到股票市场的动态变化规律,为投资者提供决策参考。
气象预测
气象数据如气温、气压、风速等也是时间序列数据。LSTM 可以利用历史气象数据来预测未来的气象变化,例如预测未来几天的气温变化、降水概率等。通过对大量气象观测数据的学习,LSTM 能够挖掘出气象要素之间的复杂关系和时间演变规律,提高气象预测的准确性。

(三)语音识别

在语音识别系统中,LSTM 可以对语音信号的序列特征进行建模。语音信号首先被转换为一系列的特征向量(如梅尔频率倒谱系数 MFCC),然后 LSTM 对这些特征向量序列进行处理,识别出语音中的单词和句子。LSTM 能够处理语音信号中的长时依赖关系,例如语音中的韵律、连读等现象,从而提高语音识别的准确率。

四、LSTM 代码实现

(一)使用 Python 和 TensorFlow 构建 LSTM 模型

以下是一个简单的示例代码,展示了如何使用 TensorFlow 构建一个 LSTM 模型用于时间序列预测任务(以预测正弦波数据为例)。

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt# 生成正弦波数据
def generate_sine_wave_data(num_samples, time_steps):x = []y = []for i in range(num_samples):# 生成一个随机的起始点start = np.random.rand() * 2 * np.pi# 生成时间序列数据series = [np.sin(start + i * 0.1) for i in range(time_steps)]# 目标值是下一个时间步的正弦值target = np.sin(start + time_steps * 0.1)x.append(series)y.append(target)return np.array(x), np.array(y)# 超参数
num_samples = 10000
time_steps = 50
input_dim = 1
output_dim = 1
num_units = 64
learning_rate = 0.001
num_epochs = 100# 生成数据
x_train, y_train = generate_sine_wave_data(num_samples, time_steps)# 数据预处理,将数据形状调整为适合 LSTM 输入的格式
x_train = np.reshape(x_train, (num_samples, time_steps, input_dim))
y_train = np.reshape(y_train, (num_samples, output_dim))# 构建 LSTM 模型
model = tf.keras.Sequential()
model.add(tf.keras.layers.LSTM(num_units, input_shape=(time_steps, input_dim)))
model.add(tf.keras.layers.Dense(output_dim))# 定义损失函数和优化器
loss_fn = tf.keras.losses.MeanSquaredError()
optimizer = tf.keras.optimizers.Adam(learning_rate)# 编译模型
model.compile(loss=loss_fn, optimizer=optimizer)# 训练模型
history = model.fit(x_train, y_train, epochs=num_epochs, verbose=2)# 绘制训练损失曲线
plt.plot(history.history['loss'])
plt.title('Training Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.show()# 使用训练好的模型进行预测
x_test, y_test = generate_sine_wave_data(100, time_steps)
x_test = np.reshape(x_test, (100, time_steps, input_dim))
y_pred = model.predict(x_test)# 绘制预测结果与真实值对比图
plt.plot(y_test, label='True')
plt.plot(y_pred, label='Predicted')
plt.title('Prediction Results')
plt.xlabel('Sample')
plt.ylabel('Value')
plt.legend()
plt.show()

在上述代码中,首先定义了一个函数 generate_sine_wave_data 用于生成正弦波数据作为时间序列预测的示例数据。然后设置了一系列超参数,如样本数量、时间步长、输入维度、输出维度、LSTM 单元数量、学习率和训练轮数等。接着生成训练数据并进行预处理,将其形状调整为适合 LSTM 模型输入的格式((样本数量, 时间步长, 输入维度))。
构建 LSTM 模型时,使用 tf.keras.Sequential 模型,先添加一个 LSTM 层,指定单元数量和输入形状,然后添加一个全连接层用于输出预测结果。定义了均方误差损失函数和 Adam 优化器,并编译模型。使用 model.fit 方法对模型进行训练,并绘制训练损失曲线以观察训练过程。最后,生成测试数据,使用训练好的模型进行预测,并绘制预测结果与真实值的对比图,以评估模型的性能。

(二)代码解读

数据生成部分
generate_sine_wave_data 函数通过循环生成多个正弦波序列数据。对于每个序列,随机选择一个起始点,然后根据正弦函数生成指定时间步长的序列数据,并将下一个时间步的正弦值作为目标值。这样生成的数据可以模拟时间序列预测任务中的数据模式,其中输入是一个时间序列,目标是该序列的下一个值。
模型构建部分
tf.keras.Sequential 是 TensorFlow 中用于构建序列模型的类。model.add(tf.keras.layers.LSTM(num_units, input_shape=(time_steps, input_dim))) 这一行添加了一个 LSTM 层,num_units 定义了 LSTM 层中的单元数量,它决定了模型能够学习到的特征表示的复杂度。input_shape 则指定了输入数据的形状,即时间步长和输入维度。model.add(tf.keras.layers.Dense(output_dim)) 添加了一个全连接层,用于将 LSTM 层的输出转换为最终的预测结果,输出维度与目标数据的维度相同。
训练与评估部分
loss_fn = tf.keras.losses.MeanSquaredError() 定义了均方误差损失函数,用于衡量预测值与真实值之间的差异。optimizer = tf.keras.optimizers.Adam(learning_rate) 选择了 Adam 优化器,并指定了学习率。model.compile(loss=loss_fn, optimizer=optimizer) 编译模型,将损失函数和优化器与模型关联起来。model.fit(x_train, y_train, epochs=num_epochs, verbose=2) 对模型进行训练,epochs 表示训练的轮数,verbose 控制训练过程中的输出信息。训练完成后,通过绘制训练损失曲线可以观察模型在训练过程中的收敛情况。最后,使用测试数据进行预测,并绘制预测结果与真实值的对比图,直观地评估模型的预测准确性。

五、LSTM 的优势与局限性

(一)优势

长距离依赖学习能力
如前文所述,LSTM 能够有效地解决传统 RNN 中的梯度消失问题,从而可以学习到序列数据中长距离的依赖关系。这使得它在处理诸如长文本、长时间序列等数据时表现出色,能够捕捉到数据中深层次的语义、趋势和模式。
灵活性与适应性
LSTM 可以应用于多种不同类型的序列数据处理任务,无论是自然语言、时间序列还是语音信号等。它的门控机制使得模型能够根据不同的数据特点和任务需求,灵活地调整细胞状态中的信息保留与更新,具有较强的适应性。

  • (二)局限性

计算复杂度较高
由于 LSTM 的细胞结构和门控机制相对复杂,相比于简单的神经网络模型,其计算复杂度较高。在处理大规模数据或构建深度 LSTM 网络时,训练时间和计算资源的需求可能会成为瓶颈,需要强大的计算硬件支持。
可能存在过拟合
在数据量较小或模型参数过多的情况下,LSTM 模型也可能出现过拟合现象,即模型过于适应训练数据,而对新的数据泛化能力较差。需要采用一些正则化技术,如 L1/L2 正则化、Dropout 等,来缓解过拟合问题。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/62673.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

算法笔记:力扣24. 两两交换链表中的节点

思路: 本题最简单的就是通过递归的形式去实现 class Solution {public ListNode swapPairs(ListNode head) {if(head null || head.next null){return head;}ListNode next head.next;head.next swapPairs(next.next);next.next head;return next;} } 对于链…

ehr系统建设方案,人力资源功能模块主要分为哪些,hrm平台实际案例源码,springboot人力资源系统,vue,JAVA语言hr系统(源码)

eHR人力资源管理系统:功能强大的人力资源管理工具 随着企业规模的不断扩大和业务需求的多样化,传统的人力资源管理模式已无法满足现代企业的需求。eHR人力资源管理系统作为一种先进的管理工具,能够为企业提供高效、准确、实时的人力资源管理。…

【Android】从事件分发开始:原理解析如何解决滑动冲突

【Android】从事件分发开始:原理解析如何解决滑动冲突 文章目录 【Android】从事件分发开始:原理解析如何解决滑动冲突Activity层级结构浅析Activity的setContentView源码浅析AppCompatActivity的setContentView源码 触控三分显纷争,滑动冲突…

OGRE 3D----2. QGRE + QQuickView

将 OGRE(面向对象图形渲染引擎)集成到使用 QQuickView 的 Qt Quick 应用程序中,可以在现代灵活的 UI 框架中提供强大的 3D 渲染功能。本文将指导您如何在 QQuickView 环境中设置 OGRE。 前提条件 在开始之前,请确保您已安装以下内容: Qt(版本 5.15 )OGRE(版本14.2.5)…

GAGAvatar: Generalizable and Animatable Gaussian Head Avatar 学习笔记

1 Overall GAGAvatar(Generalizable and Animatable Gaussian Avatar),一种面向单张图片驱动的可动画化头部头像重建的方法,解决了现有方法在渲染效率和泛化能力上的局限。 旋转参数 现有方法的局限性: 基于NeRF的方…

论文笔记-WWW2024-ClickPrompt

论文笔记-WWW2024-ClickPrompt: CTR Models are Strong Prompt Generators for Adapting Language Models to CTR Prediction ClickPrompt: CTR模型是大模型适配CTR预测任务的强大提示生成器摘要1.引言2.预备知识2.1传统CTR预测2.2基于PLM的CTR预测 3.方法3.1概述3.2模态转换3.…

预训练模型与ChatGPT:自然语言处理的革新与前景

目录 一、ChatGPT整体背景认知 (一)ChatGPT引起关注的原因 (二)与其他公司的竞争情况 二、NLP学习范式的发展 (一)规则和机器学习时期 (二)基于神经网络的监督学习时期 &…

GRAG: Graph Retrieval-Augmented Generation

GRAG: Graph Retrieval-Augmented Generation 摘要 简单检索增强生成 (Naive RAG) 聚焦于单一文档的检索,因此在处理网络化文档时表现不足,例如引用图、社交媒体和知识图谱等应用中非常常见的场景。为了解决这一限制,我们提出了图检索增强生…

使用Python OpenCV实现图像形状检测

目录 一、环境准备 二、读取和预处理图像 读取图像 灰度化 滤波去噪 三、边缘检测 四、查找轮廓 五、绘制轮廓 六、形状分类 七、显示结果 八、完整代码示例 九、总结 图像形状检测是计算机视觉领域中的一项关键技术,广泛应用于工业自动化、机器人视觉、医学图像处…

11.25.2024刷华为OD

文章目录 HJ76 尼科彻斯定理(观察题,不难)HJ77 火车进站(DFS)HJ91 走格子方法,(动态规划,递归,有代表性)HJ93 数组分组(递归)语法知识…

多线程篇-8--线程安全(死锁,常用保障安全的方法,安全容器,原子类,Fork/Join框架等)

1、线程安全和不安全定义 (1)、线程安全 线程安全是指一个类或方法在被多个线程访问的情况下可以正确得到结果,不会出现数据不一致或其他错误行为。 线程安全的条件 1、原子性(Atomicity) 多个操作要么全部完成&a…

自动驾驶决策规划算法-路径决策算法:二次规划

本文为学习自动驾驶决策规划算法第二章第四节(中) 路径二次规划算法》的学习笔记。 1 二次型 二次型的形式为 1 2 x T H x f T x \begin{equation} \frac{1}{2}\boldsymbol{x}^TH\boldsymbol{x}f^T\boldsymbol{x} \end{equation} 21​xTHxfTx​​ 约束 A e q x b e q \be…

AI开发-数据可视化库-Seaborn

1 需求 概述 Seaborn 是一个基于 Python 的数据可视化库,它建立在 Matplotlib 之上。其主要目的是使数据可视化更加美观、方便和高效。它提供了高层次的接口和各种美观的默认主题,能够帮助用户快速创建出具有吸引力的统计图表,用于数据分析和…

相交链表和环形链表

(一)相交链表 相交链表 思路:先分别计算出A列表和B列表的长度,判断它们的尾节点是否相等,如果不相等就不相交,直接返回空。然后让两个列表中的长的列表先走它们的差距步,然后再一起走&#xff…

[Redis#12] 常用类型接口学习 | string | list

目录 0.准备 1.string get | set set_with_timeout_test.cpp set_nx_xx_test.cpp mset_test.cpp mget_test.cpp getrange_setrange_test.cpp incr_decr_test.cpp 2.list lpush_lrange_test.cpp rpush_test.cpp lpop_rpop_test.cpp blpop_test.cpp llen_test.cpp…

A054-基于Spring Boot的青年公寓服务平台的设计与实现

🙊作者简介:在校研究生,拥有计算机专业的研究生开发团队,分享技术代码帮助学生学习,独立完成自己的网站项目。 代码可以查看文章末尾⬇️联系方式获取,记得注明来意哦~🌹 赠送计算机毕业设计600…

【经典】星空主题的注册界面HTML,CSS,JS

目录 界面展示 完整代码 说明&#xff1a; 这是一个简单的星空主题的注册界面&#xff0c;使用了 HTML 和 CSS 来实现一个背景为星空效果的注册页面。 界面展示 完整代码 <!DOCTYPE html> <html lang"zh"> <head><meta charset"UTF-8&…

TiDB 优化器丨执行计划和 SQL 算子解读最佳实践

作者&#xff1a; TiDB社区小助手 原文来源&#xff1a; https://tidb.net/blog/5edb7933 导读 在数据库系统中&#xff0c;查询优化器是数据库管理系统的核心组成部分&#xff0c;负责将用户的 SQL 查询转化为高效的执行计划&#xff0c;因而会直接影响用户体感的性能与稳…

位运算在嵌入式系统开发中的应用

目录 一、数据存储与节省 “绝技” 1.1. 传感器数据存储挑战 1.2. 位运算解决方案 1.2.1. 数据整合 1.2.2. 数据提取 1.3. 收益分析 二、硬件控制 “精准操纵术” 2.1. 位运算操控硬件寄存器的实例 2.2. 位运算在硬件控制中的优势 2.3. 电机驱动芯片寄存器控制示例 …

设置redis

1.https://github.com/tporadowski/redis/releases下载对应版本 解压 启动redis临时服务 在 redis 文件夹下 cmd 输入redis-server.exe redis.windows.conf 临时服务启动 从新打开一个cmd 运行redis-cli 输入ping 启动成功 命令行输入shutdown关闭服务 创建永久服务 在…