【深度学习笔记】3_14 正向传播、反向传播和计算图

3.14 正向传播、反向传播和计算图

前面几节里我们使用了小批量随机梯度下降的优化算法来训练模型。在实现中,我们只提供了模型的正向传播(forward propagation)的计算,即对输入计算模型输出,然后通过autograd模块来调用系统自动生成的backward函数计算梯度。基于反向传播(back-propagation)算法的自动求梯度极大简化了深度学习模型训练算法的实现。本节我们将使用数学和计算图(computational graph)两个方式来描述正向传播和反向传播。具体来说,我们将以带 L 2 L_2 L2范数正则化的含单隐藏层的多层感知机为样例模型解释正向传播和反向传播。

3.14.1 正向传播

正向传播是指对神经网络沿着从输入层到输出层的顺序,依次计算并存储模型的中间变量(包括输出)。为简单起见,假设输入是一个特征为 x ∈ R d \boldsymbol{x} \in \mathbb{R}^d xRd的样本,且不考虑偏差项,那么中间变量

z = W ( 1 ) x , \boldsymbol{z} = \boldsymbol{W}^{(1)} \boldsymbol{x}, z=W(1)x,

其中 W ( 1 ) ∈ R h × d \boldsymbol{W}^{(1)} \in \mathbb{R}^{h \times d} W(1)Rh×d是隐藏层的权重参数。把中间变量 z ∈ R h \boldsymbol{z} \in \mathbb{R}^h zRh输入按元素运算的激活函数 ϕ \phi ϕ后,将得到向量长度为 h h h的隐藏层变量

h = ϕ ( z ) . \boldsymbol{h} = \phi (\boldsymbol{z}). h=ϕ(z).

隐藏层变量 h \boldsymbol{h} h也是一个中间变量。假设输出层参数只有权重 W ( 2 ) ∈ R q × h \boldsymbol{W}^{(2)} \in \mathbb{R}^{q \times h} W(2)Rq×h,可以得到向量长度为 q q q的输出层变量

o = W ( 2 ) h . \boldsymbol{o} = \boldsymbol{W}^{(2)} \boldsymbol{h}. o=W(2)h.

假设损失函数为 ℓ \ell ,且样本标签为 y y y,可以计算出单个数据样本的损失项

L = ℓ ( o , y ) . L = \ell(\boldsymbol{o}, y). L=(o,y).

根据 L 2 L_2 L2范数正则化的定义,给定超参数 λ \lambda λ,正则化项即

s = λ 2 ( ∥ W ( 1 ) ∥ F 2 + ∥ W ( 2 ) ∥ F 2 ) , s = \frac{\lambda}{2} \left(\|\boldsymbol{W}^{(1)}\|_F^2 + \|\boldsymbol{W}^{(2)}\|_F^2\right), s=2λ(W(1)F2+W(2)F2),

其中矩阵的Frobenius范数等价于将矩阵变平为向量后计算 L 2 L_2 L2范数。最终,模型在给定的数据样本上带正则化的损失为

J = L + s . J = L + s. J=L+s.

我们将 J J J称为有关给定数据样本的目标函数,并在以下的讨论中简称目标函数。

3.14.2 正向传播的计算图

我们通常绘制计算图来可视化运算符和变量在计算中的依赖关系。图3.6绘制了本节中样例模型正向传播的计算图,其中左下角是输入,右上角是输出。可以看到,图中箭头方向大多是向右和向上,其中方框代表变量,圆圈代表运算符,箭头表示从输入到输出之间的依赖关系。

在这里插入图片描述

图3.6 正向传播的计算图

3.14.3 反向传播

反向传播指的是计算神经网络参数梯度的方法。总的来说,反向传播依据微积分中的链式法则,沿着从输出层到输入层的顺序,依次计算并存储目标函数有关神经网络各层的中间变量以及参数的梯度。对输入或输出 X , Y , Z \mathsf{X}, \mathsf{Y}, \mathsf{Z} X,Y,Z为任意形状张量的函数 Y = f ( X ) \mathsf{Y}=f(\mathsf{X}) Y=f(X) Z = g ( Y ) \mathsf{Z}=g(\mathsf{Y}) Z=g(Y),通过链式法则,我们有

