机器学习入门--LSTM原理与实践

LSTM模型

长短期记忆网络(Long Short-Term Memory,LSTM)是一种常用的循环神经网络(RNN)变体,特别擅长处理长序列数据和捕捉长期依赖关系。本文将介绍LSTM模型的数学原理、代码实现和实验结果,并使用pytorch和sklearn的数据集进行验证。

数学原理

遗忘门(Forget Gate)

遗忘门的作用是决定前一时间步的细胞状态中哪些信息需要被遗忘。具体计算公式为:
f t = σ ( W f ⋅ [ h t − 1 , x t ] + b f ) f_t = \sigma(W_f \cdot [h_{t-1}, x_t] + b_f) ft=σ(Wf[ht1,xt]+bf)
其中, W f W_f Wf表示遗忘门权重矩阵, h t − 1 h_{t-1} ht1表示前一时间步的隐藏状态, x t x_t xt是当前时间步的输入, b f b_f bf是遗忘门的偏置向量, σ \sigma σ表示sigmoid函数。

输入门 (Input Gate)

输入门的作用是决定当前时间步的输入中哪些信息将被加入到细胞状态中。具体计算公式为:
i t = σ ( W i ⋅ [ h t − 1 , x t ] + b i ) i_t = \sigma(W_i \cdot [h_{t-1}, x_t] + b_i) it=σ(Wi[ht1,xt]+bi)
其中, W i W_i Wi 表示输入门的权重矩阵, h t − 1 h_{t-1} ht1表示前一时间步的隐藏状态, x t x_t xt表示当前时间步的输入, b i b_i bi是输入门的偏置向量, σ \sigma σ是sigmoid函数。

更新单元 (Candidate Cell State)

更新单元计算出一个候选的单元状态,用于更新当前时间步的单元状态。具体计算公式为:
C ~ t = tanh ⁡ ( W C ⋅ [ h t − 1 , x t ] + b C ) \tilde{C}_{t} = \tanh(W_C \cdot [h_{t-1}, x_t] + b_C) C~t=tanh(WC[ht1,xt]+bC)
其中, W C W_C WC是更新单元的权重矩阵, h t − 1 h_{t-1} ht1是前一时间步的隐藏状态, x t x_t xt表示当前时间步输入, b C b_C bC是更新单元偏置向量。

细胞状态更新(Cell State Update)

通过遗忘门、输入门和更新单元的计算结果,可以更新当前时间步的细胞状态。具体计算公式为:
C t = f t ∗ C t − 1 + i t ∗ C ~ t C_t = f_t * C_{t-1} + i_t * \tilde{C}_t Ct=ftCt1+itC~t
其中, f t f_t ft是遗忘门的输出, C t − 1 C_{t-1} Ct1是前一时间步的单元状态, i t i_t it是输入门的输出, C ~ t \tilde{C}_t C~t是更新单元的输出。

输出门(Output Gate)

输出门的作用是决定当前时间步的隐藏状态。具体计算公式为:
o t = σ ( W o ⋅ [ h t − 1 , x t ] + b o ) o_t = \sigma(W_o \cdot [h_{t-1}, x_t] + b_o) ot=σ(Wo[ht1,xt]+bo)
其中, W o W_o Wo是输出门的权重矩阵, h t − 1 h_{t-1} ht1是前一时间步的隐藏状态, x t x_t xt 表示当前时间步输入, b o b_o bo 是输出门偏置向量。

隐藏状态更新(Hidden State Update)

通过输出门和细胞状态计算出当前时间步的隐藏状态。具体计算公式为:
h t = o t ∗ tanh ⁡ ( C t ) h_t = o_t * \tanh(C_t) ht=ottanh(Ct)
其中, o t o_t ot是输出门的输出, C t C_t Ct代表当前时间步的单元状态。

代码实现

import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.datasets import load_boston
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt# 加载数据集并进行标准化
data = load_boston()
X = data.data
y = data.target
scaler = StandardScaler()
X = scaler.fit_transform(X)
y = y.reshape(-1, 1)# 转换为张量
X = torch.tensor(X, dtype=torch.float32).unsqueeze(1)
y = torch.tensor(y, dtype=torch.float32)# 定义LSTM模型
class LSTMNet(nn.Module):def __init__(self, input_size, hidden_size, output_size):super(LSTMNet, self).__init__()self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True)self.fc = nn.Linear(hidden_size, output_size)def forward(self, x):out, _ = self.lstm(x)out = self.fc(out[:, -1, :])return outinput_size = X.shape[2]
hidden_size = 32
output_size = 1
model = LSTMNet(input_size, hidden_size, output_size)# 定义损失函数和优化器
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)# 训练模型
num_epochs = 10000
loss_list = []
for epoch in range(num_epochs):optimizer.zero_grad()outputs = model(X)loss = criterion(outputs, y)loss.backward()optimizer.step()if (epoch+1) % 100 == 0:loss_list.append(loss.item())print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item()}')# 可视化损失曲线
plt.plot(range(100), loss_list)
plt.xlabel('num_epochs')
plt.ylabel('loss of LSTM Training')
plt.show()# 预测新数据
new_data_point = X[0].reshape(1, 1, -1)
prediction = model(new_data_point)
print(f'Predicted value: {prediction.item()}')

