从0开始深度学习(6)——Pytorch动态图机制(前向传播、反向传播)

PyTorch 的动态计算图机制是其核心特性之一,它使得深度学习模型的开发更加灵活和高效。

0 计算图

计算图(Computation Graph)是一种用于表示数学表达式或程序流程的图形结构,可以将复杂的表达式分解成一系列简单的操作,并以节点和边的形式展示这些操作及其之间的关系,能够清晰地展示计算过程中的依赖关系

  • 节点(Nodes): 表示变量或常量,也可以表示操作(如加法、乘法等)。
  • 边(Edges) :表示数据流的方向,即一个操作的结果如何作为输入传递给下一个操作。

举例说明

假设我们有如下的表达式:
x = a + b x=a+b x=a+b
y = c ∗ d y=c*d y=cd
z = x + y z=x+y z=x+y
在这里插入图片描述

其中 a , b , c , d a,b,c,d a,b,c,d是输入变量, z z z是输出变量。

假设令 a = 1 , b = 2 , c = 3 , d = 4 a=1,b=2,c=3,d=4 a=1,b=2,c=3,d=4,则可以根据上面的计算图,计算出 z z z,计算过程如下:
x = a + b = 1 + 2 = 3 x=a+b=1+2=3 x=a+b=1+2=3
y = c ∗ d = 3 ∗ 4 = 12 y=c*d=3*4=12 y=cd=34=12
z = x + y = 12 + 3 = 5 z=x+y=12+3=5 z=x+y=12+3=5

1 前向传播、反向传播

但是在机器学习中,计算图不仅展示了计算路径,还为计算梯度提供了基础,这里涉及到前向传播和反向传播

1.1 前向传播(计算输出)

根据当前的模型参数计算网络的输出,前向传播的输出作用是:

  • 评估模型性能: 通过比较网络的预测值与实际值,可以计算损失函数的值,从而评估模型当前的性能。
  • 准备反向传播: 前向传播计算出的中间结果(例如每个隐藏层的输出)是反向传播中计算梯度所必需的。上面的计算过程就是前向传播。
    在这里插入图片描述

1.2 反向传播(计算梯度)

目的是计算损失函数关于每个模型参数的梯度。这些梯度信息用于更新模型参数,使模型在下一次迭代中表现得更好,关于梯度是什么,在从0开始深度学习(2)——自动微分解释了梯度的概念和计算过程。
在这里插入图片描述

现在依然以上面的例子举例,假设我么们有一个损失函数 L L L,并且 L L L关于 z z z的梯度已知为 ∂ L ∂ z = 1 \frac{\partial L}{\partial z}=1 zL=1(通常都这样设置,且适用于多种函数),目标是计算 a , b , c , d a,b,c,d a,b,c,d关于 L L L的梯度,反着计算。
在这里插入图片描述

1 静态计算图

静态图是通过先定义后运行的方式,先搭建图,然后再输入数据进行计算,典型代表是Tensorflow 1.0 版本,Tensorflow名字的来源就是因为张量Tensor在预先定义的图中流动(Flow)
在这里插入图片描述

2 动态计算图

动态图是指计算图的运算和搭建同时运行,也就是可以先计算前面的节点的值,再根据这些值搭建后面的图,如下图所示:
在这里插入图片描述

2.1核心概念

2.1.1 张量

参考从0开始深度学习(1)—— 代码实现线性代数中的概念解释

2.1.2 自动微分

pytorch框架中借助自动微分机制来实现动态计算图,意味着计算图是在运行时动态构建的,它使得计算复杂模型的梯度变得非常方便。

  • 计算图在每次前向传播时重新构建。这意味着每个操作都会记录下来,并在反向传播时按需计算梯度。
  • 每个张量(Tensor)都有一个 .grad 属性来存储梯度,以及一个 .grad_fn 属性来记录创建该张量的操作。

关键属性:

  1. data:包含tensor的实际数据
  2. grad:存储tensor的梯度。在开始训练前,我们会使用zero_grad()把梯度清零,第一次反向传播后, x.grad会包含当前损失函数关于x的梯度,如果第二次反向传播前不清零,则新的梯度会累加x.grad,即包含两次反向传播的梯度之和
  3. grad_fn:记录了创建tensor的操作,表明该tensor是通过何种前向传播计算出来的,例如x是通过加法计算出来的,则x.grad_fn指向一个加法操作的函数对象

