MATLAB环境下基于PSO-DBN-ELM方法的图像分类

在纯数据驱动的图像识别方法中,深度信念网络DBN识别模型具备较好的识别性能。对于DBN模型而言,可利用的数据越多,挖掘的信息也越多,建立的模型就越准确。然而DBN本身仍存在一定的不足之处,一方面由于DBN内部包含多层限制玻尔兹曼机RBM,在训练过程中需要针对每一层RBM进行网络参数的选取,又因为其网络训练的黑箱性质,会导致整个调参过程相当繁琐,甚至出现无法收敛的结果,且目前没有DBN网络架构设计相关的标准规则。另一方面DBN本身引入传统的BP算法作为回归层,也继承了其易局部最小化、训练速度慢等问题。

目前对于DBN的改进方法一般分为两种,一是通过预处理(特征融合等)提高输入数据质量,二是通过优化策略(Dropout等)改进模型参数或者选取不同的分类/回归器替换BP层。极限学习机ELM拥有学习速度快、泛化能力强等特点,可以很大程度上改善DBN本身收敛速度过慢、容易陷入局部最优、非线性性能不足等问题,从而提高模型的准确率。

粒子群优化算法PSO是一种群体智能的模型优化算法,主要通过群体信息共享来获取最优解。PSO受启发于鸟类寻找食物时的群体协作能力,它将目标函数类比为鸟群中的鸟,即粒子。函数的在固定范围内寻求最优解类比为鸟在相应范围内的位置上寻找食物。粒子的移动方向依靠每个位置的适应度来决定,适应度越高代表越靠近最优解。每个粒子都拥有独立的位置和移动速度参数,在粒子群体更新过程中,粒子会向适应度最高的方向移动,也有可能向随机方向移动,但不会超出设定范围。

鉴于PSO、DBN和ELM的优势,提出一种基于PSO-DBN-ELM方法的图像分类模型,通过3层DBN提取嵌入在原始图像数据中的特征,然后将特征输入极限学习机(ELM)进行分类。针对隐节点选择困难的问题,采用粒子群算法对隐节点进行自动选择,并以最小的适应度函数(即DBN-ELM交叉验证分类的准确率)为目标。

部分代码如下:

