反向传播详解BP

误差反向传播(Back-propagation, BP)算法的出现是神经网络发展的重大突破,也是现在众多深度学习训练方法的基础。该方法会计算神经网络中损失函数对各参数的梯度,配合优化方法更新参数,降低损失函数。

BP本来只指损失函数对参数的梯度通过网络反向流动的过程,但现在也常被理解成神经网络整个的训练方法,由误差传播、参数更新两个环节循环迭代组成。

本文将以最基础的全连接深度前馈网络为例,详细展示Back-propagation的全过程,并以Numpy进行实现。

图1 神经元为层/权重为层
通常我们以神经元来计量“层”,但本文将权重抽象为“层”,个人认为这样更有助于反向传播的理解和代码的编写。如上图所示的网络就被抽象为两个中间层、一个输出层的结构。

简而言之,神经网络的训练过程中,前向传播和反向传播交替进行,如下图所示:前向传播通过训练数据和权重参数计算输出结果;反向传播通过导数链式法则计算损失函数对各参数的梯度,并根据梯度进行参数的更新,这一点是重点,会在后文详叙。

图2 前向传播&反向传播
前向传播
每层中前向传播的过程如下所示,很简单的矩阵运算。我们将权重作为层,中间层和输出层均可用Layer类来表示,只是对应的激活函数不同。如图2所示,每一层的输入和输出都是
,且前一层的输出是后一层的输入。

* 表示element-wise乘积,· 表示矩阵乘积

class Layer:
‘’‘中间层类’‘’
self.W # (input_dim, output_dim)
self.b # (1, output_dim)
self.activate(a) = sigmoid(a)/tanh(a)/ReLU(a)/Softmax(a)

def forward(self, input_data):       # input_data: (1, input_dim)'''单个样本的前向传播'''input_data · self.W + self.b = a  # a: (1, output_dim)h = self.activate(a)              # h: (1, output_dim)return h
  1. 反向传播

损失对参数梯度的反向传播可以被这样直观解释:由A到传播B,即由
得到
,由导数链式法则
实现。所以神经网络的BP就是通过链式法则求出
对所有参数梯度的过程。

如上图示例,输入
,经过网络的参数
,得到一系列中间结果

表示通过权重和偏置的结果,还未经过激活函数,
表示经过激活函数后的结果。灰色框内表示
对各中间计算结果的梯度,这些梯度的反向传播有两类:



,通过激活函数,如右上角



,通过权重,如橙线部分

可以看出梯度的传播和前向传播的模式是一致的,只是方向不同。

计算完了灰色框的部分(损失对中间结果
的梯度),损失对参数
的梯度也就显而易见了,以图中红色的

为例:

因此,我们可以如图2,将反向传播的表达式和代码如下。

注意代码和公式中
表示element-wise乘积,
表示矩阵乘积。

* 表示element-wise乘积,· 表示矩阵乘积

class Layer:
‘’‘中间层类’‘’
self.W # (input_dim, output_dim)
self.b # (1, output_dim)
self.activate(a) = sigmoid(a)/tanh(a)/ReLU(a)/Softmax(a)

def forward(self, input_data):       # input_data: (1, input_dim)'''单个样本的前向传播'''input_data · self.W + self.b = a  # a: (1, output_dim)h = self.activate(a)              # h: (1, output_dim)return hdef backward(input_grad):'''单个样本的反向传播'''a_grad = input_grad * activate’(a)  # (1, output_dim)b_grad = a_grad                     # (1, output_dim)W_grad = (input_data.T) · a_grad    # (input_dim, output_dim)self.b -= learning_rate * b_grad self.W -= learning_rate * W_gradreturn a_grad · (self.W).T          # (1, input_dim)

输出层的反向传播略有不同,因为在分类任务中输出层若用到softmax激活函数,

不是逐个对应的,如下图所示,因此
中的element-wise相乘是失效的,需要用
乘以向量
到向量
的向量梯度(雅可比矩阵)。

但实际上,经过看上去复杂的计算后输出层
会计算出一个非常简洁的结果:

以分类任务为例(交叉熵损失、softmax、训练标签
为one-hot向量其中第
维为1):


以回归任务为例(二次损失、线性激活、训练标签
为实数向量):

因此输出层反向传播的公式和代码可以写成如下所示:

* 表示element-wise乘积,· 表示矩阵乘积

class Output_layer(Layer):
‘’‘属性和forward方法继承Layer类’‘’

def backward(input_grad):'''输出层backward方法''''''单个样本的反向传播'''a_grad = input_grad                 # (1, output_dim)b_grad = a_grad                     # (1, output_dim)W_grad = (input_data.T) · a_grad    # (input_dim, output_dim)self.b -= learning_rate * b_grad self.W -= learning_rate * W_gradreturn a_grad · (self.W).T          # (1, input_dim)
  1. Batch 批量计算

除非用随机梯度下降,否则每次用以训练的样本都是整个batch计算的,损失函数
则是整个batch中样本得到损失的均值。

在计算中会以向量化的方式增加运算效率,用batch_size表示批的规模,代码可更改为:

* 表示element-wise乘积,· 表示矩阵乘积

class Layer:
‘’‘中间层类’‘’

def forward(self, input_data):       # input_data: (batch_size, input_dim)'''batch_size个样本的前向传播'''input_data · self.W + self.b = a  # a: (1, output_dim)h = self.activate(a)              # h: (1, output_dim)return hdef backward(input_grad):             # input_grad: (batch_size, output_dim)'''batch_size个样本的反向传播'''a_grad = input_grad * activate’(a) # (batch_size, output_dim)b_grad = a_grad.mean(axis=0)       # (1, output_dim)W_grad = (a_grad.reshape(batch_size,1,output_dim) * input_data.reshape(batch_size,input_dim,1)).mean(axis=0)# (input_dim, output_dim) self.b -= lr * b_gradself.W -= lr * W_gradreturn a_grad · (self.W).T         # output_grad: (batch_size, input_dim)

class Output_layer(Layer):
‘’‘输出层类:属性和forward方法继承Layer类’‘’

def backward(input_grad):             # input_grad: (batch_size, output_dim)'''输出层backward方法''''''batch_size个样本的反向传播'''a_grad = input_grad                # (batch_size, output_dim)b_grad = a_grad.mean(axis=0)       # (1, output_dim)W_grad = (a_grad.reshape(batch_size,1,output_dim) * input_data.reshape(batch_size,input_dim,1)).mean(axis=0)# (input_dim, output_dim) self.b -= learning_rate * b_grad self.W -= learning_rate * W_gradreturn a_grad · (self.W).T          # output_grad: (batch_size, input_dim)

这里比较易错的地方是什么时候求均值,对
求均值还是对
求均值:梯度在中间结果
上都不需要求均值,对参数
的梯度时才需要求均值。

  1. 代码

https://github.com/qcneverrepeat/ML01/blob/master/BP_DNN.ipynb
​github.com/qcneverrepeat/ML01/blob/master/BP_DNN.ipynb

模拟一个三层神经网络的训练

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/144062.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

OpenAI暂停ChatGPT Plus新用户注册;迷宫与图神经网络

🦉 AI新闻 🚀 OpenAI暂停ChatGPT Plus新用户注册,考虑用户体验 摘要:OpenAI决定暂停ChatGPT Plus新用户注册,以应对开发日后使用量激增带来的压力,确保每个人都能享受良好的体验。根据调查机构Writerbudd…

下载huggingface预训练模型到本地并调用