2.2 举例说明

依然以上述例子为例,我先举一个静态图的例子,是在TensorFlow 1.x 中,需要提前定义好整个计算图,然后在会话中执行:

import tensorflow as tf# 定义输入和参数
a = tf.placeholder(tf.float32, shape=[])
b = tf.placeholder(tf.float32, shape=[])
c = tf.placeholder(tf.float32, shape=[])
d = tf.placeholder(tf.float32, shape=[])# 前向传播,此处无法查看,只能在Session中查看
x = a + b
y = c * d
z = x + y# 定义损失函数
L = z# 定义梯度
grad_a, grad_b, grad_c, grad_d = tf.gradients(L, [a, b, c, d])# 创建会话
with tf.Session() as sess:# 前向传播print("Starting forward pass...")x_val, y_val, z_val = sess.run([x, y, z], feed_dict={a: 1.0, b: 2.0, c: 3.0, d: 4.0})print(f"x: {x_val}, grad_fn: N/A")print(f"y: {y_val}, grad_fn: N/A")print(f"z: {z_val}, grad_fn: N/A")# 打印前向传播的结果print(f"x: {x_val}")  # 3print(f"y: {y_val}")  # 12print(f"z: {z_val}")  # 15# 反向传播print("Starting backward pass...")grad_a_val, grad_b_val, grad_c_val, grad_d_val = sess.run([grad_a, grad_b, grad_c, grad_d], feed_dict={a: 1.0, b: 2.0, c: 3.0, d: 4.0})# 打印梯度print(f"a.grad: {grad_a_val}")  # 1print(f"b.grad: {grad_b_val}")  # 1print(f"c.grad: {grad_c_val}")  # 4print(f"d.grad: {grad_d_val}")  # 3

接下来举一个动态图的例子
pytorch动态图例子:

import torch# 定义输入和参数
a = torch.tensor([1.0], requires_grad=True)
b = torch.tensor([2.0], requires_grad=True)
c = torch.tensor([3.0], requires_grad=True)
d = torch.tensor([4.0], requires_grad=True)# 前向传播,在定义操作中,就可以查看梯度和操作,不像静态图中,必须在Session中才能查看
print("Starting forward pass...")
x = a + b
print(f"x: {x.item()}, grad_fn: {x.grad_fn}")
# 输出结果:x: 3.0, grad_fn: <AddBackward0 object at 0x000002113E68B790>y = c * d
print(f"y: {y.item()}, grad_fn: {y.grad_fn}")
# 输出结果:y: 12.0, grad_fn: <MulBackward0 object at 0x000002113D43C070>z = x + y
print(f"z: {z.item()}, grad_fn: {z.grad_fn}")
# z: 15.0, grad_fn: <AddBackward0 object at 0x000002113E68B790># 打印前向传播的结果
print(f"x: {x.item()}")  # 3
print(f"y: {y.item()}")  # 12
print(f"z: {z.item()}")  # 15
'''
输出结果:
x: 3.0
y: 12.0
z: 15.0
'''# 假设损失函数 L = z
L = z# 反向传播
print("Starting backward pass...")
L.backward()# 打印梯度
print(f"a.grad: {a.grad.item()}")  # 1
print(f"b.grad: {b.grad.item()}")  # 1
print(f"c.grad: {c.grad.item()}")  # 4
print(f"d.grad: {d.grad.item()}")  # 3
'''
输出结果:.
a.grad: 1.0
b.grad: 1.0
c.grad: 4.0
d.grad: 3.0
'''

3 利用动态图机制,修改流程

我们这里希望根据 a a a的值来决定计算流程,详情看注释

import torch# 定义输入和参数
a = torch.tensor([1.0], requires_grad=True)
b = torch.tensor([2.0], requires_grad=True)# 采用简洁的代码,可以动态的修改计算流程,控制更灵活
if a.item() > 0:x = a + bprint(f"x: {x.item()}")  # 3
else:x = a - b# 前向传播
y = x * 2
print(f"y: {y.item()}")  # 6# 反向传播
y.backward()# 打印梯度
print(f"a.grad: {a.grad.item()}")  # 2
print(f"b.grad: {b.grad.item()}")  # 2

如果是在静态图中需要通过控制流来决定,与动态图相比,操作更复杂,不直观

