基于粒子群优化的BP神经网络算法

        大家好,我是带我去滑雪!

        基于粒子群优化的BP神经网络算法(Particle Swarm Optimization Backpropagation Neural Network,PSO-BPNN)是一种利用粒子群优化算法优化BP神经网络的算法。它将BP神经网络的权重和偏置值作为粒子群的位置,并利用PSO算法来搜索最优解。该算法具体许多优点,例如:1、能够在搜索过程中不陷入局部最优解,而是更有可能找到全局最优解或者接近全局最优解的解;2、可以通过利用粒子群的速度来调整BP神经网络的权重和偏置值,使得网络能够更快地优化;3、可以通过直接根据网络性能来更新权重和偏置值,避免了梯度消失问题;4、由于采用了随机性初始化的粒子群,可以减轻对初始值的依赖,使得算法更加稳定和鲁棒;5、PSO-BP算法的粒子群可以并行运行,每个粒子都可以独立进行计算和更新。这使得算法能够有效利用并行计算的优势,加速了训练过程。

       PSO-BP神经网络算法的基本步骤:

  1. 初始化粒子群的位置和速度。每个粒子的位置表示BP神经网络的权重和偏置值,速度表示更新的步长;

  2. 对于每个粒子,根据当前位置计算网络的输出,并计算误差(例如均方误差);

  3. 根据每个粒子的误差评估其适应度,适应度可以使用误差的倒数或其他评价指标来表示;

  4. 更新每个粒子的最优位置和最优适应度。如果当前适应度优于历史最优适应度,则更新最优位置;

  5. 更新每个粒子的速度和位置。速度的更新考虑了个体经验和群体共享经验,以及随机项,使得粒子能够在搜索空间中进行探索和利用;

  6. 重复步骤2到步骤5,直到满足停止条件(例如达到最大迭代次数或达到期望的网络性能)。

       下面开始代码实战

clc;

clear;

tic

close all;

load('basket.mat')#导入数据

P = trains(:,1:end-1) ;

T = trains(:,end) ;

P_test = tests(:,1:end-1) ;

T_test = tests(:,end) ;

cur_season = pred ;

inputnum=size(P,2);

hiddennum=2*inputnum+1;

outputnum=size(T,2);

w1num=inputnum*hiddennum;                               w2num=outputnum*hiddennum;

N=w1num+hiddennum+w2num+outputnum;

nVar=N;

VarSize=[1,nVar];

VarMin=-0.5;

VarMax=0.5;

MaxIt=200;

nPop=359;

w=1;

wdamp=0.99;

c1=1.5;

c2=2.0;

VelMax=0.1*(VarMax-VarMin);

VelMin=-VelMax;

empty_particle.Position=[];

empty_particle.Cost=[];

empty_particle.Velocity=[];

empty_particle.Best.Position=[];

empty_particle.Best.Cost=[];

particle=repmat(empty_particle,nPop,1);

GlobalBest.Cost=inf;

for i=1:nPop

    particle(i).Position=unifrnd(VarMin,VarMax,VarSize);

    particle(i).Velocity=zeros(VarSize);

    particle(i).Cost=BpFunction(particle(i).Position,P,T,hiddennum,P_test,T_test);

    particle(i).Best.Position=particle(i).Position;

    particle(i).Best.Cost=particle(i).Cost;

    if particle(i).Best.Cost<GlobalBest.Cost

        GlobalBest=particle(i).Best;

    end

end

BestCost=zeros(MaxIt,1);

