【PyTorch单点知识】自动求导机制的原理与实践

文章目录

      • 0. 前言
      • 1. 自动求导的基本原理
      • 2. PyTorch中的自动求导
        • 2.1 创建计算图
        • 2.2 反向传播
        • 2.3 反向传播详解
        • 2.4 梯度清零
        • 2.5 定制自动求导
      • 3. 代码实例:线性回归的自动求导
      • 4. 结论

0. 前言

按照国际惯例,首先声明:本文只是我自己学习的理解,虽然参考了他人的宝贵见解及成果,但是内容可能存在不准确的地方。如果发现文中错误,希望批评指正,共同进步。

在深度学习中,自动求导(Automatic Differentiation, AD)是一项至关重要的技术,它使我们能够高效地计算神经网络的梯度,进而通过反向传播算法更新权重。

PyTorch作为一款动态计算图的深度学习框架,以其灵活性和易用性著称,其自动求导机制是其实现高效、灵活训练的核心。本文将深入探讨PyTorch中的自动求导机制,从原理到实践,通过代码示例来展示其工作流程。

如果对计算图不太了解,可以参考我的往期文章:基于TorchViz详解计算图(附代码)

1. 自动求导的基本原理

自动求导是一种数学方法,用于计算函数的导数。与数值微分相比,自动求导能够提供精确的导数计算结果,同时避免了符号微分中可能出现的手动求导错误。在深度学习中,我们通常关注的是反向模式backward的自动求导,即从输出向输入方向传播梯度的过程。

反向模式自动求导基于链式法则,它允许我们将复杂的复合函数的导数分解成多个简单函数的导数的乘积。在神经网络中,每一层都可以看作是一个简单的函数,通过链式法则,我们可以从前向传播的输出开始,逆向计算每个参数的梯度。

2. PyTorch中的自动求导

PyTorch通过其autograd模块实现了自动求导机制。autograd记录了所有的计算步骤,创建了一个计算图(Computational Graph),并在需要时执行反向传播,计算梯度。

2.1 创建计算图

在PyTorch中,当一个张量(Tensor)的requires_grad=True时,任何对该张量的操作都会被记录在计算图中。例如:

import torchx = torch.ones(2, 2, requires_grad=True)
y = x + 2
z = y * y * 3
out = z.mean()print(y.grad_fn)  # 查看y的计算节点
print(z.grad_fn)  # 查看z的计算节点

输出为:

<AddBackward0 object at 0x000001CADEC6AB60>
<MulBackward0 object at 0x000001CADEC6AB60>

在上述代码中,z的计算节点显示了z是如何由y计算得来的,而y的计算节点则显示了y是如何由x计算得来的。这样就形成了一个计算图。

2.2 反向传播

一旦我们完成了前向传播并得到了最终的输出,就可以调用out.backward()来进行反向传播,计算梯度。例如:

import torchx = torch.ones(2, 2, requires_grad=True)
y = x + 2
z = y * y * 3
out = z.mean()out.backward()
print(x.grad)

这里,x.grad就是out相对于x的梯度。输出为:

tensor([[4.5000, 4.5000],[4.5000, 4.5000]])
2.3 反向传播详解

下面我们来详细分析下1.2节的具体计算过程:

  1. 首先,创建了一个2x2的张量x,其值全为1,并且设置了requires_grad=True,这意味着PyTorch将会追踪这个张量上的所有操作,以便能够计算梯度。
x = torch.ones(2, 2, requires_grad=True)
  1. 然后,将x与2相加得到y
y = x + 2

此时y的值为:

tensor([[3., 3.],[3., 3.]])
  1. 接下来,将y的每个元素平方再乘以3得到z
z = y * y * 3

此时z的值为:

tensor([[27., 27.],[27., 27.]])
  1. 计算z的平均值作为输出out
out = z.mean()

此时out的值为:

tensor(27.)
  1. 使用backward()函数对out进行反向传播,计算梯度:
out.backward()
  1. 最后,打印x的梯度:
print(x.grad)