clear all
close all
format compact 
format long
%% 1.数据加载
fprintf(1,'加载数据 \n');
load('drivFace600');%其中1-173为1类,174-343为2类 344-510为3类 511-600为4类,各选择20%作为测试集
%第一类173组
[i1 i2]=sort(rand(173,1)); 
train(1:139,:)=input(i2(1:139),:);     train_label(1:139,1)=output(i2(1:139),1);
test(1:34,:)=input(i2(140:173),:);     test_label(1:34,1)=output(i2(140:173),1);
%第二类有170组
[i1 i2]=sort(rand(170,1));
train(140:275,:)=input(173+i2(1:136),:);    train_label(140:275,1)=output(173+i2(1:136),1);
test(35:68,:)=input(173+i2(137:170),:);     test_label(35:68,1)=output(173+i2(137:170),1);
%第三类有167
[i1 i2]=sort(rand(167,1));
train(276:408,:)=input(343+i2(1:133),:);    train_label(276:408,1)=output(343+i2(1:133),1);
test(69:102,:)=input(343+i2(134:167),:);     test_label(69:102,1)=output(343+i2(134:167),1);
%第4类有90
[i1 i2]=sort(rand(90,1));
train(409:480,:)=input(510+i2(1:72),:);    train_label(409:480,1)=output(510+i2(1:72),1);
test(103:120,:)=input(510+i2(73:90),:);     test_label(103:120,1)=output(510+i2(73:90),1); 
clear i1 i2 input output
%%打乱顺序
k=rand(480,1);[m n]=sort(k);
train=train(n(1:480),:);train_label=train_label(n(1:480),:);
k=rand(120,1);[m n]=sort(k);
test=test(n(1:120),:);test_label=test_label(n(1:120),:);
clear k m n
%no_dims = round(intrinsic_dim(train, 'MLE')); %round四舍五入
%disp(['MLE estimate of intrinsic dimensionality: ' num2str(no_dims)]);
numbatches=10;%数据分块数
numcases=48;%每块数据集的样本个数(不能太小)块数不能超过样本数
numdims=size(train,2);%单个样本的维数
% 训练数据
x=train;%将数据转换成DBN的数据格式
for i=1:numbatchestrain1=x((i-1)*numcases+1:i*numcases,:);batchdata(:,:,i)=train1;
end%将分好的10组数据都放在batchdata中
%% rbm参数
maxepoch=20;%训练rbm的次数
hid=4; %隐含层数
hmax=500;hmin=100; %各隐含层节点数取值区间
tic;
%%
h=PSO_dbnelm_cross(hid,hmax,hmin,batchdata,train,train_label); %PSO优化隐含层节点数
%%
t1=toc
tic;
numpen0=h(1,1); numpen1=h(1,2); numpen2=h(1,3);numpen3=h(1,4); %dbn最终隐含层的节点数
disp('构建一个num2str(H)层的置信网络');
clear i 
%% 训练第1层RBM
fprintf(1,'Pretraining Layer 1 with RBM: %d-%d \n',numdims,numpen0);%6400-500
numhid=numpen0;
restart=1;
rbm1;%使用cd-k训练rbm,注意此rbm的可视层不是二值的,而隐含层是二值的
vishid1=vishid;hidrecbiases=hidbiases;
%% 训练第2层RBM
fprintf(1,'\nPretraining Layer 2 with RBM: %d-%d \n',numpen0,numpen1);%500-200
batchdata=batchposhidprobs;%将第一个RBM的隐含层的输出作为第二个RBM 的输入
numhid=numpen1;%将numpen的值赋给numhid,作为第二个rbm隐含层的节点数
restart=1;
rbm1;
hidpen=vishid; penrecbiases=hidbiases; hidgenbiases=visbiases;
%% 训练第3层RBM
fprintf(1,'\nPretraining Layer 3 with RBM: %d-%d \n',numpen1,numpen2);%200-100
batchdata=batchposhidprobs;%显然,将第二哥RBM的输出作为第三个RBM的输入
numhid=numpen2;%第三个隐含层的节点数
restart=1;
rbm1;
hidpen2=vishid; penrecbiases2=hidbiases; hidgenbiases2=visbiases;
%% 训练第4层RBM
fprintf(1,'\nPretraining Layer 4 with RBM: %d-%d \n',numpen2,numpen3);%200-100
batchdata=batchposhidprobs;%显然,将第二哥RBM的输出作为第三个RBM的输入
numhid=numpen3;%第三个隐含层的节点数
restart=1;
rbm1;
hidpen3=vishid; penrecbiases3=hidbiases; hidgenbiases3=visbiases;%% 训练极限学习机% 训练集特征输出
w1=[vishid1; hidrecbiases]; 
w2=[hidpen; penrecbiases]; 
w3=[hidpen2; penrecbiases2];
w4=[hidpen3; penrecbiases3];
digitdata = [x ones(size(x,1),1)];%x表示train数据集
w1probs = 1./(1 + exp(-digitdata*w1));%w1probs = [w1probs  ones(size(x,1),1)];%
w2probs = 1./(1 + exp(-w1probs*w2));%w2probs = [w2probs ones(size(x,1),1)];%
w3probs = 1./(1 + exp(-w2probs*w3)); %w3probs = [w3probs ones(size(x,1),1)];%
w4probs = 1./(1 + exp(-w3probs*w4)); % 
H = w4probs';  %%第三个rbm的实际输出值,也是elm的输入值H
lamda=0.001;  %% 正则化系数在0.0007-0.00037之间时  测试精度最大81.667%
H=H+1/lamda;  %加入regularization factor
T =train_label';            %训练集标签
T1=ind2vec(T);              %做分类需要先将T转换成向量索引
OutputWeight=pinv(H') *T1'; 
Y=(H' * OutputWeight)';
temp_Y=zeros(1,size(Y,2));
for n=1:size(Y,2)[max_Y,index]=max(Y(:,n));temp_Y(n)=index;
end
Y_train=temp_Y;
%Y_train=vec2ind(temp_Y1);
% 训练集准确率
train_accuracy=sum(Y_train==T)/length(T)
% 训练集实际分类与预测分类对比
figure(1)
plot(Y_train,'bo');hold on 
plot(T,'r*');%% 测试极限学习机
N2 = size(test,1);
% 测试集特征输出
w1=[vishid1; hidrecbiases]; %(784+1*500)
w2=[hidpen; penrecbiases]; %(500+1*500)
w3=[hidpen2; penrecbiases2];%(500+1*2000)
w4=[hidpen3; penrecbiases3];
test1 = [test ones(N2,1)];
w1probs = 1./(1 + exp(-test1*w1));w1probs = [w1probs  ones(N2,1)];
w2probs = 1./(1 + exp(-w1probs*w2)); w2probs = [w2probs ones(N2,1)];
w3probs = 1./(1 + exp(-w2probs*w3)); w3probs = [w3probs ones(N2,1)];
w4probs = 1./(1 + exp(-w3probs*w4));  
H1=w4probs';
%加入正则化系数
H1=H1+1/lamda;
TY=(H1' * OutputWeight)';     %   TY: the actual output of the testing data
temp_Y=zeros(1,size(TY,2));
for n=1:size(TY,2)[max_Y,index]=max(TY(:,n));temp_Y(n)=index;
end
TY1=temp_Y;
% 加载输出
TV=test_label';
% 测试集分类准确率
test_accuracy = sum(TV==TY1) / length(TV)
% 测试集实际分类与预测分类对比
figure(2)
plot(TV,'r*');
hold on
plot(TY1,'bo');
xlabel('测试集样本数')
ylabel('标签种类')
title('测试阶段:实际输出与理想输出的差');
legend('真实值','预测值')
% 程序运行时间
t2=toc

工学博士,担任《Mechanical System and Signal Processing》审稿专家,担任
《中国电机工程学报》优秀审稿专家,《控制与决策》,《系统工程与电子技术》,《电力系统保护与控制》,《宇航学报》等EI期刊审稿专家。

擅长领域:现代信号处理,机器学习,深度学习,数字孪生,时间序列分析,设备缺陷检测、设备异常检测、设备智能故障诊断与健康管理PHM等。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/703933.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

servlet---->request.getHeader(“X-Requested-With“);有什么作用?

X-Requested-With 是一个自定义的HTTP请求头,主要用于在服务器端识别请求是由Ajax技术发起的还是由其他技术发起的。这个请求头是由浏览器或客户端应用程序添加的,因此其值可能取决于发送请求的客户端或者开发者的选择。 如果请求不是通过JavaScript的 …

微信小程序支付(前后端都包含)

Java中换取微信支付唯一订单号(用于换取支付窗口) /*** 微信小程序支付*/PostMapping(value "/xcxPay")ResponseBodypublic Map<String,Object> miniAppPay(RequestBody byte[] req) {HashMap<String, Object> objectObjectMap new HashMap<>();…

【软件测试】--功能测试1

一、测试介绍 什么是软件&#xff1f; 控制计算机硬件工作的工具。 什么是软件测试&#xff1f; 使用技术手段验证软件是否满足需求 软件测试的目的&#xff1f; 减少软件缺陷&#xff0c;保证软件质量。 测试主流技能 1、功能测试 2、自动化测试 3、接口测试 4、性能测试 ​…

MySQL-事务,properties文件解析,连接池

1.事务机制管理 1.1 Transaction事务机制管理 默认情况下是执行一条sql语句就保存一次&#xff0c;那么比如我们需要三条数据同时成功或同时失败就需要开启事务机制了。开启事务机制后执行过程中发生问题就会回滚到操作之前&#xff0c;相当于没有执行操作。 1.2 事务的特征 事…

【初始RabbitMQ】延迟队列的实现

延迟队列概念 延迟队列中的元素是希望在指定时间到了之后或之前取出和处理消息&#xff0c;并且队列内部是有序的。简单来说&#xff0c;延时队列就是用来存放需要在指定时间被处理的元素的队列 延迟队列使用场景 延迟队列经常使用的场景有以下几点&#xff1a; 订单在十分…

Anaconda下安装torch-geometric

主要流程参考&#xff1a;https://blog.csdn.net/weixin_45671036/article/details/130617637 https://blog.csdn.net/weixin_43756314/article/details/130225038?ops_request_misc&request_id&biz_id102&utm_term%E5%80%9F%E5%8A%A9anaconda%20%E5%AE%89%E8%A3%…

配置vscode,使其可以运行C++11特性的代码(如vector)

配置vscode&#xff0c;使其可以运行C11特性的代码 封面引用自配置教程的B站视频&#xff0c;非常详细的视频&#xff0c;感谢视频作者的贡献。 文章目录 配置vscode&#xff0c;使其可以运行C11特性的代码Step 1: 基础配置Step 2: 调整Code Runner的配置Step 3: 更改tasks.jso…

【Spring连载】使用Spring Data的Repositories----定义Repository接口

【Spring连载】使用Spring Data的Repositories----定义Repository接口 一、微调Repository定义二、使用多个Spring Data模块的Repositories 要定义repository接口&#xff0c;首先需要定义特定于域&#xff08;domain&#xff09;类的repository接口。接口必须继承Repository&a…

8.openEuler操作系统网络管理和防火墙(二)

openEuler OECA认证辅导,标红的文字为学习重点和考点。 如果需要做实验,建议安装麒麟信安、银河麒麟、统信等具有图形化的操作系统,其安装与openeuler基本一致。 3.通过IP命令配置网络 配置IP地址: 使用ip命令为接口配置地址,命令格式如下,其中 interface-name 为网卡名…

一文7个步骤教你搭建测试web测试项目实战环境

​今天小编&#xff0c;给大家总结下web 测试实战的相关内容&#xff0c;一起来学习下吧&#xff01; web项目实战可按顺序依次为&#xff1a;【搭建测试环境】、【需求评审】、【编写测试计划】、【分析测试点.编写测试用例】、【用例评审】、【执行用例提bug】、【测试报告】…

广东珠宝行业为什么要开展珠宝神秘顾客调查呢?

在竞争激烈的珠宝市场中&#xff0c;品牌形象、服务质量以及顾客满意度是决定一个企业成功与否的关键因素。为了更好地了解顾客需求&#xff0c;优化服务流程&#xff0c;提升顾客满意度&#xff0c;珠宝行业开展神秘顾客调查显得尤为重要。以下从几个方面详细阐述珠宝行业为何…

undo日志详解

一、undo日志介绍 上一节详细的说了redo日志&#xff0c;redo日志的功能就是把增删改操作都记录着&#xff0c;如果断电导致内存中的脏页丢失&#xff0c;可以根据磁盘中的redo日志文件进行恢复。redo日志被设计出来是为了保证数据库的持久性&#xff0c;undo日志设计出来是为…

AI 绘画:人工智能绘画之美

人工智能&#xff08;AI&#xff09;是当今科技领域的热门话题&#xff0c;它不仅可以帮助我们解决各种复杂的问题&#xff0c;还可以创造出令人惊叹的艺术作品。AI 绘画是一种利用 AI 技术生成图像的方法&#xff0c;它可以模仿不同的风格、主题和技巧&#xff0c;甚至可以创造…

Qt Linux下调用OpenGL的glu.h报错:error: GL/glu.h: No such file or directory

Qt Linux下调用OpenGL的glu.h报错&#xff1a;error: GL/glu.h: No such file or directory 引言一、问题描述二、解决方案三、解决过程记录3.1 定位问题3.2 尝试使用yum命令安装3.3 从网上下载到本地进行安装 引言 在Windows上正常运行的OpenGL程序&#xff0c;到Linux下突然…

cuda学习笔记(2)

一 专业名词 1 分支断定 2 一致性和同一性 3 常见名词汇总 4 加速比 二 GPU架构构述 GPU就是将cpu的数据存储单元去掉&#xff0c;也就是保留执行单元&#xff0c;GPU就是多个执行单元 1 GPU设计思路&#xff0c;指令流共享&#xff0c;同时执行&#xff0c;数据切分成小块 …

四种主流的prompt框架

省流版&#xff1a; 文章介绍了在使用GPT时的四种prompt框架&#xff0c;有利于使用者打磨提问风格&#xff0c;与GPT进行更好的交互以提高生产力&#xff0c;能帮助大家有效提高工作效率~ 创作不易&#xff0c;如果对你有帮助的话&#xff0c;还请三连支持~ 想要使用Prompt…

MySQL的21个SQL经验

1. 写完SQL先explain查看执行计划(SQL性能优化) 日常开发写SQL的时候,尽量养成这个好习惯呀:写完SQL后,用explain分析一下,尤其注意走不走索引。 explain select userid,name,age from user where userid =10086 or age =18;2、操作delete或者update语句,加个limit(S…

jQuery简介与解析 - 掌控网页互动的魔法工具

jQuery简介与解析 - 掌控网页互动的魔法工具 摘要&#xff1a;本文将带您了解jQuery这一强大且流行的JavaScript库&#xff0c;探讨其特点、优势以及如何在网页开发中发挥巨大作用。我们将从jQuery的基本概念入手&#xff0c;逐步深入解析其核心功能&#xff0c;助您轻松掌握这…

phpspreadsheet导出数据和图片到excel

仅作记录&#xff0c;废话不多说 前提是已经安装了phpspreadsheet &#xff08; composer require phpoffice/phpspreadsheet &#xff09; 一、 数据拼装&#xff0c;调用excel类 <?php /*** 电子台账* Date: 2023/4/20* Time: 17:28*/namespace app\store\controlle…

Android 面试问题 2024 版(其三)

Android 面试问题 2024 版&#xff08;其三&#xff09; 十一、版本控制十二、Play 商店和应用程序部署十三、无障碍十四、第三方库和 API十五、解决问题的能力十六、基于 JD 的非常高级别的问题 十一、版本控制 什么是版本控制&#xff0c;为什么它在软件开发中很重要&#x…