深度学习 精选笔记(7)前向传播、反向传播和计算图

学习参考:

  • 动手学深度学习2.0
  • Deep-Learning-with-TensorFlow-book
  • pytorchlightning

①如有冒犯、请联系侵删。
②已写完的笔记文章会不定时一直修订修改(删、改、增),以达到集多方教程的精华于一文的目的。
③非常推荐上面(学习参考)的前两个教程,在网上是开源免费的,写的很棒,不管是开始学还是复习巩固都很不错的。

深度学习回顾,专栏内容来源多个书籍笔记、在线笔记、以及自己的感想、想法,佛系更新。争取内容全面而不失重点。完结时间到了也会一直更新下去,已写完的笔记文章会不定时一直修订修改(删、改、增),以达到集多方教程的精华于一文的目的。所有文章涉及的教程都会写在开头、一起学习一起进步。

前向传播用于计算模型的预测输出,反向传播用于根据预测输出和真实标签之间的误差来更新模型参数。

前向传播和反向传播是神经网络训练中的核心步骤,通过这两个过程,神经网络能够学习如何更好地拟合数据,提高预测准确性。

一、计算图

计算图(Computational Graph)是一种图形化表示方法,用于描述数学表达式中各个变量之间的依赖关系和计算流程。在深度学习和机器学习领域,计算图常用于可视化复杂的数学运算和函数计算过程,尤其是在反向传播算法中的梯度计算过程中被广泛应用。

计算图通常包括两种节点:

  • 计算节点(Compute Nodes):这些节点表示数学运算,如加法、乘法等。计算节点接受输入,并产生输出。
  • 数据节点(Data Nodes):这些节点表示数据或变量,如输入数据、权重、偏置等。

通过连接计算节点和数据节点的边,构建了一个有向图,其中每个节点表示一个操作,边表示数据流向。计算图可以帮助理解复杂的计算过程,特别是在深度学习中涉及大量参数和运算的情况下。

二、前向传播

前向传播(forward propagation或forward pass) 指的是:按顺序(从输入层到输出层)计算和存储神经网络中每层的结果。

前向传播(Forward Propagation):

  • 定义:前向传播是指输入数据通过神经网络模型的各层,逐层进行计算并传递至输出层的过程。
  • 作用:在前向传播过程中,输入数据经过神经网络的权重和激活函数的计算,最终得到模型的预测输出。
  • 目的:前向传播的目的是计算模型对输入数据的预测值,为后续的损失函数计算和反向传播提供基础。

1.前向传播的计算图

假设单隐藏层神经网络中,输入样本是 𝐱∈ℝ d, 并且隐藏层不包括偏置项。 这里的中间变量是:
在这里插入图片描述
其中 𝐖(1)∈ℝℎ×𝑑 是隐藏层的权重参数。 将中间变量 𝐳∈ℝℎ 通过激活函数 𝜙 后, 得到长度为 ℎ 的隐藏激活向量是:
在这里插入图片描述
隐藏变量 𝐡也是一个中间变量。 假设输出层的参数只有权重 𝐖(2)∈ℝ𝑞×ℎ, 可以得到输出层变量,它是一个长度为 𝑞 的向量:
在这里插入图片描述
假设损失函数为 𝑙,样本标签为 𝑦,可以计算单个数据样本的损失项,
在这里插入图片描述
根据 𝐿2 正则化的定义,给定超参数 𝜆 ,正则化项为
在这里插入图片描述
其中矩阵的Frobenius范数是将矩阵展平为向量后应用的 𝐿2范数。 最后,模型在给定数据样本上的正则化损失为:
在这里插入图片描述
该函数J就是目标函数。

绘制计算图有助于可视化计算中操作符和变量的依赖关系。

与上述简单网络相对应的计算图, 其中正方形表示变量,圆圈表示操作符。 左下角表示输入,右上角表示输出。 注意显示数据流的箭头方向主要是向右和向上的。
在这里插入图片描述

三、反向传播

