66、基于长短期记忆 (LSTM) 网络对序列数据进行分类

1、基于长短期记忆 (LSTM) 网络对序列数据进行分类的原理及流程

基于长短期记忆(LSTM)网络对序列数据进行分类是一种常见的深度学习任务,适用于处理具有时间或序列关系的数据。下面是在Matlab中使用LSTM网络对序列数据进行分类的基本原理和流程:

  1. 准备数据

    • 确保数据集中包含带有标签的序列数据,例如时间序列数据、文本数据等。
    • 将数据进行预处理和归一化,以便输入到LSTM网络中。
  2. 构建LSTM网络

    • 在Matlab中,可以使用内置函数 lstmLayer 来构建LSTM层。
    • 指定输入数据维度、LSTM单元数量、输出层大小等参数。
    • 通过 layers = [sequenceInputLayer(inputSize), lstmLayer(numHiddenUnits), fullyConnectedLayer(numClasses), classificationLayer()] 构建完整的LSTM分类网络。
  3. 定义训练选项

    • 设置训练选项,例如学习率、最大迭代次数、小批量大小等。
    • 使用 trainingOptions 函数来定义训练选项。
  4. 训练网络

    • 使用 trainNetwork 函数来训练构建好的LSTM网络。
    • 输入训练数据和标签,并使用定义好的训练选项进行训练。
  5. 评估网络性能

    • 使用测试数据评估训练好的网络的性能,可以计算准确率、混淆矩阵等。
    • 通过 classify 函数对新数据进行分类预测。
  6. 模型调优

    • 可以通过调整LSTM网络结构、训练参数等进行进一步优化模型性能。

在实际的应用中,可以根据具体数据和任务需求对LSTM网络进行调整和优化,以获得更好的分类性能。Matlab提供了丰富的工具和函数来支持LSTM网络的构建、训练和评估,利用这些工具可以更高效地完成序列数据分类任务。

2、基于长短期记忆 (LSTM) 网络对序列数据进行分类说明

使用 LSTM 神经网络对序列数据进行分类,LSTM 神经网络将序列数据输入网络,并根据序列数据的各个时间步进行预测。

 

3、加载序列数据

1)说明

使用 Waveform 数据集,训练数据包含四种波形的时间序列数据。每个序列有三个通道,且长度不同。

从 WaveformData 加载示例数据。

序列数据是序列的 numObservations×1 元胞数组,其中 numObservations 是序列数。每个序列都是一个 numTimeSteps×-numChannels 数值数组,其中 numTimeSteps 是序列的时间步,numChannels 是序列的通道数。标签数据是 numObservations×1 分类向量。

2)加载数据代码

load WaveformData 

3)绘制部分序列

代码

numChannels = size(data{1},2);idx = [3 4 5 12];
figure
tiledlayout(2,2)
for i = 1:4nexttilestackedplot(data{idx(i)},DisplayLabels="Channel "+string(1:numChannels))xlabel("Time Step")title("Class: " + string(labels(idx(i))))
end

视图效果

5e726cb5c6544a79826ce1ec2d56f905.png

4)查看分类

实现代码

classNames = categories(labels)classNames = 4×1 cell{'Sawtooth'}{'Sine'    }{'Square'  }{'Triangle'}

5)划分数据

说明

使用 trainingPartitions 函数将数据划分为训练集(包含 90% 数据)和测试集(包含其余 10% 数据)

实现代码 

numObservations = numel(data);
[idxTrain,idxTest] = trainingPartitions(numObservations,[0.9 0.1]);
XTrain = data(idxTrain);
TTrain = labels(idxTrain);XTest = data(idxTest);
TTest = labels(idxTest);

4、准备要填充的数据

1)说明

默认情况下,软件将训练数据拆分成小批量并填充序列,使它们具有相同的长度

2)获取观测值序列长度代码

numObservations = numel(XTrain);
for i=1:numObservationssequence = XTrain{i};sequenceLengths(i) = size(sequence,1);
end

