【pytorch】自动求导机制

基础概念汇总

Tensor是 torch.autograd中的数据类型,主要用于封装 Tensor,进行自动求导。

  • grad : data的梯度
  • grad_fn : 创建 Tensor的 Function,是自动求导的关键
  • requires_grad:指示是否需要梯度
  • is_leaf : 指示是否是叶子结点

PyTorch张量可以记住它们来自什么运算以及其起源的父张量,并且提供相对于输入的导数链。你无需手动对模型求导:不管如何嵌套,只要你给出前向传播表达式,PyTorch都会自动提供该表达式相对于其输入参数的梯度。

当设置.requires_grad = True之后,在其上进行的各种操作就会被记录下来,它将开始追踪在其上的所有操作,从而利用链式法则进行梯度传播。任何以tensor为祖先的张量都可以访问从tensor到该张量所调用的函数链。如果这些函数是可微的(大多数PyTorch张量运算都是可微的),则导数的值将自动存储在参数张量的grad属性中。

完成计算后,可以调用.backward()来完成所有梯度计算。沿着整个函数链(即计算图)计算损失的导数。此Tensor的梯度将累积到.grad属性中。调用backward会导致导数值在叶节点处累积。所以将其用于参数更新后,需要将梯度显式清零

if params.grad is not None:params.grad.zero_()

但是,如果中间加载了不支持梯度的操作,就会发生梯度断流。这在自己写模型时候时常发生,会导致模型无法求导。例如,在求loss时使用pil、cv2的库,导致无法反向传播。后面即使手动打开也没有用,梯度流不能被中断。或者自己写了transform函数,调用官方不支持grad_fn的函数,也会导致这样的问题。

如果不想要被继续追踪,可以调用.detach()将其从追踪记录中分离出来,可以防止将来的计算被追踪,这样梯度就传不过去了。此外,还可以用with torch.no_grad()将不想被追踪的操作代码块包裹起来,这种方法在评估模型的时候很常用,因为在评估模型时,我们并不需要计算可训练参数(requires_grad=True)的梯度。

深入

autograd 机制

Autograd 是一种反向自动微分系统。从概念上讲, autograd 记录了一个图表,记录了创建的所有操作 执行操作时的数据。提供有向无环图 其叶子是输入张量,根是输出张量。

在内部,autograd 将该图表示为 Function 对象(真正的表达式),可以是 apply() 编辑计算结果 评估图表。计算前向传播时,autograd 同时执行请求的计算并构建图表 表示计算梯度的函数(.grad_fn 每个 torch.Tensor 的属性都是该图的入口点)。 当前向传递完成后,我们在 向后传递以计算梯度。

需要注意的重要一点是,该图每次都会从头开始重新创建 迭代,这正是允许使用任意 Python 控件的原因 流语句,可以改变图形的整体形状和大小 每次迭代。您不必先对所有可能的路径进行编码 启动培训——你跑什么,你就与众不同。

拓展torch

https://pytorch.org/docs/stable/notes/extending.html

想在模型中执行计算,请实现自定义函数 不可微分或依赖于非 PyTorch 库(例如 NumPy)。如果想让操作能够与其他操作链接并使用 autograd 引擎,就得使用自定义函数。

自定义函数也可用于提高性能和 内存使用情况:如果您使用 C++ 扩展, 您可以将它们包装在 Function 中以与 autograd 交互 引擎。如果您想减少为向后传递保存的缓冲区数量, 自定义函数可用于将操作组合在一起。

第 1 步:子类化Function后,您需要定义 3 个方法

forward() 是执行该操作的代码。它可以需要 你想要多少个参数,其中一些是可选的,如果你 指定默认值。这里接受所有类型的 Python 对象。 Tensor 跟踪历史记录的参数(即, requires_grad=True)将被转换为不跟踪历史记录的内容 在调用之前,它们的使用将被记录在图表中。请注意,这 逻辑不会遍历列表/字典/任何其他数据结构,只会 考虑作为调用的直接参数的张量。你可以 返回单个 Tensor 输出,或 tuple 张量(如果有多个输出)。另外,请参阅 Function 的文档来查找有用方法的描述,这些方法可以 仅从 forward() 调用。

setup_context()(可选)。人们可以写一个“组合”forward() 接受一个 ctx 对象或(从 PyTorch 2.0 开始)一个单独的 forward() 不接受 ctx 和发生 修改的 setup_context() 方法。 应该具有计算能力, 应该具有 只负责修改(并且不进行任何计算)。 一般来说,单独的 和 更接近于如何 PyTorch 本机操作可以工作,因此更适合与各种 PyTorch 子系统组合。 请参阅组合或单独的forward() 和setup_context()了解更多详情。ctxforward()setup_context()ctxforward()setup_context()