for it=1:MaxIt

    for i=1:nPop

        particle(i).Velocity = w*particle(i).Velocity ...

            +c1*rand(VarSize).*(particle(i).Best.Position-particle(i).Position) ...

            +c2*rand(VarSize).*(GlobalBest.Position-particle(i).Position);

        particle(i).Velocity = max(particle(i).Velocity,VelMin);

        particle(i).Velocity = min(particle(i).Velocity,VelMax);

        particle(i).Position = particle(i).Position + particle(i).Velocity;

        IsOutside=(particle(i).Position<VarMin | particle(i).Position>VarMax);

        particle(i).Velocity(IsOutside)=-particle(i).Velocity(IsOutside);

        particle(i).Position = max(particle(i).Position,VarMin);

        particle(i).Position = min(particle(i).Position,VarMax);

        particle(i).Cost=BpFunction(particle(i).Position,P,T,hiddennum,P_test,T_test);

        if particle(i).Cost<particle(i).Best.Cost

            particle(i).Best.Position=particle(i).Position;

            particle(i).Best.Cost=particle(i).Cost;

            if particle(i).Best.Cost<GlobalBest.Cost 

                GlobalBest=particle(i).Best; 

            end

        end

    end

    BestCost(it)=GlobalBest.Cost;

    disp(['Iteration ' num2str(it) ': Best Cost = ' num2str(BestCost(it))]);

    w=w*wdamp;

end

BestSol=GlobalBest;

%% Results

figure;

%plot(BestCost,'LineWidth',7);

semilogy(BestCost,'LineWidth',7);

xlabel('Number of iterations')

ylabel('The variation of the error')

title('Evolution')

grid on;

fprintf([' Optimal initial weight and threshold:\n=',num2str(BestSol.Position),'\n Minimum error=',num2str(BestSol.Cost),'\n'])

cur_test=zeros(size(cur_season,1),1);

[~,bestCur_sim]=BpFunction(BestSol.Position,P,T,hiddennum,cur_season,cur_test);

prob=softmax(bestCur_sim);                                 

disp(['1次尝试通过游戏的概率为 ',num2str(prob(1))]);

disp(['2次尝试通过游戏的概率',num2str(prob(2))]);

disp(['3次尝试通过游戏的概率',num2str(prob(3))]);

disp(['4次尝试通过游戏的概率',num2str(prob(4))]);

disp(['5次尝试通过游戏的概率',num2str(prob(5))]);

disp(['6次尝试通过游戏的概率',num2str(prob(6))]);

disp(['7次以上尝试通过游戏的概率',num2str(prob(7))]);

toc

function [err,T_sim]=BpFunction(x,P,T,hiddennum,P_test,T_test)

inputnum=size(P,7);

outputnum=size(T,7);