# 条件分支
x = tf.cond(a > 0, lambda: a + b, lambda: a - b)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/55147.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

详解代理模式-【静态代理与JDK动态代理】(非常的斯国一)

目录 静态代理 什么是静态代理: ​ 特点: 例子&#xff1a; JDK动态代理&#xff08;主要讲点&#xff09; 大纲&#xff1a; 1、与静态代码的联系 2、JDK动态代理的主流程 3、Proxy的源码 整体概述&#xff1a; 重要点的翻译 &#xff1a; newProxyInstance源码&am…

深信服2025届全球校招研发笔试-C卷(AK)

前面14个填空题 T1 已知 子数组 定义为原数组中的一个连续子序列。现给定一个正整数数组 arr&#xff0c;请计算该数组内所有可能的奇数长度子数组的数值之和。 输入描述 输入一个正整数数组arr 输出描述 所有可能的奇数长度子数组的和 示例 1 输入 1,4,2,5,3 输出 58 说明 …

使用 Light Chaser 进行大屏数据可视化

引言 在当今数据驱动的世界中&#xff0c;数据可视化变得越来越重要。Light Chaser 是一款基于 React 技术栈的大屏数据可视化设计工具&#xff0c;通过简单的拖拽操作&#xff0c;你可以快速生成漂亮、美观的数据可视化大屏和看板。本文将介绍如何使用 Light Chaser 进行数据…

828华为云征文|部署在线文档应用程序 CodeX Docs

828华为云征文&#xff5c;部署在线文档应用程序 CodeX Docs 一、Flexus云服务器X实例介绍二、Flexus云服务器X实例配置2.1 重置密码2.2 服务器连接2.3 安全组配置2.4 Docker 环境搭建 三、Flexus云服务器X实例部署 CodeX Docs3.1 CodeX Docs 介绍3.2 CodeX Docs 部署3.3 CodeX…

RabbitMQ应用

RabbitMQ 共提供了7种⼯作模式, 进⾏消息传递 一、七种模式的概述 1、Simple(简单模式) P:生产者,就是发送消息的程序 C:消费者,就是接收消息的程序 Queue:消息队列,类似⼀个邮箱, 可以缓存消息; ⽣产者向其中投递消息, 消费者从其中取出消息 特点: ⼀个⽣产者P,⼀…

小米2025届软件开发工程师(C/C++/Java)(编程题AK)

选择题好像也是25来个 编程题 T1 题目描述 小明喜欢解决各种数学难题。一天&#xff0c;他遇到了一道有趣的题目:他需要帮助他的朋友们完成一个排序任务。小明得到两个长度为 n 的数组a[]和b[]。他可以在两个数组对应位置进行交换&#xff0c;即选定一个位置 i &#xff0c…

《pyqt+open3d》open3d可视化界面集成到qt中

《pyqtopen3d》open3d可视化界面集成到qt中 一、效果显示二、代码三、资源下载 一、效果显示 二、代码 参考链接 main.py import sys import open3d as o3d from PyQt5.QtWidgets import QApplication, QMainWindow, QWidget from PyQt5.QtGui import QWindow from PyQt5.Qt…

App模拟下载场景的demo

摘要 目的&#xff1a;提供一个稳定的下载场景&#xff0c;可以手动触发和定时触发下载&#xff0c;每次下载相同大小文件&#xff0c;研究下载场景的功耗影响 原理&#xff1a;把电脑当做服务器&#xff0c;手机测试App固定下载电脑存放的某个XXXMB的大文件&#xff0c;基于…

C语言进阶版第14课—内存函数

文章目录 1. memcpy函数的使用和模拟实现1.1 memcpy函数的使用1.2 模拟实现memcpy函数 2. memmove函数的使用和模拟实现2.1 memmove函数的使用2.2 memmove函数的模拟实现 3. memset函数4. memcmp函数 1. memcpy函数的使用和模拟实现 1.1 memcpy函数的使用 memcpy函数的原形voi…

英语音标与重弱读

英语中&#xff0c;比较重要的是音标。但事实上&#xff0c;我们对音标的学习还是比较少的&#xff0c;对它的理解也是比较少的。 一、音标 2个半元音 [w][j] 5个长元音&#xff1a;[i:] [ə:] [ɔ:] [u:] [ɑ:] 7个短元音&#xff1a;[i] [ə] [ɔ] [u] [] [e] [ʌ] 8个双元音…