backward()(或vjp())定义渐变公式。 它将给出与输出一样多的 Tensor 参数,每个参数 其中代表梯度 w.r.t.那个输出。重要的是永远不要修改 这些就地。它应该返回尽可能多的张量 是输入,每个输入都包含梯度 w.r.t.它是 相应的输入。如果您的输入不需要梯度 (needs_input_grad 是一个布尔值元组,表示 每个输入是否需要梯度计算),或者是非Tensor 对象,您可以返回python:None。另外,如果您有可选的 forward() 的参数你可以返回比那里更多的梯度 都是输入,只要它们都是 None。

第 2 步:使用 ctx 中的功能 正确地确保新的 Function 能够正常工作 autograd 引擎。

save_for_backward() 必须是 用于保存向后传递中使用的任何张量。非张量应该 直接存储在ctx上。如果张量既不是输入也不是输出 保存为向后,您的 Function 可能不支持双向后 (参见步骤 3)。

mark_dirty()必须习惯于 标记由转发函数就地修改的任何输入。

mark_non_differentiable()必须 用于告诉引擎输出是否不可微。经过 默认所有可微分类型的输出张量都会被设置 要求梯度。不可微类型的张量(即整数类型) 从未被标记为需要渐变。

set_materialize_grads()可 用于告诉 autograd 引擎在以下情况下优化梯度计算 通过不具体化给予向后的梯度张量,输出不依赖于输入 功能。也就是说,如果设置为 False,则 Python 中的 None 对象或“未定义张量”(张量 x 为 C++ 中的 x.define() 为 False) 不会转换为先用零填充的张量 向后调用,因此您的代码将需要处理此类对象,就好像它们是 张量用零填充。此设置的默认值为 True。

Step 3:

If your Function does not support double backward you should explicitly declare this by decorating backward with the once_differentiable(). With this decorator, attempts to perform double backward through your function will produce an error. See our double backward tutorial for more information on double backward.

验证

使用torch.autograd.gradcheck() 检查你的后向函数是否正确计算了 通过使用后向函数计算雅可比矩阵来向前推进 将值按元素与使用数值计算的雅可比行列式进行比较 有限差分。

reference

https://github.com/ShusenTang/Deep-Learning-with-PyTorch-Chinese/blob/master/docs/chapter4/4.2.mdhttps://tianchi.aliyun.com/forum/post/336073
https://pytorch.org/docs/stable/notes/extending.html
https://pytorch.org/tutorials/advanced/cpp_extension.html
https://pytorch.org/docs/stable/notes/autograd.html
https://pytorch.org/docs/stable/generated/torch.autograd.Function.backward.html#torch.autograd.Function.backward

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/234327.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

MySQL主从复制与切换

1.主从架构及基本原理 常见主从部署架构:一主多从、一丛多主、双主复制、主从级联复制。 1.1主从复制原理 1.从节点开启start slave,开启主从复制,从节点IO线程与主节点建立连接,请求数据同步; 2.主节点接收到从节点…

将smiles转为图片都有什么包?rdkit、Open Babel、Indigo

除了 RDKit,还有几个其他的Python库可以用来将SMILES(Simplified Molecular Input Line Entry System)字符串转换为分子图片。这些库通常用于化学信息学和分子建模领域。一些常见的选项包括: rdkit Open Babel Open Babel 命令行…

【NeRF】内容准备

论文地址 项目主页 参考文章: NeRF入门之体渲染 NeRF:火爆科研圈的三维重建技术大揭秘 【三维重建】NeRF原理代码讲解 精选17篇神经辐射场(NeRF)高分论文分享!最新技术成果一次看完!2023/9/29 计划看一下NeRF相关的热…

Ansible的playbook脚本使用

本章注意介绍如何在ansible中写脚本 playbook的语法在写playbook时如何进行错误处理 ansible的许多模块都是在命令行中执行的,每次只能执行一个模块。如果需要执行多个模块,且要写判断语句,判断模块是否执行成功了,如果没成功会…

SpringBoot基于gRPC进行RPC调用

SpringBoot基于gRPC进行RPC调用 一、gRPC1.1 什么是gRPC?1.2 如何编写proto1.3 数据类型及对应关系1.4 枚举1.5 数组1.6 map类型1.7 嵌套对象 二、SpringBoot gRPC2.1 工程目录2.2 jrpc-api2.2.1 引入gRPC依赖2.2.2 编写 .proto 文件2.2.3 使用插件机制生产proto相关…