上述代码中,我们首先加载并标准化波士顿房价数据集,然后定义了一个包含LSTM层和全连接层的LSTMNet模型。通过使用均方误差作为损失函数和Adam优化器进行训练,我们展示了如何训练和预测LSTM模型。最后,通过matplotlib库绘制了损失曲线(如下图所示),并对新数据点进行了预测。
LSTM-损失函数

总结

LSTM作为一种强大的循环神经网络模型,在处理长序列数据和捕捉长期依赖关系方面表现出色。通过本文的介绍和实验,我们深入探讨了LSTM的数学原理、代码实现和应用实例。通过使用pytorch和sklearn的数据集进行实验,我们验证了LSTM模型在房价预测任务中的有效性和性能优势。希望本文能帮助读者更好地理解和应用LSTM模型。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/686314.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

OpenCV库及在ROS中使用

OpenCV库及在ROS中使用 依赖 cv_bridge image_transport roscpp rospy sensor_msgs std_msgsCMakeLists.txt添加 find_package(OpenCV REQUIRED) include_directories(${OpenCV_INCLUDE_DIRS}) target_link_libraries(pub_img_topic ${catkin_LIBRARIES} ${Opencv_LIBS}) C …

基于springboot大学生租房系统源码和论文

伴随着全球信息化发展,行行业业都与计算机技术相衔接,计算机技术普遍运用于各大行业,大学生租房系统便是其中一种。实施计算机系统来管理可以降低大学生租房管理的成本,使整个大学生租房的发展和服务水平有显著提升。 本论文主要面…

Github Copilot是什么?Ai高效编程!一键远程授权…

GitHub Copilot是一款Ai编程插件,由OpenAi和Github联合推出,目前支持主流的IDE编辑器安装使用,包括JetBrains IDEs、VSCode、Visual Studio、Neovim等。 官方地址:https://github.com/features/copilot 官方文档:http…

VBA即用型代码手册之取消隐藏工作表及删除工作表

我给VBA下的定义:VBA是个人小型自动化处理的有效工具。可以大大提高自己的劳动效率,而且可以提高数据的准确性。我这里专注VBA,将我多年的经验汇集在VBA系列九套教程中。 作为我的学员要利用我的积木编程思想,积木编程最重要的是积木如何搭建…

【maya 入门笔记】基本视图和拓扑

1. 界面布局 先看基本窗口布局,基本窗口情况如下: 就基本窗口布局的情况来看,某种意义上跟blender更像一点(与3ds max相比)。 那么有朋友就说了,玛格基,那blender最下面的时间轴哪里去了&…

使用PaddleNLP UIE模型提取上市公司PDF公告关键信息

项目地址:使用PaddleNLP UIE模型抽取PDF版上市公司公告 - 飞桨AI Studio星河社区 (baidu.com) 背景介绍 本项目将演示如何通过PDFPlumber库和PaddleNLP UIE模型,抽取公告中的相关信息。本次任务的PDF内容是破产清算的相关公告,目标是获取受理…

pubg开启之路

概要: pubg中文名绝地求生,一款免费游戏,本篇主要讲述如何在电脑上开始pubg 要想下载并开始玩pubg有两个方法(具体就是两个软件),一个是epic games,另一个是steam 一、加速器是必要的吗? 1、不使用加速…

Pandas数据库大揭秘:read_sql、to_sql 参数详解与实战篇【第81篇—Pandas数据库】

Pandas数据库大揭秘:read_sql、to_sql 参数详解与实战篇 Pandas是Python中一流的数据处理库,而数据库则是数据存储和管理的核心。将两者结合使用,可以方便地实现数据的导入、导出和分析。本文将深入探讨Pandas中用于与数据库交互的两个关键方…

代码随想录 Leetcode135. 分发糖果