3)序列长度排序代码

[sequenceLengths,idx] = sort(sequenceLengths);
XTrain = XTrain(idx);
TTrain = TTrain(idx);

4)查看序列长度

代码

figure
bar(sequenceLengths)
xlabel("Sequence")
ylabel("Length")
title("Sorted Data")

视图效果

28fa6a44a83743e8b30660fcbb63a929.png

 

5、定义 LSTM 神经网络架构

1)说明

将输入大小指定为输入数据的通道数。

指定一个具有 120 个隐藏单元的双向 LSTM 层,并输出序列的最后一个元素。

最后,包括一个输出大小与类的数量匹配的全连接层,后跟一个 softmax 层。

2)实现代码

numHiddenUnits = 120;
numClasses = 4;layers = [sequenceInputLayer(numChannels)bilstmLayer(numHiddenUnits,OutputMode="last")fullyConnectedLayer(numClasses)softmaxLayer]layers = 4×1 Layer array with layers:1   ''   Sequence Input    Sequence input with 3 dimensions2   ''   BiLSTM            BiLSTM with 120 hidden units3   ''   Fully Connected   4 fully connected layer4   ''   Softmax           softmax

6、指定训练选项

1)说明

使用 Adam 求解器进行训练。

进行 200 轮训练。

指定学习率为 0.002。

使用阈值 1 裁剪梯度。

为了保持序列按长度排序,禁用乱序。

在图中显示训练进度并监控准确度。

2)实现代码

options = trainingOptions("adam", ...MaxEpochs=200, ...InitialLearnRate=0.002,...GradientThreshold=1, ...Shuffle="never", ...Plots="training-progress", ...Metrics="accuracy", ...Verbose=false);

7、训练 LSTM 神经网络

1)说明

使用 trainnet 函数训练神经网络

2)实现代码

net = trainnet(XTrain,TTrain,layers,"crossentropy",options);

3)视图效果 

cf4856b632c144a58ea5c904d194e96d.png

8、测试 LSTM 神经网络

1)对测试数据进行分类,并计算预测的分类准确度。

numObservationsTest = numel(XTest);
for i=1:numObservationsTestsequence = XTest{i};sequenceLengthsTest(i) = size(sequence,1);
end[sequenceLengthsTest,idx] = sort(sequenceLengthsTest);
XTest = XTest(idx);
TTest = TTest(idx);

2)对测试数据进行分类,并计算预测的分类准确度。

scores = minibatchpredict(net,XTest);
YTest = scores2label(scores,classNames);

3)计算分类准确度

acc = mean(YTest == TTest)acc = 0.8700

4)混淆图中显示分类结果

figure
confusionchart(TTest,YTest)

0b999fc6591a4438bf3efd7c41517798.png 

9、总结

基于长短期记忆(LSTM)网络对序列数据进行分类是一种重要的深度学习任务,适用于处理具有序列关系的数据,如时间序列数据、自然语言处理等。以下是对使用LSTM网络进行序列数据分类的总结:

  1. LSTM网络结构

    • LSTM是一种适用于处理长期依赖问题的循环神经网络(RNN)变种,能够有效地捕捉序列数据中的长期依赖关系。
    • LSTM网络包含输入门、遗忘门、输出门等核心部分,通过这些门控机制来控制信息的输入、遗忘和输出。
  2. 数据准备

    • 准备带有标签的序列数据,确保数据格式正确且包含标签信息。
    • 进行数据预处理和归一化操作,以便于网络训练。
  3. 网络构建

    • 使用深度学习框架(如TensorFlow、Pytorch或Matlab)构建LSTM网络,定义输入层、LSTM层、全连接层和输出层。
    • 设置网络参数,包括输入维度、LSTM单元个数、输出类别数等。
  4. 模型训练

    • 使用标记好的数据集对构建好的LSTM网络进行训练。
    • 设置优化器、损失函数和训练参数,如学习率、迭代次数等。
    • 调整网络参数以提高模型性能,避免过拟合。
  5. 模型评估

    • 使用验证集或测试集对训练好的模型进行评估,计算准确率、精确率、召回率等指标。
    • 分析模型在不同类别上的表现,进行结果可视化分析。
  6. 模型应用和优化

    • 将训练好的模型用于实际应用中,对新数据进行分类预测。
    • 根据实际需求对模型进行调优和优化,如调整网络结构、训练参数或使用模型集成等方法。