由于out是通过一系列操作从x得到的,我们可以根据链式法则计算出x的梯度。具体来说,out相对于x的梯度可以通过以下步骤计算得出:

  • out相对于z的梯度是1/z.size(0)(因为z.mean()是对z的所有元素取平均),这里z.size(0)等于4,所以out相对于z的梯度是1/4
  • z相对于y的梯度是y * 3 * 2(因为z = y^2 * 3,所以dz/dy = 2*y*3)。
  • y相对于x的梯度是1(因为y = x + 2,所以dy/dx = 1)。

综合以上,out相对于x的梯度是:

1/4 * (y * 3 * 2) * 1

由于y的值为[[3, 3], [3, 3]],那么上述梯度计算结果为:

1/4 * (3 * 3 * 2) * 1 = 9/2 = 4.5

因此,最终x.grad的值为:

tensor([[4.5000, 4.5000],[4.5000, 4.5000]])
2.4 梯度清零

在多次迭代中,梯度会累积在张量中,因此在每次迭代开始之前,我们需要调用optimizer.zero_grad()来清零梯度,防止梯度累积。(PyTorch为了训练方便,会默认梯度累积)

2.5 定制自动求导

PyTorch还允许我们定义自己的自动求导函数,通过继承torch.autograd.Function类并重写forwardbackward方法。这为实现更复杂的计算提供了可能。

3. 代码实例:线性回归的自动求导

接下来,我们将通过一个简单的线性回归问题,演示PyTorch自动求导机制的实际应用。

假设我们有一组数据点,我们想找到一条直线(y = wx + b),使得这条直线尽可能接近这些数据点。我们的目标是最小化损失函数(例如均方误差)。

import torch
import numpy as np# 准备数据
np.random.seed(0)
X = np.random.rand(100, 1)
Y = 2 + 3 * X + 0.1 * np.random.randn(100, 1)X = torch.from_numpy(X).float()
Y = torch.from_numpy(Y).float()# 初始化权重和偏置
w = torch.tensor([1.], requires_grad=True)
b = torch.tensor([1.], requires_grad=True)# 定义模型和损失函数
def forward(x):return w * x + bloss_fn = torch.nn.MSELoss()# 训练循环
learning_rate = 0.01
for epoch in range(1000):# 前向传播y_pred = forward(X)# 计算损失loss = loss_fn(y_pred, Y)# 反向传播loss.backward()# 更新权重with torch.no_grad():w -= learning_rate * w.gradb -= learning_rate * b.grad# 清零梯度w.grad.zero_()b.grad.zero_()if (epoch+1) % 100 == 0:print(f'Epoch [{epoch+1}/1000], Loss: {loss.item():.4f}')print('Final weights:', w.item(), 'bias:', b.item())

输出:

Epoch [100/1000], Loss: 0.1273
Epoch [200/1000], Loss: 0.0782
Epoch [300/1000], Loss: 0.0620
Epoch [400/1000], Loss: 0.0497
Epoch [500/1000], Loss: 0.0404
Epoch [600/1000], Loss: 0.0332
Epoch [700/1000], Loss: 0.0277
Epoch [800/1000], Loss: 0.0235
Epoch [900/1000], Loss: 0.0203
Epoch [1000/1000], Loss: 0.0179
Final weights: 2.68684983253479 bias: 2.17771577835083

在这个例子中,我们首先准备了一些随机生成的数据,然后初始化了权重w和偏置b。在训练循环中,我们通过前向传播计算预测值,使用均方误差损失函数计算损失,然后通过调用loss.backward()进行反向传播,最后更新权重和偏置。通过多次迭代,我们最终找到了使损失最小化的权重和偏置。

4. 结论

PyTorch的自动求导机制是其强大功能的关键所在。通过autograd模块,PyTorch能够自动跟踪计算图并高效地计算梯度,这大大简化了深度学习模型的开发过程。本文通过理论解释和代码示例,深入探讨了PyTorch中的自动求导机制,希望读者能够从中获得对这一重要概念的深刻理解,并在实际项目中灵活运用。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/45854.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

游戏的无边框模式是什么?有啥用?