∂ Z ∂ X = prod ( ∂ Z ∂ Y , ∂ Y ∂ X ) , \frac{\partial \mathsf{Z}}{\partial \mathsf{X}} = \text{prod}\left(\frac{\partial \mathsf{Z}}{\partial \mathsf{Y}}, \frac{\partial \mathsf{Y}}{\partial \mathsf{X}}\right), XZ=prod(YZ,XY),

其中 prod \text{prod} prod运算符将根据两个输入的形状,在必要的操作(如转置和互换输入位置)后对两个输入做乘法。

回顾一下本节中样例模型,它的参数是 W ( 1 ) \boldsymbol{W}^{(1)} W(1) W ( 2 ) \boldsymbol{W}^{(2)} W(2),因此反向传播的目标是计算 ∂ J / ∂ W ( 1 ) \partial J/\partial \boldsymbol{W}^{(1)} J/W(1) ∂ J / ∂ W ( 2 ) \partial J/\partial \boldsymbol{W}^{(2)} J/W(2)。我们将应用链式法则依次计算各中间变量和参数的梯度,其计算次序与前向传播中相应中间变量的计算次序恰恰相反。首先,分别计算目标函数 J = L + s J=L+s J=L+s有关损失项 L L L和正则项 s s s的梯度

∂ J ∂ L = 1 , ∂ J ∂ s = 1. \frac{\partial J}{\partial L} = 1, \quad \frac{\partial J}{\partial s} = 1. LJ=1,sJ=1.

其次,依据链式法则计算目标函数有关输出层变量的梯度 ∂ J / ∂ o ∈ R q \partial J/\partial \boldsymbol{o} \in \mathbb{R}^q J/oRq

∂ J ∂ o = prod ( ∂ J ∂ L , ∂ L ∂ o ) = ∂ L ∂ o . \frac{\partial J}{\partial \boldsymbol{o}} = \text{prod}\left(\frac{\partial J}{\partial L}, \frac{\partial L}{\partial \boldsymbol{o}}\right) = \frac{\partial L}{\partial \boldsymbol{o}}. oJ=prod(LJ,oL)=oL.

接下来,计算正则项有关两个参数的梯度:

∂ s ∂ W ( 1 ) = λ W ( 1 ) , ∂ s ∂ W ( 2 ) = λ W ( 2 ) . \frac{\partial s}{\partial \boldsymbol{W}^{(1)}} = \lambda \boldsymbol{W}^{(1)},\quad\frac{\partial s}{\partial \boldsymbol{W}^{(2)}} = \lambda \boldsymbol{W}^{(2)}. W(1)s=λW(1),W(2)s=λW(2).

现在,我们可以计算最靠近输出层的模型参数的梯度 ∂ J / ∂ W ( 2 ) ∈ R q × h \partial J/\partial \boldsymbol{W}^{(2)} \in \mathbb{R}^{q \times h} J/W(2)Rq×h。依据链式法则,得到

∂ J ∂ W ( 2 ) = prod ( ∂ J ∂ o , ∂ o ∂ W ( 2 ) ) + prod ( ∂ J ∂ s , ∂ s ∂ W ( 2 ) ) = ∂ J ∂ o h ⊤ + λ W ( 2 ) . \frac{\partial J}{\partial \boldsymbol{W}^{(2)}} = \text{prod}\left(\frac{\partial J}{\partial \boldsymbol{o}}, \frac{\partial \boldsymbol{o}}{\partial \boldsymbol{W}^{(2)}}\right) + \text{prod}\left(\frac{\partial J}{\partial s}, \frac{\partial s}{\partial \boldsymbol{W}^{(2)}}\right) = \frac{\partial J}{\partial \boldsymbol{o}} \boldsymbol{h}^\top + \lambda \boldsymbol{W}^{(2)}. W(2)J=prod(oJ,W(2)o)+prod(sJ,W(2)s)=oJh+λW(2).

沿着输出层向隐藏层继续反向传播,隐藏层变量的梯度 ∂ J / ∂ h ∈ R h \partial J/\partial \boldsymbol{h} \in \mathbb{R}^h J/hRh可以这样计算:

∂ J ∂ h = prod ( ∂ J ∂ o , ∂ o ∂ h ) = W ( 2 ) ⊤ ∂ J ∂ o . \frac{\partial J}{\partial \boldsymbol{h}} = \text{prod}\left(\frac{\partial J}{\partial \boldsymbol{o}}, \frac{\partial \boldsymbol{o}}{\partial \boldsymbol{h}}\right) = {\boldsymbol{W}^{(2)}}^\top \frac{\partial J}{\partial \boldsymbol{o}}. hJ=prod(oJ,ho)=W(2)oJ.

