PyTorch 的自动求导与计算图

在深度学习中,模型的训练过程本质上是通过梯度下降算法不断优化损失函数。为了高效地计算梯度,PyTorch 提供了强大的自动求导机制,这一机制依赖于“计算图”(Computational Graph)的概念。

1. 什么是计算图?

计算图是一种有向无环图(DAG),其中每个节点表示操作或变量,边表示数据的流动。简单来说,计算图是一个将复杂计算分解为一系列基本操作的图表。每个节点(通常称为“张量”)是一个数据单元,而边表示这些数据单元之间的计算关系。

例如,假设你有一个简单的函数 y = 2x + 1,这个函数可以表示为一个非常简单的计算图:

    x  ----->  2x  ----->  2x + 1

在这个图中,x 是一个输入张量,2x 是第一个操作节点,2x + 1 是第二个操作节点。PyTorch 会自动构建这个计算图,随着你对张量进行操作,图会动态扩展。

2. PyTorch 中的计算图

在 PyTorch 中,计算图是动态构建的。这意味着每次运行前向传播时,PyTorch 都会根据实际的操作构建计算图。这与其他静态图框架(如 TensorFlow 的早期版本)不同,后者需要先定义完整的图,然后再运行计算。

动态计算图的优点在于它灵活且易于调试。你可以在代码中使用 Python 的控制流(如条件语句、循环等),计算图会根据运行时的实际路径生成。

来看一个实际的例子:

import torch# 创建一个张量,并指定需要计算梯度
x = torch.tensor([1.0, 2.0, 3.0, 4.0], requires_grad=True)# 对张量进行操作
y = 2 * x + 1

在这段代码中,我们创建了一个名为 x 的张量,并通过 requires_grad=True 指定它是需要计算梯度的变量。这一步非常重要,因为只有 requires_grad 设置为 True 的张量,PyTorch 才会在计算图中跟踪它们的操作。

接下来,y = 2 * x + 1 执行了两个操作:先将 x 乘以 2,再加 1。每个操作都会在计算图中创建一个节点,表示计算的过程。这个计算图的结构可以描述为:

    x  ----->  2x  ----->  2x + 1

每个操作都被记录在计算图中,为反向传播过程做好准备。

3. 反向传播与梯度计算

当我们执行完前向计算后,接下来要做的就是通过反向传播计算梯度。梯度是指损失函数相对于输入变量的导数,用于指示在给定点处损失函数如何变化。

假设我们想计算 yx 的梯度。在 PyTorch 中,我们通过调用 backward() 方法来实现:

# 对 y 求和,然后执行反向传播
y.sum().backward()

y.sum() 是一个标量函数,将 y 的所有元素相加。这一步非常重要,因为在反向传播中,只有标量的梯度才能正确地传递。如果 y 不是标量,PyTorch 会对其进行求和,以确保反向传播的正确性。

执行 backward() 后,PyTorch 会自动计算 yx 的梯度,并将结果存储在 x.grad 中:

print(x.grad)  # 输出 [2.0, 2.0, 2.0, 2.0]

在这个例子中,dy/dx = 2,所以 x.grad 中的每个元素的值都是 2

4. 自动求导背后的数学原理

要理解自动求导,首先需要理解基本的微积分概念。导数反映了函数的变化率,是梯度下降算法的核心。

4.1 导数的概念

导数表示一个函数在某个点的瞬时变化率。如果你有一个简单的线性函数 y = 2x + 1,其导数是 2。这意味着,无论 x 的值是多少,y 的变化率都是常数 2

4.2 链式法则

链式法则是反向传播算法的基础。它告诉我们如何计算复合函数的导数。假设我们有两个函数 u = g(x)y = f(u),那么 yx 的导数可以通过链式法则计算:

dy/dx = (dy/du) * (du/dx)

在计算图中,链式法则对应于从输出节点到输入节点的梯度传递。每一步都遵循链式法则,将梯度从一层传递到下一层,最终计算出输入变量的梯度。

5. 复杂操作与控制流中的自动求导

PyTorch 的动态计算图不仅支持简单的操作,还可以处理更加复杂的操作和控制流。

5.1 非线性操作

非线性操作,如平方、指数运算等,使得计算图更加复杂。考虑下面的例子:

z = y ** 2  # z = (2x + 1) ^ 2

在这个例子中,计算图变为:

    x  ----->  2x  ----->  2x + 1  ----->  (2x + 1) ^ 2

此时,如果你对 z 进行反向传播,PyTorch 会首先计算 dz/dy,然后利用链式法则乘以 dy/dx,最终得到 dz/dx。通过调用 z.sum().backward(),你可以得到 zx 的梯度。

z.sum().backward()
print(x.grad)  # 输出 [12.0, 16.0, 20.0, 24.0]

在这里,x.grad 的值为 4x + 2,这就是 z = (2x + 1)^2x 的导数。

5.2 控制流中的求导

PyTorch 的自动求导机制同样可以处理控制流,比如条件语句和循环。对于动态计算图,控制流可以使得每次前向计算的图结构不同,但 PyTorch 依然能够正确计算梯度。