现在很多游戏的显示设置中&#xff0c;都有个比较特殊的选项“无边框”。小伙伴们如果尝试过&#xff0c;就会发现这个效果和全屏几乎一毛一样&#xff0c;于是就很欢快地用了起来&#xff0c;不过大家也许会发现&#xff0c;怎么和全屏比起来&#xff0c;似乎有点不够爽快&…

渲染引擎实践 - OSG引擎渲染一帧的过程

一&#xff1a;概述 经过前面两节的介绍&#xff0c;我们已经创建了窗口(OSG引擎窗口)和启动了渲染线程(OSG渲染线程)。当应用程序加载好模型数据以后&#xff0c;就开始正式的渲染了&#xff0c;那么本节分析下渲染一帧的过程&#xff0c;本文尽量做到简单&#xff0c;清晰&am…

uniapp编译成h5后接口请求参数变成[object object]

问题&#xff1a;uniapp编译成h5后接口请求参数变成[object object] 但是运行在开发者工具上没有一点问题 排查&#xff1a; 1&#xff1a;请求参数&#xff1a;看是否是在请求前就已经变成了[object object]了 结果&#xff1a; 一切正常 2&#xff1a;请求头&#xff1a;看…

AST反混淆实战:提升JavaScript代码的可读性与调试便利性

博客标题&#xff1a;AST反混淆&#xff1a;提升JavaScript代码的可读性与调试便利性 引言 JavaScript代码混淆是一种常见的保护源码的方法&#xff0c;但这也给代码的维护和调试带来了不小的挑战。抽象语法树&#xff08;AST&#xff09;提供了一种结构化的方式来分析和转换…

C语言实现数据结构B树

B树&#xff08;B-Tree&#xff09;是一种自平衡的树数据结构&#xff0c;它维护着数据的有序性&#xff0c;并允许搜索、顺序访问、插入、删除等操作都在对数时间内完成。B树广泛用于数据库和操作系统的文件系统中。 B树的基本特性 根节点&#xff1a;根节点至少有两个子节点…

平安好车主:“保”你车平安,“养”出好生活~

“小朋友 你是否有很多问号,为什么......”从出生到长大,不论我们身居何处,年岁几何,妈妈似乎总有嘱咐不完的话。小时候,总不能理解妈妈的话,只想摆脱唠叨,期盼快快长大。 如今,我们羽翼渐丰,已能驾驭人生,肩负起家庭的重任,但妈妈的话却依然从未落下。不过,此刻的我们,不仅能…

Gitea 仓库事件触发Jenkins远程构建

文章目录 引言I Gitea 仓库事件触发Jenkins远程构建1.1 Jenkins配置1.2 Gitea 配置引言 应用场景:项目部署 I Gitea 仓库事件触发Jenkins远程构建 Gitea支持用于仓库事件的Webhooks 1.1 Jenkins配置 高版本Jenkins需要关闭跨域限制和开启匿名用户访问 在Jenkins启动前加入…

Windows 32 汇编笔记(二):使用 MASM

一、Win32 汇编源程序的结构 ; Hello.asm ; 使用 Win32 ASM 写的 Hello, world 程序 ;>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>…

STM32入门开发操作记录(一)——新建工程

目录 一、课程准备1. 课程资料2. 配件清单3. 根目录 二、环境搭建三、新建工程1. 载入器件支持包2. 添加模块3. ST配置4. 外观设置5. 主函数文件 一、课程准备 1. 课程资料 本记录操作流程参考自b站视频BV1th411z7snSTM32入门教程-2023版 细致讲解 中文字幕&#xff0c;课程资…

柯桥韩语培训韩语学习力职场口语韩语中的职场黑话你知道几个?

生活中比较常用的&#xff0c;与职场生活有关的新造词有상사병, 직장살이, 무두절(無頭節)等。一起来看下他们的意思吧... 상사병 상사병是指因为上司多变不定的指示而火大的意思。 직장살이 직장살이用来比喻职场生活也需要看上司的脸色&#xff0c;就像在婆家看婆婆脸色一样…

