深度学习--Matlab使用LSTM长短期记忆网络对负荷进行分类

一、概述

关于LSTM同系列的前一篇文章写的是利用LSTM网络对电力负荷进行预测【LSTM预测】,其本质是sequence-to-sequence problems,序列到序列的预测应用。这里做一下sequence-to-label classification problems,序列到标签的分类应用【LSTM分类】。关于LSTM的网络特性不再赘述。

本篇博文的具体示例是对给定的电力负荷进行分类,电力负荷数据格式为每日96个数据点的一维时间序列值,每条负荷数据均对应一个类型标签,总共类别为6类。其他的例子可以参考官网给定的japaneseVowelsTrainData 案例。

负荷数据是某电力公司内部数据,鉴于保密要求,这里仅描述数据格式,负荷数据集不提供。

  • 类别:6
  • 数据长度:96
  • 训练数据条数:9821
  • 测试数据条数:2456

二、数据格式转换

首先看一下需要传到LSTM网络的训练参数格式。

trainedNet = trainNetwork(C, Y, layers, options);

它必须从序列输入层开始,C是一个包含序列或时间序列预测器的元胞数组。C是d行1列,d代表有多少个训练样本,每个训练样本又包括N行M列,N代表训练样本的数据维度,M代表序列长度,y是标签的分类向量,是categorical类型。

因此,训练数据应该转换成元胞数组,训练数据标签应该转换成categorical类型。

2.1 训练数据格式转换

代码如下所示,用XTrain和YTrain来代替上述训练网络中的C和Y。

dataStandardlized是原始数据标准化后的数据,dataStandardlizedLable是每条数据对应的类别标签,num型。获得XTrain需要通过XTrainData转换成元胞数组,XTrain每一行是一条负荷训练样本数据,即1*96的数据。

YTrain是categorical类别数组,可以通过categorial函数转换,但是输入参数时字符元胞数组,因此现将XTrainLabel转换成字符矩阵,然后再将矩阵转换成元胞数组,最后转换成categorical类型。

%提取训练样本数据
XTrainSize = 9821;
XTrainData = dataStandardlized(1:XTrainSize,:);
XTrainLabel = dataStandardlizedLable(1:XTrainSize,:);%XTrain
for i = 1:size(XTrainData,1)XTrain{i,1} = XTrainData(i,:);
end%YTrain
TrainstrLable = num2str(XTrainLabel);% num to str
for i = 1:size(XTrainData,1)% str matrix to cellTraincellLable{i,1} = TrainstrLable(i,1);
end
YTrain = categorical(TraincellLable);%cell to categorical

2.2 测试数据格式转换

测试数据格式转换方法与训练数据格式转换相同,见代码。

%提取测试样本
XTestData = dataStandardlized(1+XTrainSize:end,:);
XTestLabel = dataStandardlizedLable(1+XTrainSize:end,:);
%XTest
for i = 1:size(XTestData,1)XTest{i,1} = XTestData(i,:);
end
%YTest
TeststrLable = num2str(TestLabel);% num to str
for i = 1:size(XTestData,1)TestcellLable{i,1} = TeststrLable(i,1);% str matrix to cell
end
YTest = categorical(TestcellLable);%cell to categorical

三、网络参数设置

前面讲到了TrainNetwork的C和Y,这里描述一下网络参数 layers和options的具体配置。

3.1 layers

layers用于定义训练网络的架构,按照网络架构的先后,依次填写到layers的每一行。

首先定义LSTM网络架构:

  • 将输入大小指定为序列大小 1(输入数据的维度,指同一时间下的数据维度)
  • 指定具有 100 个隐含单元的双向 LSTM 层,并输出序列的最后一个元素。
  • 指定六个类,包含大小为 1 的全连接层,后跟 softmax 层和分类层。
inputSize = 1;
numHiddenUnits = 100;
numClasses = 6;layers = [ ...sequenceInputLayer(inputSize)bilstmLayer(numHiddenUnits,'OutputMode','last')fullyConnectedLayer(numClasses)softmaxLayerclassificationLayer]

具体地:  

  1. sequenceInputLayer(inputSize):序列输入层,指定输入维度
  2. bilstmLayer(numHiddenUnits,'OutputMode','last'):双向LSTM层,指定隐藏节点,输出模式为‘last’即输出最后一个分类值
  3. fullyConnectedLayer(numClasses):全连接层,指定输出类别的个数
  4. softmaxLayer:这层是输出各类别分类的的概率
  5. classificationLayer:分类层,输出最后的分类结果,类似于概率竞争投票。