题目&#xff1a; 代码(首刷看解析 2024年2月15日&#xff09;&#xff1a; class Solution { public:int candy(vector<int>& ratings) {vector<int> left(ratings.size(), 1);vector<int> right(ratings.size(), 1);for (int i 1; i < ratings.si…

半导体物理基础-笔记

源内容参考&#xff1a;https://www.bilibili.com/video/BV11U4y1k7zn/?spm_id_from333.337.search-card.all.click&vd_source61654d4a6e8d7941436149dd99026962 半导体物理要解决的四个问题 载流子在哪里&#xff1b;如何获得足够多的载流子&#xff1b;载流子如何运动…

html从零开始8:css3新特性、动画、媒体查询、雪碧图、字体图标【搬代码】

css3新特性 <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta http-equiv"X-UA-Compatible" content"IEedge"><meta name"viewport" content"widthdevice-width, …

linux kernel 内存踩踏之KASAN_SW_TAGS(二)

一、背景 linux kernel 内存踩踏之KASAN&#xff08;一&#xff09;_kasan版本跟hasan版本区别-CSDN博客 上一篇简单介绍了标准版本的KASAN使用方法和实现&#xff0c;这里将介绍KASAN_SW_TAGS和KASAN_HW_TAGS 的使用和背后基本原理&#xff0c;下图是三种方式的对比&#x…

萨科微半导体宋仕强介绍说

萨科微半导体宋仕强介绍说&#xff0c;电源管理芯片是指在电子设备系统中&#xff0c;负责对电能的变换、分配、检测等进行管理的芯片&#xff0c;其性能和可靠性直接影响电子设备的工作效率和使用寿命&#xff0c;是电子设备中的关键器件。萨科微slkor&#xff08;www.slkormi…

2023年中国数据智能管理峰会(DAMS上海站2023):核心内容与学习收获(附大会核心PPT下载)

随着数字经济的飞速发展&#xff0c;数据已经渗透到现代社会的每一个角落&#xff0c;成为驱动企业创新、提升治理能力、促进经济发展的关键要素。在这样的背景下&#xff0c;2023年中国数据智能管理峰会&#xff08;DAMS上海站2023&#xff09;应运而生&#xff0c;汇聚了众多…

1.逆向基础

文章目录 一、前言二、什么是逆向&#xff1f;三、软件逆向四、逆向分析技术五、文本字符六、Windows系统1.Win API2.WOW643.Windows消息机制4.虚拟内存 一、前言 原文以及后续文章可点击查看&#xff1a;逆向基础 逆向真的是一个很宏大的话题&#xff0c;而且大多数都是相当…

数据预处理 —— AI算法初识

一、预处理原因 AI算法对数据进行预处理的原因主要基于以下几个核心要点&#xff1a; 1. **数据清洗**&#xff1a; - 数据通常包含缺失值、异常值或错误记录&#xff0c;这些都会干扰模型训练和预测准确性。通过预处理可以识别并填充/删除这些不完整或有问题的数据。 2. **数…

LabVIEW智能监测系统

LabVIEW智能监测系统 设计与实现一个基于LabVIEW的智能监测系统&#xff0c;通过高效的数据采集和处理能力&#xff0c;提高监测精度和响应速度。系统通过集成传感器技术与虚拟仪器软件&#xff0c;实现对环境参数的实时监测与分析&#xff0c;进而优化监控过程&#xff0c;提…

如何实现Vuex数据持久化

Vuex是一个非常流行的状态管理工具&#xff0c;它可以帮助我们在Vue.js应用中管理和共享数据。然而&#xff0c;当应用重新加载或刷新时&#xff0c;Vuex的状态会被重置&#xff0c;这就导致了数据的丢失。那么&#xff0c;如何才能实现Vuex的数据持久化呢&#xff1f;让我们一…

C语言第二十六弹---字符串函数(下)

✨个人主页&#xff1a; 熬夜学编程的小林 &#x1f497;系列专栏&#xff1a; 【C语言详解】 【数据结构详解】 目录 1、strncat 函数的使用 2、strncmp 函数的使用 3、strstr 函数的使用和模拟实现 4、strtok 函数的使用 5、strerror 函数的使用 6、perror 函数的使用…

51单片机编程应用(C语言):串口通信

目录 通信的基本概念和种类 1.1串行通信与并行通信 ​编辑 1.2同步通信与异步通信 1.3单工&#xff0c;半双工&#xff0c;全双工 1.4通信速率 二、波特率和比特率的关系 串口通信简介&#xff1a; 1.接口标准 RS-232 2、D型9针接口定义 3.通信协议&#xff1a; …