gorm只查询某一些字段字段的方法Select, 和只查询某一字段方法 Pluck

gorm中默认是查询所有字段的&#xff0c; 如果我们只需要获取某些字段的值&#xff0c;可以通过使用 Select方法来指定要查询的字段来实现&#xff0c; 也可以通过定义一个需要字段的结构体来实现&#xff1b; 而如果我们只需要查询某一个字段的值就可以使用 Pluck方法来获取(这…

如何将已有的docker服务迁移至Kubernetes集群中

如何将已有的docker服务迁移至Kubernetes集群中 问题描述迁移思路准备工作迁移gitlab过程1. 创建namespace2. 创建gitlab需要的pv跟pvc3.创建gitlab的deployment4.创建gitlab的service5.创建gitlab的ingress6.将原来的gitlab容器中的数据打包7.恢复配置文件数据8.正式恢复数据 …

100个C++面试题

面试题1&#xff1a;变量的声明和定义有什么区别 为变量分配地址和存储空间的称为定义&#xff0c;不分配地址的称为声明。一个变量可以在多个地方声明&#xff0c;但是只在一个地方定义。加入extern修饰的是变量的声明&#xff0c;说明此变量将在文件以外或在文件后面部分定…

【刷题汇总 -- 删除公共字符、两个链表的第一个公共结点、mari和shiny】

C日常刷题积累 今日刷题汇总 - day0121、删除公共字符1.1、题目1.2、思路1.3、程序实现 -- 蛮力法1.4、程序实现 -- 哈希 2、两个链表的第一个公共结点2.1、题目2.2、思路2.3、程序实现 -- 对齐比对法2.4、程序实现 -- 公共端点路程法 3、mari和shiny3.1、题目3.2、思路3.3、程…

简述编辑 编译 和运行 java application的全过程

简述编辑 编译 和运行 java application的全过程 编辑、编译和运行Java应用程序通常涉及以下几个步骤&#xff1a; 1. 编辑 首先&#xff0c;你需要一个文本编辑器来编写Java源代码。常用的编辑器包括Eclipse、IntelliJ IDEA、Visual Studio Code、Notepad等。你可以创建一个新…

[python]基于yolov10+gradio目标检测演示系统设计

【设计介绍】 YOLOv10结合Gradio实现目标检测系统设计是一个结合了最新目标检测技术和快速部署框架的项目。下面将详细介绍这一系统的设计和实现过程。 一、YOLOv10介绍 YOLOv10是YOLO&#xff08;You Only Look Once&#xff09;系列的最新版本&#xff0c;由清华大学的研究…

RabbitMQ的工作模式

RabbitMQ的工作模式 Hello World 模式 #mermaid-svg-sbc2QNYZFRQYbEib {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-sbc2QNYZFRQYbEib .error-icon{fill:#552222;}#mermaid-svg-sbc2QNYZFRQYbEib .error-text{fi…

vivado GATED_CLOCK

门控时钟 使用GATED_CLOCK属性启用Vivado合成以执行门控转换 时钟。转换时钟门控逻辑&#xff0c;以便在可用时利用触发器启用引脚。这 优化可以消除时钟树上的逻辑&#xff0c;简化网表。 此RTL属性指示工具门控逻辑中的哪个信号是时钟。 该属性放置在作为时钟的信号或端口上。…

AI工具网站

AI网站&#xff1a; https://flowus.cn/kuhehe/share/05ff2af6-cccd-451a-99f3-e3334f8b405e https://chat.mynanian.top/list https://chat5.aiyunos.top https://share.wendabao.net https://sharedchat.cn/shared.html https://chat.tinycms.xyz:3002 https://chatnio.li…

vienna整流器的矢量分析

Vienna整流器使用六个二极管和六个IGBT&#xff08;或MOSFET&#xff09;组成&#xff0c;提供三个电平&#xff1a;正极电平&#xff08;P&#xff09;、中性点电平&#xff08;O&#xff09;和负极电平&#xff08;N&#xff09;。通过对功率管的控制&#xff0c;Vienna整流器…