matlab LSTM序列分类的官方示例

matlab版本是2018b及其以上。

%%
%加载序列数据
%数据描述:总共270组训练样本共分为9类,每组训练样本的训练样个数不等,每个训练训练样本由12个特征向量组成,
[XTrain,YTrain] = japaneseVowelsTrainData;
%数据可视化
figure
plot(XTrain{1}')
xlabel("Time Step")
title("Training Observation 1")
legend("Feature " + string(1:12),'Location','northeastoutside')
%%
%LSTM可以将分组后等量的训练样本进行训练,从而提高训练效率
%如果每组的样本数量不同,进行小批量拆分,则需要尽量保证分块的训练样本数相同
%首先找到每组样本数和总的组数
numObservations = numel(XTrain);
for i=1:numObservationssequence = XTrain{i};sequenceLengths(i) = size(sequence,2);
end
%绘图前后排序的各组数据个数
figure
subplot(1,2,1)
bar(sequenceLengths)
ylim([0 30])
xlabel("Sequence")
ylabel("Length")
title("Sorted Data")
%按序列长度对测试数据进行排序
[sequenceLengths,idx] = sort(sequenceLengths);
XTrain = XTrain(idx);
YTrain = YTrain(idx);
subplot(1,2,2)
bar(sequenceLengths)
ylim([0 30])
xlabel("Sequence")
ylabel("Length")
title("Sorted Data")%%
%设置LSTM训练数据的小批量分组个数
miniBatchSize = 27;%%
%定义LSTM网络架构:
%将输入大小指定为序列大小 12(输入数据的维度)
%指定具有 100 个隐含单元的双向 LSTM 层,并输出序列的最后一个元素。
%指定九个类,包含大小为 9 的全连接层,后跟 softmax 层和分类层。
inputSize = 12;
numHiddenUnits = 100;
numClasses = 9;layers = [ ...sequenceInputLayer(inputSize)bilstmLayer(numHiddenUnits,'OutputMode','last')fullyConnectedLayer(numClasses)softmaxLayerclassificationLayer]%%
%指定训练选项:
%求解器为 'adam'
%梯度阈值为 1,最大轮数为 100。
% 27 作为小批量数。
%填充数据以使长度与最长序列相同,序列长度指定为 'longest'。
%数据保持按序列长度排序的状态,不打乱数据。
% 'ExecutionEnvironment' 指定为 'cpu',设定为'auto'表明使用GPU。maxEpochs = 100;
miniBatchSize = 27;options = trainingOptions('adam', ...'ExecutionEnvironment','cpu', ...'GradientThreshold',1, ...'MaxEpochs',maxEpochs, ...'MiniBatchSize',miniBatchSize, ...'SequenceLength','longest', ...'Shuffle','never', ...'Verbose',0, ...'Plots','training-progress');%%
%训练LSTM网络
net = trainNetwork(XTrain,YTrain,layers,options);%%
%测试LSTM网络
%加载测试集
[XTest,YTest] = japaneseVowelsTestData;%由于LSTM已经按照相似长度的小批量分组27,测试需要按照相同方式对数据进行排序处理。
numObservationsTest = numel(XTest);
for i=1:numObservationsTestsequence = XTest{i};sequenceLengthsTest(i) = size(sequence,2);
end
[sequenceLengthsTest,idx] = sort(sequenceLengthsTest);
XTest = XTest(idx);
YTest = YTest(idx);%使用classify进行分类,指定小批量大小27,指定组内数据按照最长的数据填充
miniBatchSize = 27;
YPred = classify(net,XTest, ...'MiniBatchSize',miniBatchSize, ...'SequenceLength','longest');
%计算分类准确度
acc = sum(YPred == YTest)./numel(YTest)

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/405379.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【MySQL】Java对SQL时间类型的操作(获得当前、昨天、前年。。时间)

Java获得当前时间1 java.util.Date date new java.util.Date(); 2 Timestamp time new Timestamp(date.getTime()); Java获得昨天的时间1 Calendar cal Calendar.getInstance(); 2 cal.add(Calendar.DATE, -1); 3 String a new SimpleDateFormat( "yyyy-MM-dd ")…

[react] 说说你对React的渲染原理的理解

[react] 说说你对React的渲染原理的理解 1.单向数据流。React是一个MVVM框架,简单来说是在MVC的模式下在前端部分拆分出数据层和视图层。单向数据流指的是只能由数据层的变化去影响视图层的变化,而不能反过来(除非双向绑定) 2.数…

MATLAB K-means聚类代码讲解

一、概述 K-means聚类采用类内距离和最小的方式对数据分类,MATLAB中自带K-means算法,最简单的调用如下: idxkmeans(x,k) 将n-by-p数据矩阵x中的数据划分为k个类簇。x的行对应数据条数,x的列对应数据的维度。注意:当…

树型控件使用

树型控件的使用: 1、设置树型控件的风格。 DWORD dwStyle GetWindowLong(m_TreeCtrl.m_hWnd, GWL_STYLE); dwStyle | TVS_HASBUTTONS | TVS_HASLINES | TVS_LINESATROOT; SetWindowLong(m_TreeCtrl.m_hWnd, GWL_STYLE, dwStyle); TVS_HASBUTTONS、TVS_HASLINES等风格有很多…

带界面的OCX制作实例

制作一个有界面的OCX,并进行测试。代码下载 一、制作一个有界面的OCX: 设置该对话框的属性(关键噢): 给添加的对话框资源关联一个类CDlgTest,基类是:CDialog,如下: 给CO…

[react] componentWillUpdate可以直接修改state的值吗

[react] componentWillUpdate可以直接修改state的值吗 1: 不行,这样会导致无限循环报错。 2:在react中直接修改state,render函数不会重新执行渲染,应使用setState方法进行修改 个人简介 我是歌谣,欢迎和大家一起交流…

mysql 分组后取每个组内最新的一条数据

首先,将按条件查询并排序的结果查询出来。 1 mysql> select accepttime,user,job from tuser_job where user 8 order by accepttime desc;2 --------------------------------3 | accepttime | user | job |4 --------------------------------5 | 20…

Qt C++ 命名空间namespaces讲解

一、概述 命名空间 namespace 将一组去哪聚范围内有效的类、对象或者函数组织到一个命名的名字下边,将全局范围分割成多个子域,每个子域就叫做命名空间。作用是在大工程中避免多个类和文件出现相同的成员名称。 命名空间使用的格式为: nam…

从淘宝数据结构来看电子商务中商品属性设计

淘宝名词解释 产品 和 商品的区别: 淘宝标准化产品,由类目关键属性唯一确定。如:手机类目,关键属性是品牌和型号,Nokia N95就是一个产品,nokia是品牌,N95是型号。产品除了关键属性还包括一般信息、销售属性和非关键属性…

linux串口驱动分析

linux串口驱动分析硬件资源及描述 s3c2440A 通用异步接收器和发送器(UART)提供了三个独立的异步串行 I/O(SIO)端口,每个端口都可以在中断模式或 DMA 模式下操作。UART 使用系统时钟可以支持最高 115.2Kbps 的波特率。每…

用C++实现网络编程---抓取网络数据包的实现方法

From: http://blog.csdn.net/zjl_1026_2001/article/details/2191311 做过网管或协议分析的人一般都熟悉sniffer这个工具,它可以捕捉流经本地网卡的所有数据包。抓取网络数据包进行分析有很多用处,如分析网络是否有网络病毒等异常数据,通信协…

二分图----最大匹配,最小点覆盖,最大点独立集

一.二分图 二分图又称作二部图,是图论中的一种特殊模型。 设G(V,E)是一个无向图,如果顶点V可分割为两个互不相交的子集(A,B),并且图中的每条边(i,j)所关联的两个顶点i和j分别属于这两个不同的顶点集(i in A…

[react] 怎么使用Context开发组件?

[react] 怎么使用Context开发组件? import React, {Component} from react// 首先创建一个 context 对象这里命名为:ThemeContext const ThemeContext React.createContext(light)// 创建一个祖先组件组件 内部使用Provier 这个对象创建一个组件 其中…

Linux 进程通信 -- 信号

一、概述 信号用于保持进程间的通信,可以备发送到一个进程或者一组进程,发送给进程的这个唯一信息通常是标志信号的一个数。信号可从键盘终端产生、虚拟内存中非法访问系统资源等情况下产生。信号异步发生,收到信号的进程可以采取某种动作或…

简单理解Socket

TCP/IP 要想理解socket首先得熟悉一下TCP/IP协议族, TCP/IP(Transmission Control Protocol/Internet Protocol)即传输控制协议/网间协议,定义了主机如何连入因特网及数据如何再它们之间传输的标准&…

女人必知:10个好习惯 让老公不想出轨

阅读提示:要知道,妻子这10个动作是征求了数百名老公的意见之后进行总结得出的,不仅效果显著,杀伤力强,最关键的是简单易行。女人必知:10个好习惯 让老公不想出轨  1.老公累了,靠在沙发上睡了&…

codeforce Gym 100500F Door Lock (二分)

根据题意略推一下&#xff0c;其实就是问你满足(a*(a1))/2 < m < ((a1)*a(a2))/2的a和m-(a*(a1))/2 -1是多少。 二分求解就行了 #include<cstdio>using namespace std; typedef long long ll;int main() {int T;scanf("%d",&T);for(int k 1; k <…

write() vs. writev()

From: http://www.cppblog.com/whoami17/archive/2009/05/10/82452.html 今天突然想比较一下 write() 和 writev() 的性能&#xff0c; 网上google了半天&#xff0c; 竟然没有发现一点有关的数据信息&#xff0c; 自己就测试了一下。 平台如下&#xff1a; CentOS 5.2 Lin…

[react] React Intl是什么原理?

[react] React Intl是什么原理&#xff1f; 实现原理和react-redux的实现原理类似&#xff0c;最外层包一个Provider&#xff0c;利用getChildContext&#xff0c;将intlConfigPropTypes存起来&#xff0c;在FormattedMessage、FormattedNumber等组件或者调用injectIntl生成的…

linux下GPRS模块ppp拨号上网

&#xfeff;&#xfeff;交叉编译器&#xff1a;arm-linux-gcc-4.5.4 Linux内核版本&#xff1a;Linux-3.0 主机操作系统&#xff1a;Centos 6.5 开发板&#xff1a;FL2440 GPRS:SIM900A 在开发SIM900模块之前&#xff0c;开发板已经加载了linux内核以及文件系统&#xf…