反向传播(Backpropagation):

  • 定义:反向传播是指通过计算损失函数对模型参数的梯度(梯度是一个由偏导数组成的向量,表示函数在某一点处的变化率或者斜率方向、也就是在每个自变量方向上的偏导数),从输出层向输入层传播梯度的过程。
  • 作用:在反向传播过程中,根据损失函数计算模型参数的梯度,然后利用梯度下降等优化算法更新模型参数,以减小损失函数的值。
  • 目的:反向传播的目的是根据模型预测与真实标签的误差,调整神经网络中每个参数的值,使模型能够更好地拟合训练数据,并提高在新数据上的泛化能力。

反向传播(backward propagation或backpropagation)指的是计算神经网络参数梯度的方法。 简言之,该方法根据微积分中的链式规则,按相反的顺序从输出层到输入层遍历网络。 该算法存储了计算某些参数梯度时所需的任何中间变量(偏导数)。 假设有函数 𝖸=𝑓(𝖷) 和 𝖹=𝑔(𝖸) , 其中输入和输出 𝖷,𝖸,𝖹 是任意形状的张量。 利用链式法则,可以计算 𝖹 关于 𝖷 的导数:

在这里插入图片描述
使用 prod 运算符在执行必要的操作(如换位和交换输入位置)后将其参数相乘。 对于向量,这很简单,它只是矩阵-矩阵乘法。

在前向传播的计算图中,单隐藏层简单网络的参数是 𝐖(1) 和 𝐖(2) 。 反向传播的目的是计算梯度 ∂𝐽/∂𝐖(1)∂𝐽/∂𝐖(2) 。为此,应用链式法则,依次计算每个中间变量和参数的梯度。 计算的顺序与前向传播中执行的顺序相反,因为需要从计算图的结果开始,并朝着参数的方向努力。第一步是计算目标函数 𝐽=𝐿+𝑠 相对于损失项 𝐿 和正则项 𝑠 的梯度。

这里为什么等于1?因为单隐藏层简单网络的最后一层上面是
在这里插入图片描述
根据链式法则计算目标函数关于输出层变量 𝐨 的梯度:
在这里插入图片描述
计算正则化项相对于两个参数的梯度:

在这里插入图片描述
计算最接近输出层的模型参数的梯度 ∂𝐽/∂𝐖(2)∈ℝ𝑞×ℎ 。 使用链式法则得出:

在这里插入图片描述
为了获得关于 𝐖(1)的梯度,需要继续沿着输出层到隐藏层反向传播。 关于隐藏层输出的梯度 ∂𝐽/∂𝐡∈ℝℎ 由下式给出:
在这里插入图片描述
由于激活函数 𝜙 是按元素计算的, 计算中间变量 𝐳的梯度 ∂𝐽/∂𝐳∈ℝℎ 需要使用按元素乘法运算符,用 表示:
在这里插入图片描述
最后,可以得到最接近输入层的模型参数的梯度 ∂𝐽/∂𝐖(1)∈ℝℎ×𝑑 。 根据链式法则,我们得到:
在这里插入图片描述

四、训练神经网络

在训练神经网络时,前向传播和反向传播相互依赖。

对于前向传播,沿着依赖的方向遍历计算图并计算其路径上的所有变量。 然后将这些用于反向传播,其中计算顺序与计算图的相反。

以上述简单网络为例:
正则项:

在这里插入图片描述
反向传播中计算J对W(2)的梯度公式:
在这里插入图片描述
反向传播中计算J对W(1)的梯度公式:
在这里插入图片描述
一方面,在前向传播期间计算正则项取决于模型参数𝐖(1)和 𝐖(2)的当前值。 它们是由优化算法根据最近迭代的反向传播给出的。 另一方面,反向传播期间参数的梯度计算, 取决于由前向传播给出的隐藏变量𝐡的当前值。

因此,在训练神经网络时,在初始化模型参数后, 交替使用前向传播和反向传播,利用反向传播给出的梯度来更新模型参数。