综合来看,基于LSTM网络对序列数据进行分类是一种强大的方法,可在许多领域中发挥作用。通过合理设计网络结构、优化数据准备和训练过程,可以有效地构建出具有良好泛化能力的序列数据分类模型。

10、源代码

代码

%% 基于长短期记忆 (LSTM) 网络对序列数据进行分类
%使用 LSTM 神经网络对序列数据进行分类,LSTM 神经网络将序列数据输入网络,并根据序列数据的各个时间步进行预测。%% 加载序列数据
%使用 Waveform 数据集,训练数据包含四种波形的时间序列数据。每个序列有三个通道,且长度不同。
%从 WaveformData 加载示例数据。
%序列数据是序列的 numObservations×1 元胞数组,其中 numObservations 是序列数。每个序列都是一个 numTimeSteps×-numChannels 数值数组,其中 numTimeSteps 是序列的时间步,numChannels 是序列的通道数。标签数据是 numObservations×1 分类向量。
load WaveformData 
%绘制部分序列
numChannels = size(data{1},2);idx = [3 4 5 12];
figure
tiledlayout(2,2)
for i = 1:4nexttilestackedplot(data{idx(i)},DisplayLabels="Channel "+string(1:numChannels))xlabel("Time Step")title("Class: " + string(labels(idx(i))))
end
%查看类名称
classNames = categories(labels)
%划分数据
%使用 trainingPartitions 函数将数据划分为训练集(包含 90% 数据)和测试集(包含其余 10% 数据),
numObservations = numel(data);
[idxTrain,idxTest] = trainingPartitions(numObservations,[0.9 0.1]);
XTrain = data(idxTrain);
TTrain = labels(idxTrain);XTest = data(idxTest);
TTest = labels(idxTest);
%% 准备要填充的数据
%默认情况下,软件将训练数据拆分成小批量并填充序列,使它们具有相同的长度
%获取观测值序列长度
numObservations = numel(XTrain);
for i=1:numObservationssequence = XTrain{i};sequenceLengths(i) = size(sequence,1);
end
%序列长度排序
[sequenceLengths,idx] = sort(sequenceLengths);
XTrain = XTrain(idx);
TTrain = TTrain(idx);
%查看序列长度
figure
bar(sequenceLengths)
xlabel("Sequence")
ylabel("Length")
title("Sorted Data")%% 定义 LSTM 神经网络架构
%将输入大小指定为输入数据的通道数。
%指定一个具有 120 个隐藏单元的双向 LSTM 层,并输出序列的最后一个元素。
%最后,包括一个输出大小与类的数量匹配的全连接层,后跟一个 softmax 层。
numHiddenUnits = 120;
numClasses = 4;layers = [sequenceInputLayer(numChannels)bilstmLayer(numHiddenUnits,OutputMode="last")fullyConnectedLayer(numClasses)softmaxLayer]
%% 指定训练选项
%使用 Adam 求解器进行训练。
%进行 200 轮训练。
%指定学习率为 0.002。
%使用阈值 1 裁剪梯度。
%为了保持序列按长度排序,禁用乱序。
%在图中显示训练进度并监控准确度。
options = trainingOptions("adam", ...MaxEpochs=200, ...InitialLearnRate=0.002,...GradientThreshold=1, ...Shuffle="never", ...Plots="training-progress", ...Metrics="accuracy", ...Verbose=false);
%%  训练 LSTM 神经网络
%使用 trainnet 函数训练神经网络
net = trainnet(XTrain,TTrain,layers,"crossentropy",options);
%% 测试 LSTM 神经网络
%对测试数据进行分类,并计算预测的分类准确度。
numObservationsTest = numel(XTest);
for i=1:numObservationsTestsequence = XTest{i};sequenceLengthsTest(i) = size(sequence,1);
end[sequenceLengthsTest,idx] = sort(sequenceLengthsTest);
XTest = XTest(idx);
TTest = TTest(idx);
%对测试数据进行分类,并计算预测的分类准确度。
scores = minibatchpredict(net,XTest);
YTest = scores2label(scores,classNames);
%计算分类准确度
acc = mean(YTest == TTest)
%混淆图中显示分类结果
figure
confusionchart(TTest,YTest)

