机器学习入门--LSTM原理与实践

LSTM模型

长短期记忆网络(Long Short-Term Memory,LSTM)是一种常用的循环神经网络(RNN)变体,特别擅长处理长序列数据和捕捉长期依赖关系。本文将介绍LSTM模型的数学原理、代码实现和实验结果,并使用pytorch和sklearn的数据集进行验证。

数学原理

遗忘门(Forget Gate)

遗忘门的作用是决定前一时间步的细胞状态中哪些信息需要被遗忘。具体计算公式为:
f t = σ ( W f ⋅ [ h t − 1 , x t ] + b f ) f_t = \sigma(W_f \cdot [h_{t-1}, x_t] + b_f) ft=σ(Wf[ht1,xt]+bf)
其中, W f W_f Wf表示遗忘门权重矩阵, h t − 1 h_{t-1} ht1表示前一时间步的隐藏状态, x t x_t xt是当前时间步的输入, b f b_f bf是遗忘门的偏置向量, σ \sigma σ表示sigmoid函数。

输入门 (Input Gate)

输入门的作用是决定当前时间步的输入中哪些信息将被加入到细胞状态中。具体计算公式为:
i t = σ ( W i ⋅ [ h t − 1 , x t ] + b i ) i_t = \sigma(W_i \cdot [h_{t-1}, x_t] + b_i) it=σ(Wi[ht1,xt]+bi)
其中, W i W_i Wi 表示输入门的权重矩阵, h t − 1 h_{t-1} ht1表示前一时间步的隐藏状态, x t x_t xt表示当前时间步的输入, b i b_i bi是输入门的偏置向量, σ \sigma σ是sigmoid函数。

更新单元 (Candidate Cell State)

更新单元计算出一个候选的单元状态,用于更新当前时间步的单元状态。具体计算公式为:
C ~ t = tanh ⁡ ( W C ⋅ [ h t − 1 , x t ] + b C ) \tilde{C}_{t} = \tanh(W_C \cdot [h_{t-1}, x_t] + b_C) C~t=tanh(WC[ht1,xt]+bC)
其中, W C W_C WC是更新单元的权重矩阵, h t − 1 h_{t-1} ht1是前一时间步的隐藏状态, x t x_t xt表示当前时间步输入, b C b_C bC是更新单元偏置向量。

细胞状态更新(Cell State Update)

通过遗忘门、输入门和更新单元的计算结果,可以更新当前时间步的细胞状态。具体计算公式为:
C t = f t ∗ C t − 1 + i t ∗ C ~ t C_t = f_t * C_{t-1} + i_t * \tilde{C}_t Ct=ftCt1+itC~t
其中, f t f_t ft是遗忘门的输出, C t − 1 C_{t-1} Ct1是前一时间步的单元状态, i t i_t it是输入门的输出, C ~ t \tilde{C}_t C~t是更新单元的输出。

输出门(Output Gate)

输出门的作用是决定当前时间步的隐藏状态。具体计算公式为:
o t = σ ( W o ⋅ [ h t − 1 , x t ] + b o ) o_t = \sigma(W_o \cdot [h_{t-1}, x_t] + b_o) ot=σ(Wo[ht1,xt]+bo)
其中, W o W_o Wo是输出门的权重矩阵, h t − 1 h_{t-1} ht1是前一时间步的隐藏状态, x t x_t xt 表示当前时间步输入, b o b_o bo 是输出门偏置向量。

隐藏状态更新(Hidden State Update)

通过输出门和细胞状态计算出当前时间步的隐藏状态。具体计算公式为:
h t = o t ∗ tanh ⁡ ( C t ) h_t = o_t * \tanh(C_t) ht=ottanh(Ct)
其中, o t o_t ot是输出门的输出, C t C_t Ct代表当前时间步的单元状态。

代码实现

