《动手学深度学习(PyTorch版)》笔记4.7

Chapter4 Multilayer Perceptron

4.7 Forward/Backward Propagation and Computational Graphs

本节将通过一些基本的数学和计算图,深入探讨反向传播的细节。首先,我们将重点放在带权重衰减( L 2 L_2 L2正则化)的单隐藏层多层感知机上。

4.7.1 Forward Propagation

前向传播(forward propagation或forward pass)指的是按顺序(从输入层到输出层)计算和存储神经网络中每层的结果。

我们将一步步研究单隐藏层神经网络的机制,为了简单起见,我们假设输入样本是 x ∈ R d \mathbf{x}\in \mathbb{R}^d xRd,并且我们的隐藏层不包括偏置项。这里的中间变量是:

z = W ( 1 ) x , \mathbf{z}= \mathbf{W}^{(1)} \mathbf{x}, z=W(1)x,

其中 W ( 1 ) ∈ R h × d \mathbf{W}^{(1)} \in \mathbb{R}^{h \times d} W(1)Rh×d是隐藏层的权重参数。将中间变量 z ∈ R h \mathbf{z}\in \mathbb{R}^h zRh通过激活函数 ϕ \phi ϕ后,我们得到长度为 h h h的隐藏激活向量:

h = ϕ ( z ) . \mathbf{h}= \phi (\mathbf{z}). h=ϕ(z).

隐藏变量 h \mathbf{h} h也是一个中间变量。假设输出层的参数只有权重 W ( 2 ) ∈ R q × h \mathbf{W}^{(2)} \in \mathbb{R}^{q \times h} W(2)Rq×h,我们可以得到输出层变量,它是一个长度为 q q q的向量:

o = W ( 2 ) h . \mathbf{o}= \mathbf{W}^{(2)} \mathbf{h}. o=W(2)h.

假设损失函数为 l l l,样本标签为 y y y,我们可以计算单个数据样本的损失项,

L = l ( o , y ) . L = l(\mathbf{o}, y). L=l(o,y).

根据 L 2 L_2 L2正则化的定义,给定超参数 λ \lambda λ,正则化项为

s = λ 2 ( ∥ W ( 1 ) ∥ F 2 + ∥ W ( 2 ) ∥ F 2 ) , s = \frac{\lambda}{2} \left(\|\mathbf{W}^{(1)}\|_F^2 + \|\mathbf{W}^{(2)}\|_F^2\right), s=2λ(W(1)F2+W(2)F2),

∥ X ∥ F \|\mathbf{X}\|_F XF表示矩阵的Frobenius范数:
∥ X ∥ F = ∑ i = 1 m ∑ j = 1 n x i j 2 . \|\mathbf{X}\|_F = \sqrt{\sum_{i=1}^m \sum_{j=1}^n x_{ij}^2}. XF=i=1mj=1nxij2 .
最后,模型在给定数据样本上的正则化损失为:

J = L + s . J = L + s. J=L+s.

在下面的讨论中,我们将 J J J称为目标函数(objective function)。

下图是与上述简单网络相对应的计算图,其中正方形表示变量,圆圈表示操作符。

在这里插入图片描述

4.7.2 Backward Propagation

反向传播(backward propagation或backpropagation)指的是计算神经网络参数梯度的方法,该方法根据链式规则,按相反的顺序从输出层到输入层遍历网络。该算法存储了计算某些参数梯度时所需的任何中间变量(偏导数)。
假设我们有函数 Y = f ( X ) \mathsf{Y}=f(\mathsf{X}) Y=f(X) Z = g ( Y ) \mathsf{Z}=g(\mathsf{Y}) Z=g(Y),其中输入和输出 X , Y , Z \mathsf{X}, \mathsf{Y}, \mathsf{Z} X,Y,Z是任意形状的张量。利用链式法则,我们可以计算 Z \mathsf{Z} Z关于 X \mathsf{X} X的导数:

∂ Z ∂ X = prod ( ∂ Z ∂ Y , ∂ Y ∂ X ) . \frac{\partial \mathsf{Z}}{\partial \mathsf{X}} = \text{prod}\left(\frac{\partial \mathsf{Z}}{\partial \mathsf{Y}}, \frac{\partial \mathsf{Y}}{\partial \mathsf{X}}\right). XZ=prod(YZ,XY).

在这里,我们使用 prod \text{prod} prod运算符在执行必要的操(如换位和交换输入位置)后将其参数相乘。对于高维张量,我们使用适当的对应项。