工程文件

https://download.csdn.net/download/XU157303764/89499744

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/38476.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

AD20操作使用part1

AD20的使用 原理图库,也称为元件库。文件的后缀为 .SchLib。 原理图就是表示电路板上各器件之间连接原理的图表,原理图绘制的前提是拥有所需的原理图库。 PCB库,也称为封装库。文件的后缀为 .PcbLib ,在该文件中可以画各个器件的封…

简单爬虫案例——爬取快手视频

网址:aHR0cHM6Ly93d3cua3VhaXNob3UuY29tL3NlYXJjaC92aWRlbz9zZWFyY2hLZXk9JUU2JThCJTg5JUU5JTlEJUEy 找到视频接口: 视频链接在photourl中 完整代码: import requestsimport re url https://www.kuaishou.com/graphql cookies {did: web_…

《数字图像处理与机器视觉》案例三 (基于数字图像处理的物料堆积角快速测量)

一、前言 物料堆积角是反映物料特性的重要参数,传统的测量方法将物料自然堆积,测量物料形成的圆锥表面与水平面的夹角即可,该方法检测效率低。随着数字成像设备的推广和应用,应用数字图像处理可以更准确更迅速地进行堆积角测量。 …

便携式气象站:科技助力,气象观测的新选择

在气象观测领域,便携式气象站不仅安装方便、操作简单,而且功能齐全、性能稳定,为气象观测带来了极大的便利。 首先,便携式气象站的便携性,与传统的气象站相比,它不需要复杂的安装过程和固定的设备基础&…

Visual Studio 设置回车代码补全

工具 -> 选项 -> 文本编辑器 -> C/C -> 高级 -> 主动提交成员列表 设置为TRUE

机器学习模型的训练过程和预测过程,用太上老君炼丹炉和孙悟空进行比喻--九五小庞

训练过程:太上老君炼丹 在古老的传说中,太上老君是一位擅长炼丹的神仙。他将各种珍贵的草药、矿石放入炼丹炉中,经过长时间的高温炼制,最终炼出神奇的丹药。这个过程与机器学习模型的训练过程有着惊人的相似之处。 收集原料&…

Batch脚本中的用户交互:CHOICE命令的妙用

Batch脚本中的用户交互:CHOICE命令的妙用 在自动化脚本的世界中,Batch文件以其简洁和高效而著称,但有时我们也需要与用户进行交互以获取输入或提供选择。这就是CHOICE命令大放异彩的地方。本文将深入探讨如何在Batch文件中使用CHOICE命令&am…

从实验室走向商业化,人形机器人时代要来了?

从国内市场看,据机构报告显示,预计到2026年中国人形机器人产业规模将突破200亿元。特别是在生成式AI技术大爆发的当下,未来人形机器人更是极有可能实现超预期增长。 近日,特斯拉CEO埃隆马斯克(Elon Musk)在特斯拉2024年股东大会上…

【Unity设计模式】✨使用 MVC 和 MVP 编程模式

前言 最近在学习Unity游戏设计模式,看到两本比较适合入门的书,一本是unity官方的 《Level up your programming with game programming patterns》 ,另一本是 《游戏编程模式》 这两本书介绍了大部分会使用到的设计模式,因此很值得学习 本…