3.2 options

options用于指定训练网络的优化选项,通过调用trainingOptions进行设置。

此处指定训练选项:

  • 求解器为 'adam'
  • 梯度阈值为 1,最大轮数为 100。
  • 100作为小批量数。
  • 填充数据以使长度与最长序列相同,序列长度指定为 'longest'。
  • 数据保持按序列长度排序的状态,不打乱数据。
  • 'ExecutionEnvironment' 指定为 'cpu',设定为'auto'表明使用GPU。
maxEpochs = 100;
miniBatchSize = 100;options = trainingOptions('adam', ...'ExecutionEnvironment','cpu', ...'GradientThreshold',1, ...'MaxEpochs',maxEpochs, ...'MiniBatchSize',miniBatchSize, ...'SequenceLength','longest', ...'Shuffle','never', ...'Verbose',0, ...'Plots','training-progress');

四、训练LSTM网络

将前面准备好的参数送入训练网络,等待训练结束。

net = trainNetwork(XTrain,YTrain,layers,options);

我这里的训练时间非常长,当然训练过程与隐藏节点数,训练数据的维度和训练次数以及电脑配置有关系,我这里单CPU训练耗时112分钟。

五、利用LSTM网络进行分类

利用标准结果和分类结果计算分类的正确率。

使用classify函数进行分类,同训练过程一样,仍然要指定小批量大小为100,指定组内数据按照最长的数据填充。

miniBatchSize = 100;
YPred = classify(net,XTest, ...'SequenceLength','longest','MiniBatchSize',miniBatchSize);
%计算分类准确度
acc = sum(YPred == YTest)./numel(YTest)

可以看到分类精度达到92%,还是很不错了。

------分享知识,让人愉悦,原创博文,支持请点赞。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/405383.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

VC跨进程数据(结构体)传递-WM_COPYDATA

两个测试程序,都是MFC基于对话框的应用程序,一个是发送者,一个是接收者。 两个程序都使用同一个结构体: typedef struct {char imsi[20];char options[512]; }_tagResult;发送者:按钮点击事件: void CCa…

[react] 什么渲染劫持?

[react] 什么渲染劫持? 首先,什么是渲染劫持:渲染劫持的概念是控制组件从另一个组件输出的能力,当然这个概念一般和react中的高阶组件(HOC)放在一起解释比较有明了。 高阶组件可以在render函数中做非常多…

咨询的真相8:咨询业的“前世今生”

第四节 咨询业的奶酪究竟有多大 2001年,又是从美国,传来了一本风靡全球的畅销书,讲的是四个老鼠和一块奶酪的故事。从此,奶酪就成了众人争抢之物的代名词。美国人的创新精神不得不佩服,忒俗的一个道理经他们仅从形式上…

matlab LSTM序列分类的官方示例

matlab版本是2018b及其以上。 %% %加载序列数据 %数据描述:总共270组训练样本共分为9类,每组训练样本的训练样个数不等,每个训练训练样本由12个特征向量组成, [XTrain,YTrain] japaneseVowelsTrainData; %数据可视化 figure plo…

【MySQL】Java对SQL时间类型的操作(获得当前、昨天、前年。。时间)

Java获得当前时间1 java.util.Date date new java.util.Date(); 2 Timestamp time new Timestamp(date.getTime()); Java获得昨天的时间1 Calendar cal Calendar.getInstance(); 2 cal.add(Calendar.DATE, -1); 3 String a new SimpleDateFormat( "yyyy-MM-dd ")…

[react] 说说你对React的渲染原理的理解

[react] 说说你对React的渲染原理的理解 1.单向数据流。React是一个MVVM框架,简单来说是在MVC的模式下在前端部分拆分出数据层和视图层。单向数据流指的是只能由数据层的变化去影响视图层的变化,而不能反过来(除非双向绑定) 2.数…

MATLAB K-means聚类代码讲解

一、概述 K-means聚类采用类内距离和最小的方式对数据分类,MATLAB中自带K-means算法,最简单的调用如下: idxkmeans(x,k) 将n-by-p数据矩阵x中的数据划分为k个类簇。x的行对应数据条数,x的列对应数据的维度。注意:当…