在上面的计算图中单隐藏层简单网络的参数是 W ( 1 ) \mathbf{W}^{(1)} W(1) W ( 2 ) \mathbf{W}^{(2)} W(2),反向传播的目的是计算梯度 ∂ J / ∂ W ( 1 ) \partial J/\partial \mathbf{W}^{(1)} J/W(1) ∂ J / ∂ W ( 2 ) \partial J/\partial \mathbf{W}^{(2)} J/W(2),计算的顺序与前向传播中执行的顺序相反,具体如下:

∂ J ∂ L = 1 and ∂ J ∂ s = 1. \frac{\partial J}{\partial L} = 1 \; \text{and} \; \frac{\partial J}{\partial s} = 1. LJ=1andsJ=1.

∂ J ∂ o = prod ( ∂ J ∂ L , ∂ L ∂ o ) = ∂ L ∂ o ∈ R q . \frac{\partial J}{\partial \mathbf{o}} = \text{prod}\left(\frac{\partial J}{\partial L}, \frac{\partial L}{\partial \mathbf{o}}\right) = \frac{\partial L}{\partial \mathbf{o}} \in \mathbb{R}^q. oJ=prod(LJ,oL)=oLRq.

∂ s ∂ W ( 1 ) = λ W ( 1 ) , ∂ s ∂ W ( 2 ) = λ W ( 2 ) . \frac{\partial s}{\partial \mathbf{W}^{(1)}} = \lambda \mathbf{W}^{(1)} \; \text{,} \; \frac{\partial s}{\partial \mathbf{W}^{(2)}} = \lambda \mathbf{W}^{(2)}. W(1)s=λW(1),W(2)s=λW(2).

∂ J ∂ W ( 2 ) = prod ( ∂ J ∂ o , ∂ o ∂ W ( 2 ) ) + prod ( ∂ J ∂ s , ∂ s ∂ W ( 2 ) ) = ∂ J ∂ o h ⊤ + λ W ( 2 ) ∈ R q × h . \frac{\partial J}{\partial \mathbf{W}^{(2)}}= \text{prod}\left(\frac{\partial J}{\partial \mathbf{o}}, \frac{\partial \mathbf{o}}{\partial \mathbf{W}^{(2)}}\right) + \text{prod}\left(\frac{\partial J}{\partial s}, \frac{\partial s}{\partial \mathbf{W}^{(2)}}\right)= \frac{\partial J}{\partial \mathbf{o}} \mathbf{h}^\top + \lambda \mathbf{W}^{(2)}\in \mathbb{R}^{q \times h}. W(2)J=prod(oJ,W(2)o)+prod(sJ,W(2)s)=oJh+λW(2)Rq×h.

∂ J ∂ h = prod ( ∂ J ∂ o , ∂ o ∂ h ) = W ( 2 ) ⊤ ∂ J ∂ o ∈ R h . \frac{\partial J}{\partial \mathbf{h}} = \text{prod}\left(\frac{\partial J}{\partial \mathbf{o}}, \frac{\partial \mathbf{o}}{\partial \mathbf{h}}\right) = {\mathbf{W}^{(2)}}^\top \frac{\partial J}{\partial \mathbf{o}}\in \mathbb{R}^h. hJ=prod(oJ,ho)=W(2)oJRh.

由于激活函数 ϕ \phi ϕ是按元素计算的,计算中间变量 z \mathbf{z} z的梯度需要使用按元素乘法运算符,我们用 ⊙ \odot 表示:

∂ J ∂ z = prod ( ∂ J ∂ h , ∂ h ∂ z ) = ∂ J ∂ h ⊙ ϕ ′ ( z ) ∈ R h . \frac{\partial J}{\partial \mathbf{z}} = \text{prod}\left(\frac{\partial J}{\partial \mathbf{h}}, \frac{\partial \mathbf{h}}{\partial \mathbf{z}}\right) = \frac{\partial J}{\partial \mathbf{h}} \odot \phi'\left(\mathbf{z}\right)\in \mathbb{R}^h. zJ=prod(hJ,zh)=hJϕ(z)Rh.