写在前面 在大模型横行的时代,无法在服务器上连接外网的研究僧真的是太苦逼了,每次想尝试类似于CLIP,BLIP之类的大模型都会得到“requests.exceptions.ConnectionError: (MaxRetryError("HTTPSConnectionPool(host‘huggingface.co’, …

LeetCode-2760. 最长奇偶子数组-滑动窗口暴力

Problem: 2760. 最长奇偶子数组 每日一题。实习第10天记录。 文章目录 思路Code 思路 注意用条件找r。 Code class Solution {public int longestAlternatingSubarray(int[] nums, int threshold) {int len nums.length;int l, r;int res 0;for (l 0; l < len; l) {// 定…

Datawhale智能汽车AI挑战赛

1.赛题解析 赛题地址&#xff1a;https://tianchi.aliyun.com/competition/entrance/532155 任务&#xff1a; 输入&#xff1a;元宇宙仿真平台生成的前视摄像头虚拟视频数据&#xff08;8-10秒左右&#xff09;&#xff1b;输出&#xff1a;对视频中的信息进行综合理解&…

【解决方案】危化品厂区安防系统EasyCVR+AI智能监控

危化品属于危险、易燃易爆、易中毒行类&#xff0c;一旦在生产运输过程中发生泄漏后果不堪想象&#xff0c;所以危化品的生产储存更需要严密、精细的监控&#xff0c;来保障危化品的安全。EasyCVRTSINGSEE青犀AI智能分析网关搭建的危化品智能监控方案就能很好的为危化品监管保驾…

基于STC12C5A60S2系列1T 8051单片机的数模芯片DAC0832实现数模转换应用

基于STC12C5A60S2系列1T 8051单片的数模芯片DAC0832实现数模转换应用 STC12C5A60S2系列1T 8051单片机管脚图STC12C5A60S2系列1T 8051单片机I/O口各种不同工作模式及配置STC12C5A60S2系列1T 8051单片机I/O口各种不同工作模式介绍数模芯片DAC0832介绍通过按键调节数模芯片DAC0832…

微信小程序微信用户授权登录怎么在小程序上和钉钉相关联

要在微信小程序上实现微信用户授权登录并与钉钉相关联&#xff0c;你需要执行以下步骤&#xff1a; 钉钉开放平台注册应用&#xff1a;在钉钉开放平台上注册一个应用&#xff0c;获取到相应的AppID和AppSecret。 微信小程序授权登录&#xff1a;在微信小程序中使用wx.login接口…

WPF下实现拖动任意地方都可以拖动窗口

首先在xaml中添加事件 <Window PreviewMouseLeftButtonDown"Window_PreviewMouseLeftButtonDown"PreviewMouseMove"Window_PreviewMouseMove"PreviewMouseLeftButtonUp"Window_PreviewMouseLeftButtonUp"/>然后脚本输入 Point _pressedP…

【Spring进阶系列丨第二篇】Spring中的两大核心技术IoC(控制反转)与DI(依赖注入)

前言 我们都知道Spring 框架主要的优势是在 简化开发 和 框架整合 上&#xff0c;至于如何实现就是我们要学习Spring 框架的主要内容&#xff0c;今天我们就来一起学习Spring中的两大核心技术IoC&#xff08;控制反转&#xff09;与DI&#xff08;依赖注入&#xff09;。 文章目…

【数据结构】别跟我讲你不会冒泡排序

&#x1f466;个人主页&#xff1a;Weraphael ✍&#x1f3fb;作者简介&#xff1a;目前正在学习c和算法 ✈️专栏&#xff1a;数据结构 &#x1f40b; 希望大家多多支持&#xff0c;咱一起进步&#xff01;&#x1f601; 如果文章有啥瑕疵 希望大佬指点一二 如果文章对你有帮助…

【6】Spring Boot 3 集成组件:knift4j+springdoc+swagger3

目录 【6】Spring Boot 3 集成组件&#xff1a;knift4jspringdocswagger3OpenApi规范SpringFox Swagger3SpringFox工具&#xff08;不推荐&#xff09; Springdoc&#xff08;推荐&#xff09;从SpringFox迁移引入依赖配置jAVA Config 配置扩展配置&#xff1a;spring securit…

NumLevels

NumLevels&#xff1a;输入参数&#xff0c;最大的金字塔层数。默认auto&#xff0c;范围【0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, auto】。 AngleStart&#xff1a;输入参数&#xff0c;输入匹配时的起始角度。默认-0.39&#xff0c;建议值【 -3.14, -1.57, -0.79, -0.39, -0.20,…

CodeWhisperer 使用经验分享

今天给大家分享一下 Amazon CodeWhisperer 编程工具&#xff08;免费哦&#xff09;&#xff0c;使用这个软件后我的编码质量提升不少&#xff0c;给大家分享一下我的经验。希望大家支持哦。 Amazon CodeWhisperer 是亚⻢逊出品的一款基于机器学习的 AI 编程助手&#xff0c;可…

Java 开发常用的 Linux 命令知识积累

写在前面 虽然平时大部分工作都是和Java相关的开发, 但是每天都会接触Linux系统, 尤其是使用了Mac之后, 每天都是工作在黑色背景的命令行环境中. 自己记忆力不好, 很多有用的Linux命令不能很好的记忆, 现在逐渐总结一下, 以便后续查看. 基本操作 Linux关机,重启 # 关机 shu…

设计模式——建造者模式(Builder Pattern)+ Spring相关源码

文章目录 一、建造者模式定义二、例子2.1 自定义例子2.2 JDK源码——DateTimeFormatterBuilder2.3 Spring源码——BeanDefinitionBuilder 三、其他设计模式 一、建造者模式定义 类型&#xff1a; 创建型模式 介绍&#xff1a; 使用Builder类将多个简单的对象一步一步构建成一个…

C语言--每日五道练习题-- Day15

第一题 1、以下程序段的输出结果是&#xff08; &#xff09; #include<stdio.h> int main() {char s[] "\\123456\123456\t";printf("%d\n", strlen(s));return 0; } A: 12 B: 13 C: 16 D: 以上都不对 答案及解析 A 本题考查的是转义字符 \ 占…

uniapp 实现微信小程序手机号一键登录

app 和 h5 手机号一键登录&#xff0c;参考文档&#xff1a;uni-app官网 以下是uniapp 实现微信小程序手机号一键登录 1、布局 <template><view class"mainContent"><image class"closeImg" click"onCloseClick"src"quic…

上位机模块之通用重写相机类

在常用的视觉上位机中&#xff0c;我们通常会使用单个上位机匹配多个相机或者多品牌相机&#xff0c;所以在此记录一个可重写的通用相机类&#xff0c;用于后续长期维护开发。 先上代码。 using HalconDotNet; using System.Collections.Generic;namespace WeldingInspection.M…

SQL学习(CTFhub)整数型注入,字符型注入,报错注入 -----手工注入+ sqlmap注入

目录 整数型注入 手工注入 为什么要将1设置为-1呢&#xff1f; sqlmap注入 sqlmap注入步骤&#xff1a; 字符型注入 手工注入 sqlmap注入 报错注入 手工注入 sqlmap注入 整数型注入 手工注入 先输入1 接着尝试2&#xff0c;3&#xff0c;2有回显&#xff0c;而3没有回显…

MySQL中外键的使用及外键约束策略

一、外键约束的概念 外键约束&#xff08;FOREIGN KEY,缩写FK是数据库设计的一个概念&#xff0c;它确保在两个表之间的关系保持数据的一致性和完整性。 外键是指表中的某个字段的依赖于另一张表中某个字段的值&#xff0c;而被依赖的字段必须具有主键约束或者唯一约束&#…