注意,反向传播重复利用前向传播中存储的中间值,以避免重复计算。 带来的影响之一是需要保留中间值,直到反向传播完成。 这也是训练比单纯的预测需要更多的内存(显存)的原因之一。 此外,这些中间值的大小与网络层的数量和批量的大小大致成正比。 因此,使用更大的批量来训练更深层次的网络更容易导致内存不足(out of memory)错误

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/718267.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

U盘无法读取?轻松掌握正确解决方法!

“为什么我的u盘插入电脑后会显示无法读取呢?想查看一些比较重要的文件,但就是无法读取U盘,想问问大家,我应该怎么操作呢?” U盘作为一种便捷的数据存储设备,广泛应用于我们的日常生活和工作中。然而&#…

独立游戏《星尘异变》UE5 C++程序开发日志2——创建并编写一个C++类

在本篇日志中,我们将要用一个C类来实现一个游戏内的物品,同时介绍UCLASS、USTRUCT、UPROPERTY的使用 一、创建一个C类 我们在UE5的"内容侧滑菜单"中,在右侧空白中右键选择"新建C类",然后可以选择一个想要的…

Spring AOP(Aspect-Oriented Programming,面向切面编程)介绍

Spring AOP(Aspect-Oriented Programming,面向切面编程)是Spring框架的一个重要模块,它提供了一种强大的方式来帮助开发者实现横切关注点(cross-cutting concerns)的模块化。横切关注点是指那些影响多个模块…

Linux设备模型(十一) - platform设备

一,platform device概述 在Linux2.6以后的设备驱动模型中,需关心总线、设备和驱动这3个实体,总线将设备和驱动绑定。在系统每注册一个设备的时候, 会寻找与之匹配的驱动;相反的,在系统每注册一个设备的时…

可让照片人物“开口说话”阿里图生视频模型EMO,高启强普法

3 月 1 日消息,阿里巴巴研究团队近日发布了一款名为“EMO(Emote Portrait Alive)”的 AI 框架,该框架号称可以用于“对口型”,只需要输入人物照片及音频,模型就能够让照片中的人物开口说出相关音频&#xf…

PDN分析及应用系列二-简单5V电源分配-Altium Designer仿真分析-AD

PDN分析及应用系列二 —— 案例1:简单5V电源分配 预模拟DC网络识别 当最初为PCB设计打开PDN分析仪时,它将尝试根据公共电源网络命名法从设计中识别所有直流电源网络。 正确的DC网络识别对于获得最准确的模拟结果非常重要。 在示例项目中已经识别出主DC网络以简化该过程。 …

Vulnhub靶机:Bellatrix

一、介绍 运行环境:Virtualbox 攻击机:kali(10.0.2.4) 靶机:Bellatrix(10.0.2.9) 目标:获取靶机root权限和flag 靶机下载地址:https://www.vulnhub.com/entry/hogwa…

网络学习:MPLS技术基础知识

目录 一、MPLS技术产生背景 二、MPLS网络组成(基本概念) 1、MPLS技术简介:Multiprotocol Lable Switching,多协议标签交换技术 2、MPLS网络组成 三、MPLS的优势 四、MPLS的实际应用 一、MPLS技术产生背景 1、IP采用最长掩码…

Power BI vs Superset BI 调研报告

调研结论 SupersetPower BI价格开源①. Power BI Pro 每人 $10/月($120/年/人) ②. Power BI Premium 每人 $20/月($240/年/人) ③. Power BI Embedded:4C10G $11W/年 权限基于角色的访问控制,支持细粒度的访问: 表级别、库级别、图表级别,看板级别,用户级别 基于角色…

【推荐】免费AI论文写作神器-「智元兔 AI」

还在为写论文焦虑?免费AI写作大师来帮你三步搞定! 智元兔AI是ChatGPT的人工智能助手,并且具有出色的论文写作能力。它能够根据用户提供的题目或要求,自动生成高质量的论文。 不论是论文、毕业论文、散文、科普文章、新闻稿件&…

