BP 网络的标准学习算法及其实现

BP 网络的标准学习算法及其实现

一、引言

BP(Back Propagation)神经网络是一种广泛应用于机器学习和人工智能领域的神经网络模型。它通过反向传播算法来调整网络的权重,以最小化预测输出和实际输出之间的误差。BP 网络的标准学习算法对于理解和实现神经网络的训练过程至关重要,本文将详细介绍这些算法并给出相应的代码示例。

二、BP 网络的基本结构

BP 网络通常由输入层、若干隐藏层和输出层组成。每层由多个神经元构成,神经元之间通过合适的权重连接。输入层接收外部数据,隐藏层对数据进行特征提取和转换,输出层输出最终的预测结果。

三、BP 网络的标准学习算法

(一)梯度下降算法

  1. 原理
    • 梯度下降算法是 BP 网络中最基本的优化算法。其目标是通过迭代地调整网络权重,使损失函数(如均方误差函数)的值最小化。损失函数关于权重的梯度表示了损失函数在当前权重下增长最快的方向,而梯度下降算法则沿着梯度的负方向更新权重,以期望找到损失函数的最小值。
    • 对于一个具有多个权重参数 w w w 的 BP 网络,损失函数 J ( w ) J(w) J(w) 的梯度 ∇ J ( w ) \nabla J(w) J(w) 可以通过链式法则计算。在每次迭代中,权重更新公式为: w = w − α ∇ J ( w ) w = w - \alpha \nabla J(w) w=wαJ(w),其中 α \alpha α 是学习率,它决定了每次更新的步长。如果学习率过大,可能会导致算法无法收敛甚至发散;如果学习率过小,算法收敛速度会很慢。
  2. 代码示例
    以下是一个简单的使用梯度下降算法训练单神经元的代码示例,用于拟合一个简单的线性函数 y = 2 x + 1 y = 2x + 1 y=2x+1
import numpy as np# 激活函数,这里使用恒等函数(在线性回归场景下)
def activation_function(x):return x# 损失函数,这里使用均方误差
def loss_function(y_pred, y_true):return np.mean((y_pred - y_true) ** 2)# 训练数据
x = np.array([[1], [2], [3], [4], [5]])
y = np.array([[3], [5], [7], [9], [11]])# 初始化权重和偏置
weight = np.random.rand()
bias = np.random.rand()
learning_rate = 0.01
epochs = 100for epoch in range(epochs):# 前向传播y_pred = activation_function(x * weight + bias)loss = loss_function(y_pred, y)# 计算梯度d_loss_d_weight = np.mean((y_pred - y) * x)d_loss_d_bias = np.mean(y_pred - y)# 更新权重和偏置weight = weight - learning_rate * d_loss_d_weightbias = bias - learning_rate * d_loss_d_biasif epoch % 10 == 0:print(f'Epoch {epoch}: Loss = {loss}')print(f'Final weight: {weight}, Final bias: {bias}')

(二)随机梯度下降算法(SGD)

  1. 原理
    • 梯度下降算法在每次迭代时都使用整个训练数据集来计算梯度,当数据集很大时,计算成本很高。随机梯度下降算法则每次从训练数据集中随机选择一个样本进行梯度计算和权重更新。这样可以大大加快训练速度,但由于每次只使用一个样本,梯度的估计可能会有较大的噪声,导致收敛路径可能会更加曲折。
    • 在 SGD 中,权重更新公式与梯度下降类似,但每次只针对一个样本计算梯度。例如,对于一个样本 ( x i , y i ) (x_i, y_i) (xi,yi),权重更新公式为: w = w − α ∇ J i ( w ) w = w - \alpha \nabla J_i(w) w=wαJi(w),其中 ∇ J i ( w ) \nabla J_i(w) Ji(w) 是损失函数关于权重 w w w 在样本 i i i 上的梯度。
  2. 代码示例
    以下是使用随机梯度下降算法训练上述线性回归问题的代码:
import numpy as np# 激活函数,这里使用恒等函数(在线性回归场景下)
def activation_function(x):return x# 损失函数,这里使用均方误差
def loss_function(y_pred, y_true):return np.mean((y_pred - y_true) ** 2)# 训练数据
x = np.array([[1], [2], [3], [4], [5]])
y = np.array([[3], [5], [7], [9], [11]])# 初始化权重和偏置
weight = np.random.rand()
bias = np.random.rand()
learning_rate = 0.01for epoch in range(100):for i in range(len(x)):# 随机选择一个样本sample_x = x[i]sample_y = y[i]# 前向传播y_pred = activation_function(sample_x * weight + bias)loss = loss_function(y_pred, sample_y)# 计算梯度d_loss_d_weight = (y_pred - sample_y) * sample_xd_loss_d_bias = y_pred - sample_y# 更新权重和偏置weight = weight - learning_rate * d_loss_d_weightbias = bias - learning_rate * d_loss_d_biasif epoch % 10 == 0:y_pred_all = activation_function(x * weight + bias)loss_all = loss_function(y_pred_all, y)print(f'Epoch {epoch}: Loss = {loss_all}')print(f'Final weight: {weight}, Final bias: {bias}')

(三)小批量梯度下降算法(Mini - Batch Gradient Descent)

  1. 原理
    • 小批量梯度下降算法是梯度下降和随机梯度下降的一种折衷。它每次从训练数据集中选取一小批(mini - batch)样本进行梯度计算和权重更新。这样既可以利用向量化计算的优势(相比随机梯度下降),又可以在一定程度上减少计算量(相比梯度下降),同时也能获得比随机梯度下降更稳定的梯度估计。
    • 假设小批量大小为 m m m,在每次迭代中,从训练数据集 D D D 中随机选取一个小批量样本 B = { ( x 1 , y 1 ) , ( x 2 , y 2 ) , ⋯ , ( x m , y m ) } B = \{(x_1, y_1), (x_2, y_2), \cdots, (x_m, y_m)\} B={(x1,y1),(x2,y2),,(xm,ym)}。权重更新公式为: w = w − α ∇ J B ( w ) w = w - \alpha \nabla J_B(w) w=wαJB(w),其中 ∇ J B ( w ) \nabla J_B(w) JB(w) 是损失函数关于权重 w w w 在小批量样本 B B B 上的梯度。
  2. 代码示例
    以下是使用小批量梯度下降算法训练线性回归问题的代码,假设小批量大小为 2:
import numpy as np# 激活函数,这里使用恒等函数(在线性回归场景下)
def activation_function(x):return x# 损失函数,这里使用均方误差
def loss_function(y_pred, y_true):return np.mean((y_pred - y_true) ** 2)# 训练数据
x = np.array([[1], [2], [3], [4], [5]])
y = np.array([[3], [5], [7], [9], [11]])# 初始化权重和偏置
weight = np.random.rand()
bias = np.random.rand()
learning_rate = 0.01
batch_size = 2
epochs = 100for epoch in range(epochs):for i in range(0, len(x), batch_size):end_index = min(i + batch_size, len(x))batch_x = x[i:end_index]batch_y = y[i:end_index]# 前向传播y_pred = activation_function(batch_x * weight + bias)loss = loss_function(y_pred, batch_y)# 计算梯度d_loss_d_weight = np.mean((y_pred - batch_y) * batch_x)d_loss_d_bias = np.mean(y_pred - batch_y)# 更新权重和偏置weight = weight - learning_rate * d_loss_d_weightbias = bias - learning_rate * d_loss_d_biasif epoch % 10 == 0:y_pred_all = activation_function(x * weight + bias)loss_all = loss_function(y_pred_all, y)print(f'Epoch {epoch}: Loss = {loss_all}')print(f'Final weight: {weight}, Final bias: {bias}')

四、总结

