详解反向传播(BP)算法

文章目录

      • what(是什么)
      • where(用在哪)
      • How(原理&&怎么用)
          • 原理以及推导过程
          • pytorch中的反向传播

what(是什么)

反向传播算法(Backpropagation)是一种用于训练人工神经网络的常见方法。它通过计算网络预测与实际结果之间的误差,然后反向传播这个误差来调整网络中每个权重的值,从而逐步优化网络的学习过程

在这里插入图片描述

where(用在哪)

绝大多数的神经网络都会使用反向传播算法进行网络权重以及阈值的更新,简单列举部分典型的使用场景如下

反向传播算法
前馈神经网络
多层感知机
卷积神经网络
循环神经网络
深度神经网络

How(原理&&怎么用)

原理以及推导过程

下面重点介绍反向传播算法的推导流程

在这里插入图片描述

假设有以上简单的神经网路模型,分为输入层、隐藏层、输出层。其中隐藏层包括4个神经元、输出层包括2个神经元。
假设输出层的两个神经元为 y 1 y_1 y1 y 2 y_2 y2,其激活阈值分别为 β \beta β γ \gamma γ,两个神经元的输入分别为 y 1 i n y_{1in} y1in y 2 i n y_{2in} y2in,输出分别为 y 1 ^ \hat{y_1} y1^ y 2 ^ \hat{y_2} y2^
假设隐藏层四个神经元为 h 1 h_1 h1 h 2 h_2 h2 h 3 h_3 h3 h 4 h_4 h4,其中 h 1 h_1 h1的激活阈值为 δ \delta δ,神经元 h 1 h_1 h1的输入值为 h i n h_{in} hin,输出值为 h o u t h_{out} hout
假设输入层两个神经元为 x 1 x_1 x1 x 2 x_2 x2,其中神经元 x 1 x_1 x1的输出为 x o u t x_{out} xout
假设神经元 x 1 x_1 x1到神经元 h 1 h_1 h1的连接权重为 W 11 W_{11} W11,神经元 h 1 h_1 h1到神经元 y 1 y_1 y1 y 2 y_2 y2的连接权重分别为 W 21 W_{21} W21 W 22 W_{22} W22
假设神经元的激活函数为sigmoid函数,sigmoid激活函数的表达式:
f ( x ) = 1 1 − e − x f(x)=\frac{1}{1-e^{-x}} f(x)=1ex1
该激活函数有一个非常好的性质:
f ′ ( x ) = f ( x ) ( 1 − f ( x ) ) f'(x)=f(x)(1-f(x)) f(x)=f(x)(1f(x))
下面,详细介绍连接权重 W W W以及激活阈值的更新过程。
首先,给出 W 21 W_{21} W21以及 β \beta β的更新公式,其中, W 21 W_{21} W21更新公式为:
W 21 = W 21 + η ∗ Δ W 21 W_{21}=W_{21}+\eta*\Delta W_{21} W21=W21+ηΔW21
同理, β \beta β更新公式为:
β = β + η ∗ Δ β \beta=\beta+\eta*\Delta \beta β=β+ηΔβ

在以上公式中,只有 Δ W 21 \Delta W_{21} ΔW21以及 Δ β \Delta \beta Δβ未知,需要计算。而已知的是样本,也就是 ( x , y ) (x,y) (x,y),那么我们将通过样本数据来表达出上述 Δ W 21 \Delta W_{21} ΔW21以及 Δ β \Delta \beta Δβ
根据反向传播算法, Δ W 21 \Delta W_{21} ΔW21以及 Δ β \Delta \beta Δβ分别为最终的误差对 W 21 W_{21} W21以及 β \beta β的偏导数。假设采用的损失函数为:
L o s s = 1 2 ( y 1 − y 1 ^ ) 2 + 1 2 ( y 2 − y 2 ^ ) 2 Loss=\frac{1}{2}(y_1-\hat{y_1})^2+\frac{1}{2}(y_2-\hat{y_2})^2 Loss=21(y1y1^)2+21(y2y2^)2
扩展到输出层有k个神经元的情况:
L o s s = 1 2 Σ 1 k ( y i − y i ^ ) 2 Loss=\frac{1}{2}\Sigma_1^k(y_i-\hat{y_i})^2 Loss=21Σ1k(yiyi^)2
而从输出端看,能得到以下表达式:
y 1 ^ = f ( y 1 i n − β ) = f ( W 21 h o u t − β ) \hat{y_1}=f(y_{1in}-\beta)=f(W_{21}h_{out}-\beta) y1^=f(y1inβ)=f(W21houtβ)
y 1 ^ \hat{y_1} y1^带入到损失函数中,也就是:
L o s s = 1 2 ( y 1 − f ( W 21 h o u t − β ) ) 2 + 1 2 ( y 2 − f ( W 22 h o u t − γ ) ) 2 Loss = \frac{1}{2}(y_1-f(W_{21}h_{out}-\beta))^2+\frac{1}{2}(y_2-f(W_{22}h_{out}-\gamma))^2 Loss=21(y1f(W21houtβ))2+21(y2f(W22houtγ))2
如此,便得出损失和 W 21 W_{21} W21之间的代数关系式,接下来只需要对该表达式求导即可得到 Δ W 21 \Delta W_{21} ΔW21以及 Δ β \Delta \beta Δβ