由于激活函数 ϕ \phi ϕ是按元素运算的,中间变量 z \boldsymbol{z} z的梯度 ∂ J / ∂ z ∈ R h \partial J/\partial \boldsymbol{z} \in \mathbb{R}^h J/zRh的计算需要使用按元素乘法符 ⊙ \odot

∂ J ∂ z = prod ( ∂ J ∂ h , ∂ h ∂ z ) = ∂ J ∂ h ⊙ ϕ ′ ( z ) . \frac{\partial J}{\partial \boldsymbol{z}} = \text{prod}\left(\frac{\partial J}{\partial \boldsymbol{h}}, \frac{\partial \boldsymbol{h}}{\partial \boldsymbol{z}}\right) = \frac{\partial J}{\partial \boldsymbol{h}} \odot \phi'\left(\boldsymbol{z}\right). zJ=prod(hJ,zh)=hJϕ(z).

最终,我们可以得到最靠近输入层的模型参数的梯度 ∂ J / ∂ W ( 1 ) ∈ R h × d \partial J/\partial \boldsymbol{W}^{(1)} \in \mathbb{R}^{h \times d} J/W(1)Rh×d。依据链式法则,得到

∂ J ∂ W ( 1 ) = prod ( ∂ J ∂ z , ∂ z ∂ W ( 1 ) ) + prod ( ∂ J ∂ s , ∂ s ∂ W ( 1 ) ) = ∂ J ∂ z x ⊤ + λ W ( 1 ) . \frac{\partial J}{\partial \boldsymbol{W}^{(1)}} = \text{prod}\left(\frac{\partial J}{\partial \boldsymbol{z}}, \frac{\partial \boldsymbol{z}}{\partial \boldsymbol{W}^{(1)}}\right) + \text{prod}\left(\frac{\partial J}{\partial s}, \frac{\partial s}{\partial \boldsymbol{W}^{(1)}}\right) = \frac{\partial J}{\partial \boldsymbol{z}} \boldsymbol{x}^\top + \lambda \boldsymbol{W}^{(1)}. W(1)J=prod(zJ,W(1)z)+prod(sJ,W(1)s)=zJx+λW(1).

3.14.4 训练深度学习模型

在训练深度学习模型时,正向传播和反向传播之间相互依赖。下面我们仍然以本节中的样例模型分别阐述它们之间的依赖关系。

一方面,正向传播的计算可能依赖于模型参数的当前值,而这些模型参数是在反向传播的梯度计算后通过优化算法迭代的。例如,计算正则化项 s = ( λ / 2 ) ( ∥ W ( 1 ) ∥ F 2 + ∥ W ( 2 ) ∥ F 2 ) s = (\lambda/2) \left(\|\boldsymbol{W}^{(1)}\|_F^2 + \|\boldsymbol{W}^{(2)}\|_F^2\right) s=(λ/2)(W(1)F2+W(2)F2)依赖模型参数 W ( 1 ) \boldsymbol{W}^{(1)} W(1) W ( 2 ) \boldsymbol{W}^{(2)} W(2)的当前值,而这些当前值是优化算法最近一次根据反向传播算出梯度后迭代得到的。

另一方面,反向传播的梯度计算可能依赖于各变量的当前值,而这些变量的当前值是通过正向传播计算得到的。举例来说,参数梯度 ∂ J / ∂ W ( 2 ) = ( ∂ J / ∂ o ) h ⊤ + λ W ( 2 ) \partial J/\partial \boldsymbol{W}^{(2)} = (\partial J / \partial \boldsymbol{o}) \boldsymbol{h}^\top + \lambda \boldsymbol{W}^{(2)} J/W(2)=(J/o)h+λW(2)的计算需要依赖隐藏层变量的当前值 h \boldsymbol{h} h。这个当前值是通过从输入层到输出层的正向传播计算并存储得到的。

因此,在模型参数初始化完成后,我们交替地进行正向传播和反向传播,并根据反向传播计算的梯度迭代模型参数。既然我们在反向传播中使用了正向传播中计算得到的中间变量来避免重复计算,那么这个复用也导致正向传播结束后不能立即释放中间变量内存。这也是训练要比预测占用更多内存的一个重要原因。另外需要指出的是,这些中间变量的个数大体上与网络层数线性相关,每个变量的大小跟批量大小和输入个数也是线性相关的,它们是导致较深的神经网络使用较大批量训练时更容易超内存的主要原因。