BP 网络的标准学习算法包括梯度下降、随机梯度下降和小批量梯度下降等。梯度下降算法在处理大规模数据集时计算成本高,随机梯度下降算法收敛路径可能不稳定,而小批量梯度下降算法在两者之间取得了较好的平衡。在实际应用中,需要根据数据集的大小、计算资源和模型的复杂度等因素选择合适的学习算法,以有效地训练 BP 网络并获得良好的预测性能。同时,这些算法还可以进一步改进和优化,例如使用自适应学习率等方法来提高训练效率和收敛速度。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/885325.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Java爬虫 爬取某招聘网站招聘信息

Java爬虫 爬取某招聘网站招聘信息 一、系统介绍二、功能展示1.需求爬取的网站内容2.实现流程2.1数据采集2.2页面解析2.3数据存储 三、其它1.其他系统实现 一、系统介绍 系统主要功能:本项目爬取的XX招聘网站 二、功能展示 1.需求爬取的网站内容 2.实现流程 爬虫…

stm32不小心把SWD和JTAG都给关了,程序下载不进去,怎么办?

因为想用STM32F103的PA15引脚,调试程序的时候不小心把SWD和JTAD接口都给关了,先看下罪魁祸首 GPIO_PinRemapConfig(GPIO_Remap_SWJ_JTAGDisable,ENABLE);//关掉JTAG,不关SWGPIO_PinRemapConfig(GPIO_Remap_SWJ_Disable, ENABLE);//关掉SW&am…

雷军-2022.8小米创业思考-11-新零售:用电商思维做新零售,极致的效率+极致的体验。也有弯路,重回极致效率的轨道上。

第十一章 新零售 当我们说到小米模式的时候,其实我们说的是两件东西: 一是小米模式的本质,即高效率的商业模式; 另一件是小米这家公司具象的商业模式,这是小米在实践中摸索、建立的一整套业务模型。 从2015年到202…

C语言实现数据结构之堆

文章目录 堆一. 树概念及结构1. 树的概念2. 树的相关概念3. 树的表示4. 树在实际中的运用(表示文件系统的目录树结构) 二. 二叉树概念及结构1. 概念2. 特殊的二叉树3. 二叉树的性质4. 二叉树的存储结构 三. 二叉树的顺序结构及实现1. 二叉树的顺序结构2.…

DocuBurst——基于java实现

DocuBurst 文档散(DocuBurst)也是基于关键词的文本可视化,不过它还通过径向布局体现了词的语义等级。如下图所示,外层的词是内层词的下义祠,颜色饱和度的深浅用来体现词频的高低。 DocuBurst 是第一个利用词法数据库中人工创建的结构的文档内容可视化。我们使用公认的设…

知识课堂之域名系统中实现动态代理

怎么在域名系统中解析动态ip,这一直是一个需要解决的问题,人们对与网络的稳定连接与灵活运用已经成为生活和工作中不可或缺的一部分,因此这样的问题的解决迫在眉睫。 大家对于动态ip是什么,应该都有所了解了,所谓的动…

5G周边知识笔记

这里写目录标题 3GPP 5G标准路径图5G协议规范5G新空口关键指标4G LTE和5G NR新空口技术对比5G新频段FR1FR2 信道带宽上下行解耦新频点规划,信道栅格FR1各频段实际信道栅格和NR-ARFCN范围定义 同步栅格大规模天线阵列新型调制编码技术大规模载波聚合设备到设备直接通…

报名开启|开放原子大赛“Rust数据结构与算法学习赛”

开放原子大赛“Rust数据结构与算法学习赛”报名进行中,报名截止时间为11月17日。 为了进一步促进开源技术的发展,提升国内开源社区的创新能力和国际影响力,开放原子开源基金会与清华大学开源操作系统训练营等单位,共同举办本次Rus…

Flutter自定义矩形进度条实现详解

在Flutter应用开发中,进度条是一个常见的UI组件,用于展示任务的完成进度。本文将详细介绍如何实现一个支持动画效果的自定义矩形进度条。 功能特点 支持圆角矩形外观平滑的动画过渡效果可自定义渐变色可配置边框宽度和颜色支持进度更新动画 实现原理 …