def my_func(a):if a.item() > 1:return a ** 2else:return a * 3x = torch.tensor(2.0, requires_grad=True)
y = my_func(x)
y.backward()
print(x.grad)  # 输出 4.0

在这个例子中,my_func 根据 a 的值执行不同的操作。如果 a > 1,则返回 a 的平方;否则,返回 a 的三倍。由于 x 的值为 2.0,所以计算的结果是 y = 4.0,而 yx 的导数为 4.0

6. 多变量函数的自动求导

在实际应用中,许多函数是多变量的。这时,PyTorch 同样可以计算每个变量的梯度。

x1 = torch.tensor(1.0, requires_grad=True)
x2 = torch.tensor(2.0, requires_grad=True)
y = x1 ** 2 + x2 ** 3
y.backward()

在这个例子中,yx1x2 的函数。调用 backward() 后,x1.gradx2.grad 将分别存储 yx1x2 的导数。

print(x1.grad)  # 输出 2.0
print(x2.grad)  # 输出 12.0

x1.grad 的值为 2 * x1 = 2.0,而 x2.grad 的值为 3 * x2^2 = 12.0

7. detach() 的用途与计算图的修改

在某些情况下,你可能不希望某个张量参与计算图的反向传播。detach() 函数可以从计算图中分离出一个张量,使得它在反向传播时不影响梯度的计算。

x = torch.tensor(3.0, requires_grad=True)
y = x ** 2
z = y.detach() * 2  # z 与 y 无关,反向传播时不计算 z 的梯度
z.backward()
print(x.grad)  # 输出 None

在这里,由于 z 是从 y 中分离出来的,反向传播时 x.grad 不会受到 z 的影响。

此外,with torch.no_grad() 也可以用于临时停止计算图的构建,通常用于模型推理阶段。

8. 实际应用:深度学习中的梯度更新

自动求导在深度学习中的一个典型应用是梯度更新。在训练过程中,模型的参数会通过反向传播计算梯度,并使用优化器(如 SGD、Adam 等)更新这些参数。PyTorch 的 torch.optim 模块提供了多种优化器,可以自动利用计算出的梯度进行参数更新。

import torch.optim as optim# 创建一个简单的线性模型
model = torch.nn.Linear(1, 1)
optimizer = optim.SGD(model.parameters(), lr=0.01)# 输入数据和目标
x = torch.tensor([[1.0], [2.0], [3.0]], requires_grad=True)
y_true = torch.tensor([[2.0], [4.0], [6.0]])# 前向传播
y_pred = model(x)# 计算损失
loss = torch.nn.functional.mse_loss(y_pred, y_true)# 反向传播
loss.backward()# 更新参数
optimizer.step()

在这段代码中,我们创建了一个简单的线性模型,并使用 MSE 作为损失函数。通过反向传播计算梯度后,优化器会自动更新模型的参数,使损失逐渐减小。

9. 总结

PyTorch 的自动求导机制是深度学习中非常重要且强大的工具。它基于计算图自动计算梯度,极大地简化了模型训练中的梯度计算过程。无论是简单的线性函数还是复杂的神经网络,PyTorch 都能通过动态计算图和自动求导机制高效地进行梯度计算和参数优化。在实际应用中,掌握这些基础知识可以帮助我们更好地理解和优化深度学习模型。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/51751.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

前胡基因组与伞形科香豆素的进化-文献精读42

The gradual establishment of complex coumarin biosynthetic pathway in Apiaceae 伞形科中复杂香豆素生物合成途径的逐步建立 羌活基因组--文献精读-36 摘要:复杂香豆素(CCs)是伞形科植物中的特征性代谢产物,具有重要的药用价…

深度学习与大模型第1课环境搭建

深度学习与大模型第1课 环境搭建 1. 安装 Anaconda 首先,您需要安装 Anaconda,这是一个开源的 Python 发行版,能够简化包管理和环境管理。以下是下载链接及提取码: 链接:https://pan.baidu.com/s/1Na2xOFpBXQMgzXA…

网络准入控制系统

当我们谈论网络准入控制系统时,我们谈论的并不是网络准入控制系统,而是安全,我们不能只囿于它表面的浮华而忘掉它的本质,记住,不管讨论什么,我们必须要有直达本质的能力。网络的本质就是安全。 网络准入控制…

TDesign 微信小程序组件库配置

文章目录 1.安装 npm 包2. 构建 npm3. 构建完成后即可使用 npm 包。4.修改 app.json5.修改 tsconfig.json6.使用组件 1.安装 npm 包 在小程序 package.json 所在的目录中执行命令安装 npm 包: npm install结果报错 PS C:\WeChatProjects\miniprogram-1> npm i…

vscode和edge浏览器等鼠标输入光标变透明