[p_train,ps_train]=mapminmax(P',0,1);

p_test=mapminmax('apply',P_test',ps_train);

[t_train,ps_output]=mapminmax(T',0,1);

net=newff(p_train,t_train,hiddennum);                               

net.trainParam.epochs=1000;

net.trainParam.goal=1e-3;

net.trainParam.lr=0.01;

net.trainParam.showwindow=false;                                    w1num=inputnum*hiddennum;                                           w2num=outputnum*hiddennum;                                        W1=x(1:w1num);                                                      B1=x(w1num+1:w1num+hiddennum);                                      W2=x(w1num+hiddennum+1:w1num+hiddennum+w2num);                    B2=x(w1num+hiddennum+w2num+1:w1num+hiddennum+w2num+outputnum);      net.iw{1,1}=reshape(W1,hiddennum,inputnum);                        net.lw{2,1}=reshape(W2,outputnum,hiddennum);                       net.b{1}=reshape(B1,hiddennum,1);                                  net.b{2}=reshape(B2,outputnum,1);

net = train(net,p_train,t_train);

t_sim = sim(net,p_test);

T_sim1 = mapminmax('reverse',t_sim,ps_output);

T_sim=T_sim1';

err=norm(T_sim-T_test);

index0= T_sim<0;

index1= T_sim>1;

 penalty=1000*abs(sum(T_sim(index0)))+1000*sum(T_sim(index1)-1);

err=err+penalty;

end

需要数据集的家人们可以去百度网盘(永久有效)获取:

链接:https://pan.baidu.com/s/1E59qYZuGhwlrx6gn4JJZTg?pwd=2138
提取码:2138 


更多优质内容持续发布中,请移步主页查看。

   点赞+关注,下次不迷路!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/77487.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

C++vector模拟实现

vector模拟实现 1.构造函数2.拷贝构造3.析构赋值运算符重载4.iterator5.modifiers5.1push_back5.2pop_back5.3empty5.4insert5.5erase5.6swap 6.Capacity6.1size6.2capacity6.3reserve6.4resize6.5empty 7.Element access7.1operator[]7.2at 8.在谈reserve vector官方库实现的是…

SQL11 高级操作符练习(1)

描述 题目&#xff1a;现在运营想要找到男性且GPA在3.5以上(不包括3.5)的用户进行调研&#xff0c;请你取出相关数据。 示例&#xff1a;user_profile iddevice_idgenderageuniversitygpa12138male21北京大学3.423214male复旦大学4.036543female20北京大学3.242315female23浙…

向量范数及其Python代码

【向量范数】 向量由于既有大小又有方向&#xff0c;所以不能直接比较大小。 向量范数通过将向量转化为实数&#xff0c;然后进行向量的大小比较。 所以&#xff0c;向量范数是用于度量“向量大小”的量。 设向量 &#xff0c;则有&#xff1a; ● 向量的 范数&#xff1a; ●…

Python计算机Python二级知识点整理

1. 此时我们这里首先解析一下这个d[A]N,根据ASCII表&#xff0c;我们可以看出字符A对应的十进制数字是65&#xff0c;ord()函数是把字符转换为相对应的ASCII码&#xff0c;chr()函数是ord()函数的逆运算&#xff0c;所以ord("A")65 ,chr(65)A,题目中首先定义了d为一…

性能测试包含哪些内容?

性能测试是对软件产品在特定条件下的性能进行测试和评估的过程。性能测试的内容可以包括以下几个方面&#xff1a; 1、负载测试&#xff1a;负载测试是指在特定条件下&#xff0c;对软件产品的性能进行测试和评估。测试人员可以通过模拟不同的用户数量、并发请求、访问频率等…

el-popover 通过js手动控制弹出框显示、隐藏

el-popover 通过js手动控制弹出框显示、隐藏 说明 element ui 2.x中&#xff0c;el-popover的显示隐藏有4种触发方式&#xff1a;click/focus/hover/manual&#xff0c;分别是点击/聚焦/悬浮/手动&#xff0c;正常情况这几个触发方式已经能满足大部分需求&#xff0c;但有些业…

C++毕业设计基于QT实现的超市收银管理系统源代码+数据库

C毕业设计基于QT实现的超市收银管理系统源代码数据库 编译使用 编译完成后&#xff0c;需要拷贝 file目录下的数据库 POP.db文件到可执行程序目录下 登录界面 主界面 会员管理 完整代码下载地址&#xff1a;基于QT实现的超市收银管理系统源代码数据库

笔记本多拓展出一个屏幕

一、首先要知道&#xff0c;自己的电脑有没有Type-c接口&#xff0c;支持不支持VGA 推荐&#xff1a; 自己不清楚&#xff0c;问客服&#xff0c;勤问。 二、显示屏与笔记本相连&#xff0c;通过VGA 三、连接好了&#xff0c;需要去配置 网址&#xff1a;凑合着看&#xff…

LLM 02-大模型的能力

LLM 02-大模型的能力 我们将深入探讨GPT-3——这个具有代表性的大型语言模型的能力。我们的研究主要基于GPT-3论文中的基准测试&#xff0c;这些测试包括&#xff1a; 标准的自然语言处理&#xff08;NLP&#xff09;基准测试&#xff0c;例如问题回答&#xff1b;一些特殊的一…

【OpenCV • c++】图像噪音 | 椒盐噪音 | 高斯噪音

文章目录 一、什么是图像噪音二、椒盐噪声三、高斯噪声 一、什么是图像噪音 图像噪声是图像在获取或是传输过程中受到随机信号干扰&#xff0c;妨碍人们对图像理解及分析处理的信号。很多时候将图像噪声看做多维随机过程&#xff0c;因而描述噪声的方法完全可以借用随机过程的描…

aruco码DICT几乘几是啥含义,aruco.getPredefinedDictionary

dictionary aruco.getPredefinedDictionary(aruco.DICT_5X5_100) aruco.DICT_5X5_100中的5X5和100表示: - 5X5:表示ArUco标记是5x5像素大小的正方形格子组成。 - 100:表示这个字典包含100个不同的ArUco标记。aruco代码字典中包含多个不同的二进制marker,每个marker由一系列…

PyCharm中使用matplotlib.pyplot.show()报错MatplotlibDeprecationWarning的解决方案

其实这只是一个警告&#xff0c;忽略也可。 一、控制台输出 MatplotlibDeprecationWarning: Support for FigureCanvases without a required_interactive_framework attribute was deprecated in Matplotlib 3.6 and will be removed two minor releases later. MatplotlibD…

AttributeError: module ‘OpenSSL.SSL’ has no attribute ‘SSLv3_METHOD

这个错误是由于在OpenSSL.SSL模块中找不到SSLv3_METHOD属性导致的。解决这个问题的方法如下&#xff1a; 首先&#xff0c;确保你已经安装了最新版本的cryptography和pyOpenSSL。你可以使用以下命令卸载并重新安装它们&#xff1a; 卸载cryptography&#xff1a;pip uninstall …

Java“牵手”微店商品列表页数据采集+微店商品价格数据排序,微店API接口申请指南

微店平台创立于2011年5月&#xff0c;是北京口袋时尚科技开发的应用&#xff0c;2014年1月"微店"APP正式上线。微店已经从小微店主首选的开店工具转型为助力创业者发展兴趣、创立品牌、玩成事业的系统及基础设施。 微店商品列表数据包含商品名称、价格、销量、详情、…

微信小程序 通过 pageScrollTo 滚动到界面指定位置

我们可以先创建一个page 注意 一定要在page中使用 因为pageScrollTo控制的是页面滚动 你在组件里用 他就失效了 我们先来看一个案例 wxml 代码如下 <view><button bindtap"handleTap">回到指定位置</button><view class "ControlHeight…

js 小数相乘后,精度缺失问题,记录四舍五入,向下取整

在做项目的时候&#xff0c;有一个计算金额的&#xff0c;结果发现计算的金额总是缺失0.01&#xff0c;发现相乘的时候&#xff0c;会失去精度&#xff0c;如图所示。被这整的吐血&#xff0c;由于计算逻辑由前端计算&#xff0c;所以传值后端总出错(尽量后端计算)。 还发现to…

9月12日作业

作业代码 #include <iostream>using namespace std;class Shape { protected:double cir;double area; public://无参构造Shape() {cout<<"无参构造"<<endl;}//有参构造Shape(double c, double a):cir(c), area(a){cout<<"有参构造&quo…

IDEFICS 简介: 最先进视觉语言模型的开源复现

我们很高兴发布 IDEFICS ( Image-aware Decoder Enhanced la Flamingo with Ininterleaved Cross-attention S ) 这一开放视觉语言模型。IDEFICS 基于 Flamingo&#xff0c;Flamingo 作为最先进的视觉语言模型&#xff0c;最初由 DeepMind 开发&#xff0c;但目前尚未公开发布…

极简B站直播录制工具 录播姬 2.9.0,支持自动批量录制、弹幕录制等

录播姬 是一个简单好用免费开源的直播录制工具&#xff0c;支持自动批量录制、弹幕录制、实时监控直播间状态&#xff0c;直接获取直播流&#xff0c;非录制屏幕&#xff0c;没有二次压制 软件特点 使用简单&#xff1a;粘贴房间号或房间链接即可开录 自动录制&#xff1a;主…

LeetCode 28. 找出字符串中第一个匹配项的下标

文章目录 一、题目二、C# 题解 一、题目 给你两个字符串 haystack 和 needle &#xff0c;请你在 haystack 字符串中找出 needle 字符串的第一个匹配项的下标&#xff08;下标从 0 开始&#xff09;。如果 needle 不是 haystack 的一部分&#xff0c;则返回 -1 。 点击此处跳转…