import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.datasets import load_boston
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt# 加载数据集并进行标准化
data = load_boston()
X = data.data
y = data.target
scaler = StandardScaler()
X = scaler.fit_transform(X)
y = y.reshape(-1, 1)# 转换为张量
X = torch.tensor(X, dtype=torch.float32).unsqueeze(1)
y = torch.tensor(y, dtype=torch.float32)# 定义LSTM模型
class LSTMNet(nn.Module):def __init__(self, input_size, hidden_size, output_size):super(LSTMNet, self).__init__()self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True)self.fc = nn.Linear(hidden_size, output_size)def forward(self, x):out, _ = self.lstm(x)out = self.fc(out[:, -1, :])return outinput_size = X.shape[2]
hidden_size = 32
output_size = 1
model = LSTMNet(input_size, hidden_size, output_size)# 定义损失函数和优化器
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)# 训练模型
num_epochs = 10000
loss_list = []
for epoch in range(num_epochs):optimizer.zero_grad()outputs = model(X)loss = criterion(outputs, y)loss.backward()optimizer.step()if (epoch+1) % 100 == 0:loss_list.append(loss.item())print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item()}')# 可视化损失曲线
plt.plot(range(100), loss_list)
plt.xlabel('num_epochs')
plt.ylabel('loss of LSTM Training')
plt.show()# 预测新数据
new_data_point = X[0].reshape(1, 1, -1)
prediction = model(new_data_point)
print(f'Predicted value: {prediction.item()}')

上述代码中,我们首先加载并标准化波士顿房价数据集,然后定义了一个包含LSTM层和全连接层的LSTMNet模型。通过使用均方误差作为损失函数和Adam优化器进行训练,我们展示了如何训练和预测LSTM模型。最后,通过matplotlib库绘制了损失曲线(如下图所示),并对新数据点进行了预测。
LSTM-损失函数

总结

LSTM作为一种强大的循环神经网络模型,在处理长序列数据和捕捉长期依赖关系方面表现出色。通过本文的介绍和实验,我们深入探讨了LSTM的数学原理、代码实现和应用实例。通过使用pytorch和sklearn的数据集进行实验,我们验证了LSTM模型在房价预测任务中的有效性和性能优势。希望本文能帮助读者更好地理解和应用LSTM模型。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/686314.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

OpenCV库及在ROS中使用

OpenCV库及在ROS中使用 依赖 cv_bridge image_transport roscpp rospy sensor_msgs std_msgsCMakeLists.txt添加 find_package(OpenCV REQUIRED) include_directories(${OpenCV_INCLUDE_DIRS}) target_link_libraries(pub_img_topic ${catkin_LIBRARIES} ${Opencv_LIBS}) C …

基于springboot大学生租房系统源码和论文

伴随着全球信息化发展,行行业业都与计算机技术相衔接,计算机技术普遍运用于各大行业,大学生租房系统便是其中一种。实施计算机系统来管理可以降低大学生租房管理的成本,使整个大学生租房的发展和服务水平有显著提升。 本论文主要面…

Github Copilot是什么?Ai高效编程!一键远程授权…

GitHub Copilot是一款Ai编程插件,由OpenAi和Github联合推出,目前支持主流的IDE编辑器安装使用,包括JetBrains IDEs、VSCode、Visual Studio、Neovim等。 官方地址:https://github.com/features/copilot 官方文档:http…

VBA即用型代码手册之取消隐藏工作表及删除工作表

我给VBA下的定义:VBA是个人小型自动化处理的有效工具。可以大大提高自己的劳动效率,而且可以提高数据的准确性。我这里专注VBA,将我多年的经验汇集在VBA系列九套教程中。 作为我的学员要利用我的积木编程思想,积木编程最重要的是积木如何搭建…

基于Python的爬取天气数据及可视化分析

项目查看:基于Python的爬取天气数据及可视化分析 摘 要 天气数据视化系统是一种能自动从网络上收集水情信息分析的工具,可根据用户的需求定向采集特定天气数据信息来作可视化分析,自动在网络上获取网页源码。对于天气数据视化系统信息数量较…

【maya 入门笔记】基本视图和拓扑

1. 界面布局 先看基本窗口布局,基本窗口情况如下: 就基本窗口布局的情况来看,某种意义上跟blender更像一点(与3ds max相比)。 那么有朋友就说了,玛格基,那blender最下面的时间轴哪里去了&…

Shell:终端输入一个字符,判断是大写字母小写字母还是数字字符。

#!/bin/bash # 获取用户输入 read char case $char in [[:upper:]]) echo 大写 ;; [[:lower:]]) echo 小写 ;; [1-9]) echo 数字 ;; esac

使用PaddleNLP UIE模型提取上市公司PDF公告关键信息

项目地址:使用PaddleNLP UIE模型抽取PDF版上市公司公告 - 飞桨AI Studio星河社区 (baidu.com) 背景介绍 本项目将演示如何通过PDFPlumber库和PaddleNLP UIE模型,抽取公告中的相关信息。本次任务的PDF内容是破产清算的相关公告,目标是获取受理…