#WEB前端(浮动与定位)

1.实验&#xff1a; 2.IDE&#xff1a;VSCODE 3.记录&#xff1a; float、position 没有应用浮动前 应用左浮动和右浮动后 应用定位 4.代码&#xff1a; <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><me…

pyqt5怎么返回错误信息给页面(警告窗口)

在软件设计中&#xff0c;我们可能会遇到对异常的处理&#xff0c;有些异常是用户需要看到的&#xff0c;比如说&#xff0c;当我们登录出错的时候&#xff0c;后端需要给我们返回响应的错误信息&#xff0c;就像下图实现的这样。 类似这种效果&#xff0c;我们该如何实现&…

javaWebssh题库管理系统myeclipse开发mysql数据库MVC模式java编程计算机网页设计

一、源码特点 java ssh题库管理系统是一套完善的web设计系统&#xff08;系统采用ssh框架进行设计开发&#xff09;&#xff0c;对理解JSP java编程开发语言有帮助&#xff0c;系统具有完整的源代码和数据库&#xff0c;系统主要采用B/S模式开发。开发环境为TOMCAT7.0,Mye…

「MySQL」基本操作类型

&#x1f387;个人主页&#xff1a;Ice_Sugar_7 &#x1f387;所属专栏&#xff1a;数据库 &#x1f387;欢迎点赞收藏加关注哦&#xff01; 数据库的操作 创建、显示数据库 使用 create 创建一个数据库 create database goods;然后可以用 show databases 来查看已经创建的数…

javaWebssh网上超市销售管理系统myeclipse开发mysql数据库MVC模式java编程计算机网页设计

一、源码特点 java ssh网上超市销售管理系统是一套完善的web设计系统&#xff08;系统采用ssh框架进行设计开发&#xff09;&#xff0c;对理解JSP java编程开发语言有帮助&#xff0c;系统具有完整的源代码和数据库&#xff0c;系统主要采用B/S模式开发。开发环境为TOMCA…

指针深刻理解

指针深刻理解 看完鹏哥讲的c语言进阶视频后&#xff0c;又找来C语言深度剖析这本书仔细看了一遍&#xff0c;来进一步巩固和理解指针这个重点。 1&#xff1a;数组 如上图所示&#xff0c;当我们定义一个数组 a 时&#xff0c;编译器根据指定的元素个数和元素的类型分配确定大…

身份证识别系统(安卓)

设计内容与要求&#xff1a; 通过手机摄像头捕获身份证信息&#xff0c;将身份证上的姓名、性别、出生年月、身份证号码保存在数据库中。1&#xff09;所开发Apps软件至少需由3-5个以上功能性界面组成。要求&#xff1a;界面美观整洁、方便应用&#xff1b;可以使用Android原生…

JS 对象数组排序方法测试

输出 一.Array.prototype.sort() 1.默认排序 sort() sort() 方法就地对数组的元素进行排序&#xff0c;并返回对相同数组的引用。默认排序是将元素转换为字符串&#xff0c;然后按照它们的 UTF-16 码元值升序排序。 由于它取决于具体实现&#xff0c;因此无法保证排序的时…

数据可视化基础与应用-02-基于powerbi实现医院数据集的指标体系的仪表盘制作

总结 本系列是数据可视化基础与应用的第02篇&#xff0c;主要介绍基于powerbi实现医院数据集的指标体系的仪表盘制作。 数据集描述 医生数据集doctor 医生编号是唯一的&#xff0c;名称会存在重复 医疗项目数据projects 病例编号是唯一的&#xff0c;注意这个日期编号不是真…

面试时如何回答接口测试怎么进行

一、什么是接口测试 接口测试顾名思义就是对测试系统组件间接口的一种测试&#xff0c;接口测试主要用于检测外部系统与系统之间以及内部各个子系统之间的交互点。测试的重点是要检查数据的交换&#xff0c;传递和控制管理过程&#xff0c;以及系统间的相互逻辑依赖关系等。 …