树型控件使用

树型控件的使用: 1、设置树型控件的风格。 DWORD dwStyle GetWindowLong(m_TreeCtrl.m_hWnd, GWL_STYLE); dwStyle | TVS_HASBUTTONS | TVS_HASLINES | TVS_LINESATROOT; SetWindowLong(m_TreeCtrl.m_hWnd, GWL_STYLE, dwStyle); TVS_HASBUTTONS、TVS_HASLINES等风格有很多…

带界面的OCX制作实例

制作一个有界面的OCX,并进行测试。代码下载 一、制作一个有界面的OCX: 设置该对话框的属性(关键噢): 给添加的对话框资源关联一个类CDlgTest,基类是:CDialog,如下: 给CO…

[react] componentWillUpdate可以直接修改state的值吗

[react] componentWillUpdate可以直接修改state的值吗 1: 不行,这样会导致无限循环报错。 2:在react中直接修改state,render函数不会重新执行渲染,应使用setState方法进行修改 个人简介 我是歌谣,欢迎和大家一起交流…

mysql 分组后取每个组内最新的一条数据

首先,将按条件查询并排序的结果查询出来。 1 mysql> select accepttime,user,job from tuser_job where user 8 order by accepttime desc;2 --------------------------------3 | accepttime | user | job |4 --------------------------------5 | 20…

Qt C++ 命名空间namespaces讲解

一、概述 命名空间 namespace 将一组去哪聚范围内有效的类、对象或者函数组织到一个命名的名字下边,将全局范围分割成多个子域,每个子域就叫做命名空间。作用是在大工程中避免多个类和文件出现相同的成员名称。 命名空间使用的格式为: nam…

从淘宝数据结构来看电子商务中商品属性设计

淘宝名词解释 产品 和 商品的区别: 淘宝标准化产品,由类目关键属性唯一确定。如:手机类目,关键属性是品牌和型号,Nokia N95就是一个产品,nokia是品牌,N95是型号。产品除了关键属性还包括一般信息、销售属性和非关键属性…

linux串口驱动分析

linux串口驱动分析硬件资源及描述 s3c2440A 通用异步接收器和发送器(UART)提供了三个独立的异步串行 I/O(SIO)端口,每个端口都可以在中断模式或 DMA 模式下操作。UART 使用系统时钟可以支持最高 115.2Kbps 的波特率。每…

用C++实现网络编程---抓取网络数据包的实现方法

From: http://blog.csdn.net/zjl_1026_2001/article/details/2191311 做过网管或协议分析的人一般都熟悉sniffer这个工具,它可以捕捉流经本地网卡的所有数据包。抓取网络数据包进行分析有很多用处,如分析网络是否有网络病毒等异常数据,通信协…

二分图----最大匹配,最小点覆盖,最大点独立集

一.二分图 二分图又称作二部图,是图论中的一种特殊模型。 设G(V,E)是一个无向图,如果顶点V可分割为两个互不相交的子集(A,B),并且图中的每条边(i,j)所关联的两个顶点i和j分别属于这两个不同的顶点集(i in A…

[react] 怎么使用Context开发组件?

[react] 怎么使用Context开发组件? import React, {Component} from react// 首先创建一个 context 对象这里命名为:ThemeContext const ThemeContext React.createContext(light)// 创建一个祖先组件组件 内部使用Provier 这个对象创建一个组件 其中…

Linux 进程通信 -- 信号

一、概述 信号用于保持进程间的通信,可以备发送到一个进程或者一组进程,发送给进程的这个唯一信息通常是标志信号的一个数。信号可从键盘终端产生、虚拟内存中非法访问系统资源等情况下产生。信号异步发生,收到信号的进程可以采取某种动作或…

简单理解Socket

TCP/IP 要想理解socket首先得熟悉一下TCP/IP协议族, TCP/IP(Transmission Control Protocol/Internet Protocol)即传输控制协议/网间协议,定义了主机如何连入因特网及数据如何再它们之间传输的标准&…

女人必知:10个好习惯 让老公不想出轨

阅读提示:要知道,妻子这10个动作是征求了数百名老公的意见之后进行总结得出的,不仅效果显著,杀伤力强,最关键的是简单易行。女人必知:10个好习惯 让老公不想出轨  1.老公累了,靠在沙发上睡了&…