∂ J ∂ W ( 1 ) = prod ( ∂ J ∂ z , ∂ z ∂ W ( 1 ) ) + prod ( ∂ J ∂ s , ∂ s ∂ W ( 1 ) ) = ∂ J ∂ z x ⊤ + λ W ( 1 ) = ∂ J ∂ h ⊙ ϕ ′ ( z ) x ⊤ + λ W ( 1 ) = ( W ( 2 ) ⊤ ∂ J ∂ o ) ⊙ ϕ ′ ( z ) x ⊤ + λ W ( 1 ) . \begin{align*} \frac{\partial J}{\partial \mathbf{W}^{(1)}} &= \text{prod}\left(\frac{\partial J}{\partial \mathbf{z}}, \frac{\partial \mathbf{z}}{\partial \mathbf{W}^{(1)}}\right) + \text{prod}\left(\frac{\partial J}{\partial s}, \frac{\partial s}{\partial \mathbf{W}^{(1)}}\right) \\ &= \frac{\partial J}{\partial \mathbf{z}} \mathbf{x}^\top + \lambda \mathbf{W}^{(1)} \\ &= \frac{\partial J}{\partial \mathbf{h}} \odot \phi'\left(\mathbf{z}\right)\mathbf{x}^\top + \lambda \mathbf{W}^{(1)} \\ &= ({\mathbf{W}^{(2)}}^\top \frac{\partial J}{\partial \mathbf{o}})\odot \phi'\left(\mathbf{z}\right)\mathbf{x}^\top + \lambda \mathbf{W}^{(1)}. \end{align*} W(1)J=prod(zJ,W(1)z)+prod(sJ,W(1)s)=zJx+λW(1)=hJϕ(z)x+λW(1)=(W(2)oJ)ϕ(z)x+λW(1).

4.7.3 Training Neural Networks

在训练神经网络时,前向传播和反向传播相互依赖。以上述简单网络为例:一方面,在前向传播期间计算正则项取决于模型参数 W ( 1 ) \mathbf{W}^{(1)} W(1) W ( 2 ) \mathbf{W}^{(2)} W(2)的当前值。它们是由优化算法根据最近迭代的反向传播给出的。另一方面,反向传播期间参数的梯度计算,取决于由前向传播给出的隐藏变量 h \mathbf{h} h的当前值。

因此,在训练神经网络时,我们交替使用前向传播和反向传播,利用反向传播给出的梯度来更新模型参数。注意,反向传播重复利用前向传播中存储的中间值,以避免重复计算。这带来的影响之一是我们需要保留中间值,直到反向传播完成,这也是训练比单纯的预测需要更多的内存(显存)的原因之一。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/653496.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【教学类-44-04】20240128汉字字帖的字体(一)——文艺空心黑体

背景需求: 【教学类-XX -XX 】20240128名字字卡1.0(15CM正方形手工纸、黑体,说明是某个孩子的第几个名字)-CSDN博客文章浏览阅读254次,点赞4次,收藏2次。【教学类-XX -XX 】20240128名字字卡1.0&#xff0…

12.Elasticsearch应用(十二)

Elasticsearch应用(十二) 1.单机ES面临的问题 海量数据存储问题单点故障问题 2.ES集群如何解决上面的问题 海量数据存储解决问题: 将索引库从逻辑上拆分为N个分片(Shard),存储到多个节点单点故障问题&a…

鸿蒙(HarmonyOS)项目方舟框架(ArkUI)之CheckboxGroup组件

鸿蒙(HarmonyOS)项目方舟框架(ArkUI)之CheckboxGroup组件 一、操作环境 操作系统: Windows 10 专业版、IDE:DevEco Studio 3.1、SDK:HarmonyOS 3.1 二、CheckboxGroup组件 提供多选框组件,通常用于某选项的打开或关…

【python】GtkWindow程序

一、多个GtkWindow 在GTK中,并不推荐使用多个GtkWindow来创建多文档界面(MDI),而是推荐使用单个GtkWindow内嵌入的小部件(如GtkNotebook)来实现类似的效果。然而,如果确实想要创建多个窗口的例…

教育能打破阶层固化吗

中式教育以应试为核心,强调知识的灌输和学生被动接受。随着社会的发展,中式教育的短板逐渐显现,创新能力的缺乏、对记忆的过度依赖、忽视个体差异等问题日益突出。 建议所有大学生都能去看看《上海交通大学生存手册》,它道出了中…

日常学习之:vue + django + docker + heroku 对后端项目 / 前后端整体项目进行部署

文章目录 使用 docker 在 heroku 上单独部署 vue 前端使用 docker 在 heroku 上单独部署 django 后端创建 heroku 项目构建 Dockerfile设置 settings.pydatabase静态文件管理安全设置applicaiton & 中间件配置 设置 requirements.txtheroku container 部署应用 前后端分别部…

(自用)learnOpenGL学习总结-高级OpenGL-模板测试