本人是AMD的APU会出现这种情况。它的gpu加速有些问题。 都是要关闭gpu硬件加速功能。edge浏览器好找。vscode是通过以下方法。 要关闭VSCode的硬件加速功能, ‌通过配置文件调整‌: 打开VSCode的设置(通过按下CtrlShiftP或CmdShiftP打开命令…

【Qt】窗口概述

Qt 窗口概述 Qt窗口是由QMianWindow类来实现的。 QMainWindow 是⼀个为⽤⼾提供主窗⼝程序的类,继承⾃ QWidget 类,并且提供了⼀个预定义的布局。QMainWindow 包含 ⼀个菜单栏(menu bar)、多个⼯具栏(tool bars)、多个浮动窗⼝&a…

安全入门day.03

一、知识点 1、抓包技术应用意义 在渗透安全方面,通过抓包分析,安全人员可以模拟黑客的攻击行为,对系统进行渗透测试。这种测试有助于发现系统中存在的安全漏洞和弱点。一旦发现漏洞,可以立即采取措施进行修复,从而增…

分享8个Python自动化实战脚本!

1. Python自动化实战脚本 1.1 网络自动化 网络上有丰富的信息资源,Python可以帮我们自动化获取这些信息。 爬虫简介:爬虫是一种自动提取网页信息的程序。Python有许多优秀的爬虫库,如requests和BeautifulSoup。 案例:使用Pytho…

8.26 T4 日记和编辑器(fhq维护kmp——kmp本身含有的单射与可合并性)

http://cplusoj.com/d/senior/p/NOD2301D 前4个操作拿fhq treap是很好维护的。 对于最后一个操作,我们可以这么思考,从kmp的匹配思路出发: 如果我们知道一个串进入的指针 j j j(也就是kmp匹配到的位置)&#xff0c…

IT 行业的就业情况

当前,IT 行业的就业情况呈现出以下特点: 1. 需求持续增长:随着数字化转型的加速,各个行业对信息技术的依赖程度不断提高,推动了对 IT 人才的持续需求。特别是在云计算、大数据、人工智能、物联网等新兴领域&#xff…

MySQL:复合查询

MySQL:复合查询 聚合统计分组聚合统计group byhaving 多表查询自连接子查询单行子查询多行子查询多列子查询from子查询 合并查询unionunion all 内连接外连接左外连接右外连接全外连接 视图 MySQL 复合查询是数据分析和统计的强大工具,本博客将介绍如何使…

【WiFi主要技术学习2】

WiFi协议学习2 WiFi SPEC理解频段信道带宽协商速率安全与加密WiFi主要技术理解BP直接序列扩频(Direct Sequence Spread Spectrum,DSSS)BPSKQPSK正交幅度调制(Quadrature Amplitude Modulation,QAM)互补码键控(Complementary Code Keying,CCK)正交频分复用(Orthogonal…

Global Illumination_LPV Deep Optimizations

接上回,RSM优化技术介绍后,我们本部分主要看一下,光栅GI三部曲中的LPV,这个算法算是很巧妙了,算法思路基于RSM上拓展到世界空间,可以说很具学习和思考价值,之前也简单实现过Global Illumination…

利用session.upload_progress执行文件包含

1.session.upload_progress的作用: session.upload_progress最初是PHP为上传进度条设计的一个功能,在上传文件较大的情况下,PHP将进行流式上传,并将进度信息放在Session中(包含用户可控的值),即…

Go 语言版本管理——Goenv

Go 语言版本管理——Goenv 命令安装 goenv安装和切换 Go 版本 goenv 是一个专门管理 Go 语言版本的工具。 命令 安装 goenv github-goenv git clone https://github.com/go-nv/goenv.git ~/.goenv echo export GOENV_ROOT"$HOME/.goenv" >> ~/.bash_profile…

CSAPP全书学习总结

CSAPP( 1.计算机系统漫游)学习笔记-CSDN博客 CSAPP(第二章 信息的表示和处理,附上datalab解析_datalab调整数据位置-CSDN博客 CSAPP (第三章:程序的机器级表示-CSDN博客

STM32嵌套向量中断控制器—NVIC

NVIC简介: NVIC,即Nested Vectored Interrupt Controller(嵌套向量中断控制器),是STM32中的中断控制器。它负责管理和协调处理器的中断请求,是STM32中处理异步事件的重要机制。 NVIC提供了灵活、高效、可扩…

基于ssm的实习课程管理系统/在线课程系统

实习课程管理系统 摘 要 互联网的快速发展,给各行各业带来不同程度的影响,悄然改变人们的生活、工作方式,也倒逼很多行业创新和变革,以适应社会发展的变化。人们为了能够更加方便地管理项目任务,实习课程管理系统被人们…

python-变量声明、数据类型、标识符

一.变量 1.什么是变量 为什么需要变量呢? 一个程序就是一个世界,不论使用哪种高级程序语言编写代码,变量都是其程序的基本组成单位。如下图所示的sum和sub都是变量。 变量的定义: 变量相当于内存中一个数据存储空间的表示&#…

C语言刷题日记(附详解)(3)

一、选填部分 第一题: 以下的变量定义语句中,合法的是( ) A. byte a 128; B. boolean b null; C. long c 123L; D. float d 0.9239; 思路提示:观察选项时不要马虎,思考一下各种类型变量的取值范围,以及其初始化的形式是…