System.currentTimeMillis() JAVA 转C#

JAVA中的System.currentTimeMillis() ,指获取当前时间与1970年1月1日00:00:00 GMT之间所差的毫秒数的方法。 这个方法返回的是一个long类型的值,表示从某个固定时间点(通常是UNIX纪元,即1970年1月1日00:00:00 GMT)到…

六西格玛培训引领久立特材品质新高度,行业领军再升级

久立特材六西格玛管理项目于6月27 日启动。久立特材作为行业内的领军企业,此次引入六西格玛管理法,旨在进一步提升企业运营效率和产品质量,实现持续改进和卓越运营。 久立特材的高层领导与张驰咨询的资深顾问朱老师共同出席项目启动仪式&am…

心理学|人格心理学——人格心理学单科作业(中科院)

一、单选题(第1-40小题,每题1.5分,共计60分。) 1、没有两个人能对同一事物做出相同的反应,反映的是人格的( ) 分值1.5分 A、稳定性 B、独特性 C、统合性 D、功能性 正确答案: B、独特性 2、人格决定一个人的生活方式,甚至有时会决定一个人的命运,反映的…

【python刷题】蛇形方阵

题目描述 给出一个不大于 99 的正整数n,输出n*n的蛇形方阵。从左上角填上1开始,顺时针方向依次填入数字,如同样例所示。注意每个数字有都会占用3个字符,前面使用空格补齐。 输入 输入一个正整数n,含义如题所述 输出 输出符合…

python实现windows 10 定时自动pppoe拨号-魔行观察

此脚本适用于windows 10系统的vps服务器 import subprocess import time# PPPoe 用户名和密码 USERNAME 81239078262 PASSWORD 345543 # 拨号连接名称 CONNECTION_NAME pppoedef dial_pppoe():subprocess.run([rasdial, CONNECTION_NAME, USERNAME, PASSWORD])# 如果需要定…

【正点原子K210连载】 第十二章 跑马灯实验 摘自【正点原子】DNK210使用指南-CanMV版指南

1)实验平台:正点原子ATK-DNK210开发板 2)平台购买地址https://detail.tmall.com/item.htm?id731866264428 3)全套实验源码手册视频下载地址: http://www.openedv.com/docs/boards/xiaoxitongban 第十二章 跑马灯实验…

使用 PHP 和 Selenium WebDriver 实现爬虫

随着互联网的蓬勃发展,我们可以轻松地获取海量的数据。而爬虫则是其中一种常见的数据获取方式,特别是在需要大量数据的数据分析和研究领域中,爬虫的应用越来越广泛。本文将介绍如何使用 php 和 selenium webdriver 实现爬虫。 一、什么是 Se…

Cmake使用笔记1

cmake 问题1: Selecting Windows SDK version 10.0.19041.0 to target Windows 10.0.19043. 问题分析 在Windows平台上,使用CMake或Visual Studio等开发工具时,选择正确的Windows SDK版本以确保你的应用程序能够针对目标Windows版本进行编…

iptables(12)实际应用举例:策略路由、iptables转发、TPROXY

简介 前面的文章中我们已经介绍过iptables的基本原理,表、链的基本操作,匹配条件、扩展模块、自定义链以及网络防火墙、NAT等基本配置及原理。 这篇文章将以实际应用出发,列举一个iptables的综合配置使用案例,将我们前面所涉及到的功能集合起来,形成一个完整的配置范例。…

SpringMVC的架构有什么优势?——控制器(一)

文章目录 控制器(Controller)1. 控制器(Controller):2. 请求映射(Request Mapping):3. 参数绑定(Request Parameters Binding):4. 视图解析器(View Resolver):5. 数据绑定(Data Binding):6. 表单验证(Form Validation)…

TAPD项目管理软件无法与企业微信进行关联

TAPD一段时间未使用后,需要重新启动,此时会出现你的企业微信尚未与TAPD账号关联的提示 解决方案:找到TAPD应用,先删除应用,然后再解除禁用即可