CWE-611

CWE-611,也称为“URL Redirection to Untrusted Site (‘Open Redirect’)”,是一种常见的Web应用程序安全漏洞。这种漏洞出现在应用程序接受用户提供的URL作为重定向参数,并在未经充分验证的情况下直接将用户重定向到该URL时。攻击者可以利用…

Java 基础学习(十四)Map集合与Set集合

1 Map集合 1.1 Map接口 1.1.1 Map接口概述 Map接口是一种双列集合。Map的每个元素都包含一个键对象Key和一个值对象Value ,键对象和值对象之间存在对应关系,这种关系称为映射(Mapping)。 Map接口中的元素,可以通过…

DC-6靶场

DC-6靶场下载: https://www.five86.com/downloads/DC-6.zip 下载后解压会有一个DC-3.ova文件,直接在vm虚拟机点击左上角打开-->文件-->选中这个.ova文件就能创建靶场,kali和靶机都调整至NAT模式,即可开始渗透 首先进行主…

【Transformer框架代码实现】

Transformer Transformer框架注意力机制框架导入必要的库Input Embedding / Out EmbeddingPositional EmbeddingTransformer EmbeddingScaleDotProductAttention(self-attention)MultiHeadAttention 多头注意力机制EncoderLayer 编码层Encoder多层编码块/前馈网络层…

【机器学习】密度聚类:从底层手写实现DBSCAN

【机器学习】Building-DBSCAN-from-Scratch 概念代码数据导入实现DBSCAN使用样例及其可视化 补充资料 概念 DBSCAN(Density-Based Spatial Clustering of Applications with Noise,具有噪声的基于密度的聚类方法)是一种基于密度的空间聚类算…

新手做抖店应该怎么做?应该注意些什么?踩坑避雷!

我是电商珠珠 新手做抖店,对于办理营业执照、选类目确定品,或是找达人这些,往往会在这上面吃很多亏。 我做抖店也已经三年了,关于抖店的玩法和规则这块也非常熟悉,这就来给大家讲讲我所踩的那些坑。 第一个&#xf…

自动化边坡监测设备是什么?

随着科技的不断进步,我们的生活和环境也在不断地发生变化。然而,自然灾害仍然是我们无法完全避免的风险。其中,边坡滑坡就是一种常见的自然灾害。为了保护人民的生命财产安全,科学家们研发出了自动化边坡监测设备。 WX-WY1 自动化…

C++基础-内存模型详解

目录 一、概述 二、内存分区模型分类 三、代码区 四、全局区 五、栈区

亚信安慧AntDB数据库引领中文信息处理标准化创新

近期,亚信科技旗下的AntDB数据库再获殊荣,成功通过GB 18030-2022《信息技术中文编码字符集》最高实现级别(级别3)的检测认证,成为首批达到该认证标准的数据库产品之一。这一认证不仅是对AntDB数据库卓越技术实力的肯定…

算法02哈希法

算法01之哈希法 1.哈希法理论基础1.1哈希表(1)哈希表(2)哈希函数(3)哈希碰撞 1.2哈希法基本思想1.3哈希法适用场景与最常用的哈希结构 2.LeetCode242:有效的字母异位词(1&#xff09…

《每天一分钟学习C语言·三》

1、 scanf的返回值由后面的参数决定scanf(“%d%d”,& a, &b); 如果a和b都被成功读入,那么scanf的返回值就是2如果只有a被成功读入,返回值为1如果a和b都未被成功读入,返回值为0 如果遇到错误或遇到end of file,返回值为EOF…

罗列一下js reduce 的能做的事情?

JavaScript 的 reduce 方法是一个非常强大的工具,可以用于处理数组数据。 以下是一些 reduce 可以做的事情: 1. 累加器:reduce 最常见的用途是将数组的所有元素累加到一个值中。例如,计算数组中所有数字的总和 const numbers …

ACE Tools环境配置指导

简介 ACE Tools是一套为ArkUI-X应用开发者提供的命令行工具,支持在Windows/Ubuntu/macOS平台运行,用于构建OpenHarmony、HarmonyOS、Android和iOS平台的应用程序, 其功能包括开发环境检查,新建项目,编译打包&#xff…

Debian系统设置SSH密钥登陆

如果没有安装ssh,root权限运行apt install openssh-server进行安装。 ssh-keygen -t rsa # 生成配对密钥,后续一路enter即可会在用户目录(即~这个)下生成.ssh文件夹,里面的id_rsa是私钥,id_rsa.pub是公钥…