第100+26步 ChatGPT学习:概率校准 Bayesian Binning into Quantiles

基于Python 3.9版本演示 一、写在前面 最近看了一篇在Lancet子刊《eClinicalMedicine》上发表的机器学习分类的文章&#xff1a;《Development of a novel dementia risk prediction model in the general population: A large, longitudinal, population-based machine-learn…

【2】图像视频的加载和显示

文章目录 【2】图像视频的加载和显示一、代码在哪写二、创建和显示窗口&#xff08;一&#xff09;导入OpenCV的包cv2&#xff08;二&#xff09;创建窗口&#xff08;三&#xff09;更改窗口大小 & 显示窗口&#xff08;四&#xff09;等待用户输入补充&#xff1a;ord()函…

【Unity踩坑】Unity更新Google Play结算库

一、问题描述&#xff1a; 在Google Play上提交了app bundle后&#xff0c;提示如下错误。 我使用的是Unity 2022.01.20f1&#xff0c;看来用的Play结算库版本是4.0 查了一下文档&#xff0c;Google Play结算库的维护周期是两年。现在需要更新到至少6.0。 二、更新过程 1. 下…

【视频目标分割-2024CVPR】Putting the Object Back into Video Object Segmentation

Cutie 系列文章目录1 摘要2 引言2.1背景和难点2.2 解决方案2.3 成果 3 相关方法3.1 基于记忆的VOS3.2对象级推理3.3 自动视频分割 4 工作方法4.1 overview4.2 对象变换器4.2.1 overview4.2.2 Foreground-Background Masked Attention4.2.3 Positional Embeddings 4.3 Object Me…

cpp,git,unity学习

c#中的? 1. 空值类型&#xff08;Nullable Types&#xff09; ? 可以用于值类型&#xff08;例如 int、bool 等&#xff09;&#xff0c;使它们可以接受 null。通常&#xff0c;值类型不能为 null&#xff0c;但是通过 ? 可以表示它们是可空的。 int? number null; // …

使用 SSH 连接 Docker 服务器:IntelliJ IDEA 高效配置与操作指南

使用 SSH 连接 Docker 服务器&#xff1a;IntelliJ IDEA 高效配置与操作指南 本文详细介绍了如何在 2375 端口未开放的情况下&#xff0c;通过 SSH 连接 Docker 服务器并在 Idea 中进行开发。通过修改用户权限、生成密钥对以及配置 SSH 访问&#xff0c;用户可以安全地远程操作…

NASA数据集:ATLAS/ICESat-2 L3B 每日和每月网格化海冰自由面高度,第 4 版

目录 简介 摘要 代码 引用 网址推荐 0代码在线构建地图应用 机器学习 ATLAS/ICESat-2 L3B Daily and Monthly Gridded Sea Ice Freeboard, Version 4 简介 ATLAS/ICESat-2 L3B Daily and Monthly Gridded Sea Ice Freeboard, Version 4数据是由NASA的ATLAS&#xff08…

Ubuntu 系统崩了,如何把数据拷下来

问题描述&#xff1a; Linux系统中安装输入法后&#xff0c;重启后&#xff0c;导致系统无法进入&#xff0c;进入 recovery mode下的resume 也启动不了&#xff0c;所以决定将需要的东西复制到U盘 解决方案&#xff1a; 1.重启ubuntu&#xff0c;随即点按Esc进入grub菜单&am…

OpenStack Yoga版安装笔记(十五)Horizon安装

1、官方文档 OpenStack Installation Guidehttps://docs.openstack.org/install-guide/ 本次安装是在Ubuntu 22.04上进行&#xff0c;基本按照OpenStack Installation Guide顺序执行&#xff0c;主要内容包括&#xff1a; 环境安装 &#xff08;已完成&#xff09;OpenStack…

HarmonyOS/OpenHarmony Audio 实现音频录制及播放功能

关键词&#xff1a;audio、音频录制、音频播放、权限申请、文件管理 在app的开发过程中时常会遇见一些需要播放一段音频或进行语音录制的场景&#xff0c;那么本期将介绍如何利用鸿蒙 audio 模块实现音频写入和播放的功能。本次依赖的是 ohos.multimedia.audio 音频管理模块&am…