小结

  • 正向传播沿着从输入层到输出层的顺序,依次计算并存储神经网络的中间变量。
  • 反向传播沿着从输出层到输入层的顺序,依次计算并存储神经网络中间变量和参数的梯度。
  • 在训练深度学习模型时,正向传播和反向传播相互依赖。

注:本节与原书基本相同,原书传送门

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/704332.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

哪只基金更值得买入?学会这套BI基金分析逻辑,稳赚不赔

投资基金是一种出色的理财方式,对于初次涉足基金领域的投资者而言,首先需要解决两个关键问题:一是基金是否值得投资?二是如何选择适合自己的基金? 以往盲目跟随成功的基金经理,或者仅仅依赖历史涨跌经验的…

消息中间件之RocketMQ源码分析(十七)

Broker CommitLog索引机制的数据结构 ConsumerQueue消费队列 主要用于消费拉取消息、更新消费位点等所用的索引。源代码参考org.apache.rocketmq.store.ConsumerQueue.该文件内保存了消息的物理位点、消息体大小、消息Tag的Hash值 物理位点:消息在CommitLog中的位点值消息体…

Android 水波纹扩散效果实现

人生只是一种体验,不必用来演绎完美。 效果图 View源码 package com.android.circlescalebar.view;import android.animation.Animator; import android.animation.AnimatorListenerAdapter; import android.animation.ObjectAnimator; import android.animation.…

el-tab-pane标签页如何加图标

效果如下 主要修改 <el-tab-pane name"tab6" v-if"subOrderType 10 && urlname ! wgSalesTerminationOrder"><span slot"label"> 售后判责<span class"el-icon-warning" style"color:#F66B6C;"&…

TensorFlow训练大模型做AI绘图,需要多少的GPU算力支撑

TensorFlow训练大模型做AI绘图&#xff0c;需要多少的GPU算力支撑&#xff01;这个问题就涉及到了资金投资的额度了。众所周知&#xff0c;现在京东里面一个英伟达的显卡&#xff0c;按照RTX3090(24G显存-涡轮风扇&#xff09;版本报价是7000-7500之间。如果你买一张这样的单卡…

【MySQL面试复习】谈一谈你对SQL的优化经验

系列文章目录 在MySQL中&#xff0c;如何定位慢查询&#xff1f; 发现了某个SQL语句执行很慢&#xff0c;如何进行分析&#xff1f; 了解过索引吗&#xff1f;(索引的底层原理)/B 树和B树的区别是什么&#xff1f; 什么是聚簇索引&#xff08;聚集索引&#xff09;和非聚簇索引…

原型设计工具Axure RP

Axure RP是一款专业的快速原型设计工具。Axure&#xff08;发音&#xff1a;Ack-sure&#xff09;&#xff0c;代表美国Axure公司&#xff1b;RP则是Rapid Prototyping&#xff08;快速原型&#xff09;的缩写。 下载链接&#xff1a;https://www.axure.com/ 下载 可以免费试用…

一个Post请求入门NestJS的路由与控制器

​ NestJS的控制器 控制器负责处理传入请求并向客户端返回响应。 控制器的目的是接收应用的特定请求。路由机制控制哪个控制器接收哪些请求。 通常&#xff0c;每个控制器都有不止一条路由&#xff0c;不同的路由可以执行不同的操作。 在使用了脚手架的项目中&#xff0c;我…

【Java程序员面试专栏 算法思维】四 高频面试算法题:回溯算法

一轮的算法训练完成后,对相关的题目有了一个初步理解了,接下来进行专题训练,以下这些题目就是汇总的高频题目,本篇主要聊聊回溯算法,主要就是排列组合问题,所以放到一篇Blog中集中练习 题目关键字解题思路时间空间岛屿数量网格搜索分别向上下左右四个方向探索,遇到海洋…

生成式 AI - Diffusion 模型的数学原理(5)

来自 论文《 Denoising Diffusion Probabilistic Model》&#xff08;DDPM&#xff09; 论文链接&#xff1a; https://arxiv.org/abs/2006.11239 Hung-yi Lee 课件整理 讲到这里还没有解决的问题是&#xff0c;为什么这里还要多加一个噪声。Denoise模型算出来的是高斯分布的均…