C#中的Math类

在 C# 中,Math 类提供了许多数学运算的静态方法,涵盖了各种常见的数学函数和操作。以下是 Math 类中的常用方法及其用法(持续更新中…) 方法说明示例Abs()返回指定数值的绝对值int absValue Math.Abs(-10); 结果为 10Acos()返回…

uniapp配置h5路由模式为history时404

为了不让URL中出现#,让uniapp项目配置h5路由模式为hisory 然而本地好好的,放到服务器上却404了。 解决方法是给nginx配置一个伪静态: location /xxx-html/ {alias /home/nginx_web/xxx_new_html/;try_files $uri $uri/ /xxx-html/index.ht…

python画图|灵活的subplot_mosaic()函数-初逢

【1】引言 前述学习进程中,对hist()函数画直方图已经有一定的探索。然而学无止境,在继续学习的进程中,我发现了一个显得函数subplot_mosaic(),它几乎支持我们随心所欲地排布多个子图。 经过自我探索,我有一些收获&am…

单体架构的 IM 系统设计

先直接抛出业务背景! 有一款游戏,日活跃量(DAU)在两千左右,虽然 DAU 不高,但这两千用户的忠诚度非常高,而且会持续为游戏充值;为了进一步提高用户体验,继续增强用户的忠…

vue实现天地图电子围栏

一、文档 vue3 javascript WGS84、GCj02相互转换 天地图官方文档 注册登录然后申请应用key&#xff0c;通过CDN引入 <script src"http://api.tianditu.gov.cn/api?v4.0&tk您的密钥" type"text/javascript"></script>二、分析 所谓电子围…

基于SSM(Spring + Spring MVC + MyBatis)框架的汽车租赁共享平台系统

基于SSM&#xff08;Spring Spring MVC MyBatis&#xff09;框架的汽车租赁共享平台系统是一个复杂的Web应用程序&#xff0c;用于管理和调度汽车租赁服务。下面我将提供一个详细的案例程序概述&#xff0c;包括主要的功能模块和技术栈介绍。 项目概述 功能需求 用户管理&…

Python函数专题:默认参数与关键字参数

在Python编程中,函数是一个非常重要的概念。它们不仅用于组织代码,还能够提高代码的重用性和可读性。在本文中,我们将深入探讨Python的默认参数和关键字参数这两个特性。这些特性可以让函数的调用更加灵活和强大。 一、什么是默认参数? 默认参数是指在定义函数时,为某些…

前端将后端返回的文件下载到本地

vue 将后端返回的文件地址下载到本地 在 template 拿到后端返回的文件路径 <el-button link type"success" icon"Download" click"handleDownload(file)"> 附件下载 </el-button>在 script 里面写方法 function handleDownload(v…

【C++前缀和 单调栈】1124. 表现良好的最长时间段|1908

本文涉及的基础知识点 C算法&#xff1a;前缀和、前缀乘积、前缀异或的原理、源码及测试用例 包括课程视频 C单调栈 LeetCode 1124. 表现良好的最长时间段 给你一份工作时间表 hours&#xff0c;上面记录着某一位员工每天的工作小时数。 我们认为当员工一天中的工作小时数大…

qt5将程序打包并使用

一、封装程序 (1)、点击创建项目->库->clibrary &#xff08;2&#xff09;、填写自己想要封装成库的名称&#xff0c;这里我填写的名称为mydll1 &#xff08;3&#xff09;、如果没有特殊的要求&#xff0c;则一路下一步&#xff0c;最终会出现如下文件列表。 (4)、删…

PICO+Unity MR空间锚点

官方链接&#xff1a;空间锚点 | PICO 开发者平台 注意&#xff1a;该功能只能打包成APK在PICO 4 Ultra上真机运行&#xff0c;无法通过串流或PICO developer center在PC上运行。使用之前要开启视频透视。 在 Inspector 窗口中的 PXR_Manager (Script) 面板上&#xff0c;勾选…