pubg开启之路

概要: pubg中文名绝地求生,一款免费游戏,本篇主要讲述如何在电脑上开始pubg 要想下载并开始玩pubg有两个方法(具体就是两个软件),一个是epic games,另一个是steam 一、加速器是必要的吗? 1、不使用加速…

Pandas数据库大揭秘:read_sql、to_sql 参数详解与实战篇【第81篇—Pandas数据库】

Pandas数据库大揭秘:read_sql、to_sql 参数详解与实战篇 Pandas是Python中一流的数据处理库,而数据库则是数据存储和管理的核心。将两者结合使用,可以方便地实现数据的导入、导出和分析。本文将深入探讨Pandas中用于与数据库交互的两个关键方…

代码随想录 Leetcode135. 分发糖果

题目&#xff1a; 代码(首刷看解析 2024年2月15日&#xff09;&#xff1a; class Solution { public:int candy(vector<int>& ratings) {vector<int> left(ratings.size(), 1);vector<int> right(ratings.size(), 1);for (int i 1; i < ratings.si…

Docker安装和使用Redis

Docker安装和使用Redis 一、拉取 Redis 镜像二、根据镜像运行容器三、配置 Redis 密码1、进入 redis 容器内部2、使用 redis 命令行设置密码 一、拉取 Redis 镜像 docker pull redis二、根据镜像运行容器 docker run \ --name redis \-p 6379:6379 \-d \redis \redis-server …

Object

Object类的作用 Object类是Java中所有类的父类&#xff0c;所以&#xff0c;Java中所有类的对象都可以直接使用Object类中提供的一些方法 Object类的常见方法 方法名说明public String toString()返回对象的字符串表示形式public boolean equals(Object o)判断两个对象…

JAVASE进阶:网络编程(编程实现TCP、UDP传输)

&#x1f468;‍&#x1f393;作者简介&#xff1a;一位大四、研0学生&#xff0c;正在努力准备大四暑假的实习 &#x1f30c;上期文章&#xff1a;JAVASE进阶&#xff1a;高级写法——方法引用&#xff08;Mybatis-Plus必学前置知识&#xff09; &#x1f4da;订阅专栏&#x…

pytorch tensor张量的操作

import torch import torch.nn as nn import unittest# 创建一个简单的 Conv2d 层 conv_layer nn.Conv2d(in_channels3, out_channels4, kernel_size3, stride1, padding1) # input_tensor torch.randn(1, 3, 5, 5) input_tensor torch.ones(1, 3, 5, 5) # print("inpu…

【前端工程化面试题】说一下 webpack 的构建流程

类似问题是&#xff0c;说一下 vite 的构建流程&#xff0c;参考这篇文章。 初始化流程 从配置文件和shell 语句中读取合并参数&#xff0c;初始化需要使用的插件和执行环境所需要的参数配置文件默认是 webpack.config.js编译构建流程 解析入口模块&#xff0c;从入口模块开始串…

半导体物理基础-笔记

源内容参考&#xff1a;https://www.bilibili.com/video/BV11U4y1k7zn/?spm_id_from333.337.search-card.all.click&vd_source61654d4a6e8d7941436149dd99026962 半导体物理要解决的四个问题 载流子在哪里&#xff1b;如何获得足够多的载流子&#xff1b;载流子如何运动…

html从零开始8:css3新特性、动画、媒体查询、雪碧图、字体图标【搬代码】

css3新特性 <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta http-equiv"X-UA-Compatible" content"IEedge"><meta name"viewport" content"widthdevice-width, …

linux kernel 内存踩踏之KASAN_SW_TAGS(二)

一、背景 linux kernel 内存踩踏之KASAN&#xff08;一&#xff09;_kasan版本跟hasan版本区别-CSDN博客 上一篇简单介绍了标准版本的KASAN使用方法和实现&#xff0c;这里将介绍KASAN_SW_TAGS和KASAN_HW_TAGS 的使用和背后基本原理&#xff0c;下图是三种方式的对比&#x…

萨科微半导体宋仕强介绍说

萨科微半导体宋仕强介绍说&#xff0c;电源管理芯片是指在电子设备系统中&#xff0c;负责对电能的变换、分配、检测等进行管理的芯片&#xff0c;其性能和可靠性直接影响电子设备的工作效率和使用寿命&#xff0c;是电子设备中的关键器件。萨科微slkor&#xff08;www.slkormi…