【VTKExamples::PolyData】第三十八期 Outline

很高兴在雪易的CSDN遇见你 VTK技术爱好者 QQ:870202403 前言 本文分享VTK样例Outline,并解析接口vtkOutlineFilter,希望对各位小伙伴有所帮助! 感谢各位小伙伴的点赞+关注,小易会继续努力分享,一起进步! 你的点赞就是我的动力(^U^)ノ~YO 1. Outline // Create…

Sora的潜力与影响:对视频制作、广告、娱乐等行业的深度解析

随着技术的飞速发展&#xff0c;OpenAI推出的Sora模型已经引起了广泛关注。作为一款强大的视频生成工具&#xff0c;Sora不仅改变了视频制作的传统模式&#xff0c;更对广告、娱乐等多个行业产生了深远影响。本文将深度解析Sora的潜力与影响&#xff0c;探讨其在视频制作、广告…

Python自定义logger模块(附Demo)

目录 1. 内置logger2. 自定义logger 1. 内置logger Python标准库中的logging模块提供了日志记录的功能 允许开发者通过创建日志记录器、处理程序和格式化器来控制日志的生成和输出 以下是logging模块的一些主要组件和概念&#xff1a; 日志记录器 (Logger)&#xff1a;整个…

Canvas动画之豌豆射手

&#x1f339;作者主页&#xff1a;青花锁 &#x1f339;简介&#xff1a;Java领域优质创作者&#x1f3c6;、Java微服务架构公号作者&#x1f604; &#x1f339;简历模板、学习资料、面试题库、技术互助 &#x1f339;文末获取联系方式 &#x1f4dd; 往期热门专栏回顾 专栏…

Fl Studio 20.9.2.2963 中文破解版2024永久版下载(含Keygen)

FL Studio20.9是一款流行的图像线软件制作和编辑音频文件。作为一款领先的创新产品&#xff0c;该软件能够满足在创作音乐方面的需求。有了这个产品&#xff0c;可以完成制作音乐的整个过程。可以使用这个软件进行写作&#xff0c;编辑&#xff0c;录音&#xff0c;编辑和混合和…

DP读书:《工程热力学(第二版)》(一)绪论——能量及其利用

DP读书&#xff1a;《工程热力学&#xff08;第二版&#xff09;》绪论 0.1 能量及其利用 热力学——研究对象&#xff1a;能量 能量 物质能量传递 普遍规律 能源&#xff1a;直接提供能量的物质资源 一次能源&#xff1a;热能占比85% 直接利用——>冶金、采暖、炊煮 …

2024Node.js零基础教程(小白友好型),nodejs新手到高手,(九)NodeJS入门——http模块

060_http模块_网页URL之绝对路径 hello&#xff0c;大家好&#xff0c;这一个小题的话我们来补充一个之前学习过的内容&#xff0c;就是网页当中的URL&#xff0c;咱们这个小题的话主要是来说一下绝对路径&#xff0c;有同学可能会说&#xff0c;这这这&#xff0c;不对劲&…

抽象的后端

Connection refused: no further information 出现这条代码的核心是你使用redis&#xff0c;但是本地没有开启redis服务 如何启动redis服务 第一步&#xff1a;确定你安装了对应的框架 以spring为例 <dependency><groupId>org.springframework.boot</group…

架构设计实践:熟悉架构设计方法论,并动手绘制架构设计图

文章目录 一、架构设计要素1、架构设计目标2、架构设计模式&#xff08;1&#xff09;分而治之&#xff08;2&#xff09;迭代式设计 3、架构设计的输入&#xff08;1&#xff09;概览&#xff08;2&#xff09;功能需求 - WH分析法&#xff08;3&#xff09;质量 - “怎么”分…

卡玛网● 46. 携带研究材料 ● 01背包问题,你该了解这些! 滚动数组 力扣● 416. 分割等和子集

开始背包问题&#xff0c;掌握0-1背包和完全背包即可&#xff0c;注&#xff1a;0-1背包是完全背包的基础。 0-1背包问题&#xff1a;有n件物品和一个最多能背重量为w 的背包。第i件物品的重量是weight[i]&#xff0c;得到的价值是value[i] 。每件物品只能用一次&#xff0c;求…