首先, ∂ L o s s ∂ W 21 \frac{\partial Loss}{\partial W_{21}} W21Loss的计算公式为:
∂ L o s s ∂ W 21 = [ y 1 − f ( W 21 h o u t − β ) ] ∗ [ − f ′ ( W 21 h o u t − β ) ] ∗ h o u t = − [ y 1 − f ( W 21 h o u t − β ) ] ∗ f ( W 21 h o u t − β ) [ 1 − ( f ( W 21 h o u t − β ) ) ] ∗ h o u t = − ( y 1 − y 1 ^ ) ∗ y 1 ^ ∗ ( 1 − y 1 ^ ) ∗ h o u t \begin{aligned} \frac{\partial Loss}{\partial W_{21}} & = [y_1-f(W_{21}h_{out}-\beta)]*[-f'(W_{21}h_{out}-\beta)]*h_{out} \\ & =- [y_1-f(W_{21}h_{out}-\beta)]*f(W_{21}h_{out}-\beta)[1-(f(W_{21}h_{out}-\beta))]*h_{out} \\ & = -(y_1-\hat{y_1})*\hat{y_1}*(1-\hat{y_1})*h_{out} \end{aligned} W21Loss=[y1f(W21houtβ)][f(W21houtβ)]hout=[y1f(W21houtβ)]f(W21houtβ)[1(f(W21houtβ))]hout=(y1y1^)y1^(1y1^)hout
同样地, ∂ L o s s ∂ β \frac{\partial Loss}{\partial \beta} βLoss的计算公式为:
∂ L o s s ∂ β = [ y 1 − f ( W 21 h o u t − β ) ] ∗ [ − f ′ ( W 21 h o u t − β ) ] ∗ ( − 1 ) = [ y 1 − f ( W 21 h o u t − β ) ] ∗ f ( W 21 h o u t − β ) [ 1 − ( f ( W 21 h o u t − β ) ) ] = ( y 1 − y 1 ^ ) ∗ y 1 ^ ∗ ( 1 − y 1 ^ ) \begin{aligned} \frac{\partial Loss}{\partial \beta} & = [y_1-f(W_{21}h_{out}-\beta)]*[-f'(W_{21}h_{out}-\beta)]*(-1) \\ & = [y_1-f(W_{21}h_{out}-\beta)]*f(W_{21}h_{out}-\beta)[1-(f(W_{21}h_{out}-\beta))] \\ & = (y_1-\hat{y_1})*\hat{y_1}*(1-\hat{y_1}) \end{aligned} βLoss=[y1f(W21houtβ)][f(W21houtβ)](1)=[y1f(W21houtβ)]f(W21houtβ)[1(f(W21houtβ))]=(y1y1^)y1^(1y1^)
由于梯度下降法,需要沿着负梯度方向,所以, Δ W 21 = − ∂ L o s s ∂ W 21 \Delta W_{21}=-\frac{\partial Loss}{\partial W_{21}} ΔW21=W21Loss Δ β = − ∂ L o s s ∂ β \Delta \beta=-\frac{\partial Loss}{\partial \beta} Δβ=βLoss,从而得出 W 21 , β W_{21},\beta W21,β的更新公式为:
W 21 = W 21 + η ∗ Δ W 21 = W 21 − η ∗ ∂ L o s s ∂ W 21 = W 21 + η ∗ ( y 1 − y 1 ^ ) ∗ y 1 ^ ∗ ( 1 − y 1 ^ ) ∗ h o u t \begin{aligned} W_{21} &= W_{21} + \eta*\Delta W_{21} \\ & = W_{21}-\eta * \frac{\partial Loss}{\partial W_{21}} \\ & =W_{21}+\eta *(y_1-\hat{y_1})*\hat{y_1}*(1-\hat{y_1})*h_{out} \end{aligned} W21=W21+ηΔW21=W21ηW21Loss=W21+η(y1y1^)y1^(1y1^)hout

β = β + η ∗ Δ β = β − η ∗ ∂ L o s s ∂ β = β − η ∗ ( y 1 − y 1 ^ ) ∗ y 1 ^ ∗ ( 1 − y 1 ^ ) \begin{aligned} \beta & = \beta+\eta*\Delta \beta \\ & = \beta-\eta* \frac{\partial Loss}{\partial \beta} \\ & = \beta-\eta *(y_1-\hat{y_1})*\hat{y_1}*(1-\hat{y_1}) \end{aligned} β=β+ηΔβ=βηβLoss=βη(y1y1^)y1^(1y1^)

使用同样的方式,可以对 W 11 , δ W_{11},\delta W11,δ的梯度公式进行计算和更新。

pytorch中的反向传播

下面举例说明在pytorch中,如何使用反向传播算法来更新权重以及阈值。

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F# 定义一个复杂的神经网络
class ComplexNet(nn.Module):def __init__(self):super(ComplexNet, self).__init__()self.fc1 = nn.Linear(10, 50)  # 输入大小为10,输出大小为50self.fc2 = nn.Linear(50, 20)  # 输入大小为50,输出大小为20self.fc3 = nn.Linear(20, 1)   # 输入大小为20,输出大小为1def forward(self, x):x = F.relu(self.fc1(x))  # 使用ReLU作为激活函数x = F.relu(self.fc2(x))x = self.fc3(x)return x# 创建网络实例
model = ComplexNet()# 定义损失函数
criterion = nn.MSELoss()# 随机生成一些输入和目标输出数据
input_data = torch.randn((32, 10))  # 32个样本,每个样本特征数为10
target_output = torch.randn((32, 1))  # 对应的32个目标输出# 定义优化器
optimizer = optim.Adam(model.parameters(), lr=0.01)# 训练模型
model.train()	# 设置模型为训练模式
epochs = 1000
for epoch in range(epochs):# 梯度清零optimizer.zero_grad()# 前向传播output = model(input_data)# 计算损失loss = criterion(output, target_output)# 反向传播loss.backward()# 更新模型参数optimizer.step()# 每隔一段时间输出一下损失值if (epoch+1) % 100 == 0:print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}')# 打印模型结构
print(model)

pythrch中,输入在流经每一个神经元时,会构建一个动态计算图(与tensorflow不同,tensorflow为静态计算图),记录了每个神经元的输入输出信息。在反向传播时, loss.backward()会根据已知的样本数据以及神经元的输入输出信息,计算连接权重以及阈值的梯度,然后optimizer.step()来实现对权重和阈值的更新。需要注意的是,在每一个mini-batch开始前,需要使用optimizer.zero_grad()对梯度置零。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/865043.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

自动驾驶水泥搅拌车在梁场的应用(下)

自动驾驶水泥搅拌车在梁场的应用(下) 北京渡众机器人科技有限公司的自动驾驶水泥搅拌车在梁场(也称为预制梁场)的应用可以带来多方面的优势和效益: 1. 自动化搅拌和运输 在梁场中,通常需要大量的混凝土搅…

自动化一些操作

下拉选择框 from selenium import webdriver from time import sleep # 导包 from selenium.webdriver.support.select import Select driver webdriver.Edge() driver.get(r"D:\WORK\ww\web自动化_day01_课件笔记资料代码\web自动化_day01_课件笔记资料代码\02_其他资料…

调试支付分回调下载平台证书

之前的原生代码放到webman里面,死活跑不通 没办法,只能用esayWeChat6.7 (自行下载) 它里面配置要用到平台证书 平台证书又要用到 composer require wechatpay/wechatpay 但是请求接口之前,你先要用到一个临时的平台…

配置atuin记录

https://atuin.sh/ 运行 curl --proto https --tlsv1.2 -LsSf https://setup.atuin.sh | sh报错 $ curl --proto https --tlsv1.2 -LsSf https://setup.atuin.sh | sh curl: (77) error setting certificate verify locations:CAfile: /etc/ssl/certs/ca-certificates.crtCAp…

同时安装JDK8和JDK17+环境变量默认无法修改

一、问题描述 当在windows系统中,同时安装JDK8和JDK17,环境变量默认就为jdk17,且从jdk17切换为jdk8后不生效,使用"java -version"命令查看后还是17版本。 解决方法 首先,产生的原因是,在安装…

2024最新源代码加密软件丨五款企业级软件评测

程序源代码作为企业的核心成果,一旦泄密将产生重大的损失,加密源代码至关重要。 可以防止他人未经授权使用、复制或修改源代码,保护开发者的劳动成果。 可以防止源代码被黑客或竞争对手获取和分析,减少漏洞被发现和利用的风险。…

JAVA极简图书管理系统,初识springboot后端项目

前提条件: 具备基础的springboot 知识 Java基础 废话不多说! 创建项目 配置所需环境 将application.properties>application.yml 配置以下环境 数据库连接MySQL 自己创建的数据库名称为book_test server:port: 8080 spring:datasource:url:…

ShareSDK HarmonyOS NEXT集成指南

集成前准备 注册账号 使用MobSDK之前,需要先在MobTech官网注册开发者账号,并获取MobTech提供的AppKey和AppSecret,详情可以点击查看注册流程 ShareSDK流程图 集成配置 添加依赖 在Terminal窗口中,执行如下命令进行安装 ohpm …

【Python】MacBook M系列芯片Anaconda下载Pytorch,并开发一个简单的数字识别代码(附带踩坑记录)

文章目录 配置镜像源下载Pytorch验证使用Pytorch进行数字识别 配置镜像源 Anaconda下载完毕之后,有两种方式下载pytorch,一种是用页面可视化的方式去下载,另一种方式就是直接用命令行工具去下载。 但是由于默认的Anaconda走的是外网&#x…

主干网络篇 | YOLOv8改进之引入YOLOv10的主干网络 | 全网最新改进

前言:Hello大家好,我是小哥谈。YOLOv10是由清华大学研究人员利用Ultralytics Python软件包开发的,它通过改进模型架构并消除非极大值抑制(NMS)提供了一种新颖的实时目标检测方法。这些优化使得模型在保持先进性能的同时,降低了计算需求。与以往的YOLO版本不同,YOLOv10的…

突发!Runway的Gen-3向所有人开放,媲美Sora!

7月2日凌晨,著名生成式AI平台Runway在官网宣布,其文生视频模型Gen-3 Alpha向所有用户开放使用。 上周日Runway只向部分用户提供了Gen-3的使用权限,「AIGC开放社区」也为大家解读了10个非常有代表性的视频案例。(点击查看&#xf…

晚上睡觉要不要关路由器?一语中的

前言 前几天小白去了一个朋友家,有朋友说:路由器不关机的话会影响睡眠吗? 这个影响睡眠嘛,确实是会的。毕竟一时冲浪一时爽,一直冲浪一直爽……刷剧刷抖音刷到根本停不下来,肯定影响睡眠。 所以晚上睡觉要…

昇思MindSpore学习笔记2-04 LLM原理和实践--文本解码原理--以MindNLP为例

摘要: 介绍了昇思MindSpore AI框架采用贪心搜索、集束搜索计算高概率词生成文本的方法、步骤,并为解决重复等问题所作的多种尝试。 这一节完全看不懂,猜测是如何用一定范围的词造句。 一、概念 自回归语言模型 文本序列概率分布 分解为每…

多模态融合 + 慢病精准预测

多模态融合 慢病精准预测 慢病预测算法拆解子解法1:多模态数据集成子解法2:实时数据处理与更新子解法3:采用大型语言多模态模型(LLMMs)进行深度学习分析 慢病预测更多模态 论文:https://arxiv.org/pdf/2406…

发电机保护屏组成都有哪些,如何选择

发电机保护屏组成都有哪些,如何选择 发电机是电力系统中最常用的一种电力设备。例如水力发电机,柴油发电机,风力发电机,火力发电等等。发电机保护是保证发电机安全、稳定运行的重要手段之一。对于一些小型机组的发电机&#xff0c…

探囊取物之多形式注册页面(基于BootStrap4)

基于BootStrap4的注册页面,支持手机验证码注册、账号密码注册 低配置云服务器,首次加载速度较慢,请耐心等候;演练页面可点击查看源码 预览页面:http://www.daelui.com/#/tigerlair/saas/preview/ly4gax38ub9j 演练页…

RTSP协议在视频监控系统中的典型应用、以及视频监控设备的rtsp地址格式介绍

目录 一、协议概述 1、定义 2、提交者 3、位置 二、主要特点 1、实时性 2、可扩展性 3、控制功能 4、回放支持 5、网络适应性 三、RTSP的工作原理 1、会话准备 2、会话建立 3、媒体流控制 4、会话终止 5、媒体数据传输 四、协议功能 1、双向性 2、带外协议 …

趣玩双色球APP-PyQt5实现

开发环境及软件主要功能说明 开发环境 win10 Vscode Python10.5-64_bit 使用的python库 requests,bs4,pandas,PyQt5 主要功能说明: 数据库更新,保存,另存为功能过滤显示,根据期数,开奖日期,开间期号过…

AndroidStudio activity-1.8.0.aar依赖报错

在使用Androidstudio自帶的創建activity及配套 xml時,構建項目失敗,報錯内容: Null extracted folder for artifact: ResolvedArtifact(componentIdentifierandroidx.activity:activity:1.8.0, variantNamenull, artifactFileC:\Users\hhhh\.…

Golang 开发实战day15 - Input info

🏆个人专栏 🤺 leetcode 🧗 Leetcode Prime 🏇 Golang20天教程 🚴‍♂️ Java问题收集园地 🌴 成长感悟 欢迎大家观看,不执着于追求顶峰,只享受探索过程 Golang 开发实战day15 - 用户…