模板测试 模板测试简单来说就是一个mask,根据你的mask来保留或者丢弃片段。 那么可以用来显示什么功能呢?剪切,镂空、透明度等操作。 和深度缓冲的关系是: 先片段着色器,然后进入深度测试,最后加入模板测…

2023年中国工控自动化市场现状及竞争分析,美日占主角,国产品牌初崭头角

工控自动化是一种运用控制理论、仪器仪表理论、计算机和信息技术,对工业生产过程实现检测、控制、优化、调度、管理和决策,达到增加产量、提高质量、降低消耗、确保安全等目的综合性技术。产品应用领域广泛,可分为OEM型行业和项目型行业。 近…

2024年最新 MySQL的下载、安装、启动与停止

一、MySQL的下载 MySQL最常用的2个版本: 社区版:免费开源,自由下载,不提供官方技术支持,大多数普通用户选择这个即可。企业版:需要付费,不能在线下载,可以使用30天,提供…

aio-max-nr达到上限导致数据库性能问题

问题说明: rac数据库节点一表面上看由于归档等待事件导致业务性能问题。 问题分析过程: 查看awr报告top事件,等待事件主要为归档切换问题: 查看事件,归档等待达到20多分钟 检查节点alert日志发现,最…

Linux第37步_解决“Boot interface 6 not supported”之问题

在使用USB OTG将“自己移植的固件”烧写到eMMC中时,串口会输出“Boot interface 6 not supported”,发现很多人踩坑,我也一样。 见下图: 解决办法: 1、打开终端 输入“ls回车”,列出当前目录下所有的文件…

SpringCloud-高级篇(十六)

前面学习了Lua的语法,就可以在nginx去做编程,去实现nginx类里面的业务,查询Redis,查询tomcat等 ,业务逻辑的编写依赖于其他组件,这些组件会用到OpenResty的工具去实现 (1)安装OpenRe…

C++(16)——vector的模拟实现

前面的文章中,给出了对于的模拟实现,本篇文章将给出关于的模拟实现。 目录 1.基本框架: 2. 返回值与迭代器: 2.1 返回值capacity与size: 2.2 两种迭代器iterator和const_iterator: 3. 扩容与push_back与pop_back&#xff1a…

后端学习:数据库MySQL学习

数据库简介 数据库:英文为 DataBase,简称DB,它是存储和管理数据的仓库。   接下来,我们来学习Mysql的数据模型,数据库是如何来存储和管理数据的。在介绍 Mysql的数据模型之前,需要先了解一个概念&#xf…

log4j2 配置入门介绍

配置 将日志请求插入到应用程序代码中需要进行大量的计划和工作。 观察表明,大约4%的代码专门用于日志记录。因此,即使是中等规模的应用程序也会在其代码中嵌入数千条日志记录语句。 考虑到它们的数量,必须管理这些日志语句,而…

【QT+QGIS跨平台编译】之十三:【giflib+Qt跨平台编译】(一套代码、一套框架,跨平台编译)

文章目录 一、giflib介绍二、文件下载三、文件分析四、pro文件五、编译实践一、giflib介绍 GIFlib(又称为Libgif)是一个开源的C语言库,用于处理GIF图像格式。它提供了一组函数和工具,使得开发者可以读取、写入和操作GIF图像文件。 GIFlib支持GIF87a和GIF89a两种版本的GIF格…

【项目管理】CMMI-管理性能与度量

管理性能与度量 (Managing Performance and Measurement, MPM)的目的在于开发和维持度量能力来管理开发过程性能,以实现公司业务目标,更直接来说,将管理和改进工作集中在成本、进度表和质量性能上,最大限度地提高业务投资回报。 1…

linux(进程概念)

目录 前言: 正文 冯诺依曼体系结构 操作系统 (Operator System) 概念 目的 定位 如何理解“管理” 进程组织 基本概念 内核数据结构 代码和数据 查看进程 ps指令 top指令 父子进程 fork创建进程 小结: 前…

【Redis】Redis有哪些适合的场景

🍎个人博客:个人主页 🏆个人专栏:Redis ⛳️ 功不唐捐,玉汝于成 目录 前言 正文 (1)会话缓存(Session Cache) (2)全页缓存(FPC…

深入理解TCP网络协议(1)

目录 1.TCP协议的段格式 2.TCP原理 2.1确认应答 2.2超时重传 3.三次握手(重点) 4.四次挥手 1.TCP协议的段格式 我们先来观察一下TCP协议的段格式图解: 源/目的端口号:标识数据从哪个进程来,到哪个进程去 32位序号/32位确认号:TCP会话的每一端都包含一个32位&#xff08…