pytorch的backward()的底层实现逻辑

自动微分是一种计算张量(tensors)的梯度(gradients)的技术,它在深度学习中非常有用。自动微分的基本思想是:

  • 自动微分会记录数据(张量)和所有执行的操作(以及产生的新张量)在一个由函数(Function)对象组成的有向无环图(DAG)中。在这个图中,叶子节点是输入张量,根节点是输出张量。通过从根节点到叶子节点追踪这个图,可以使用链式法则(chain rule)自动地计算梯度。
  • 在前向传播(forward pass)中,自动微分同时做两件事:
    • 运行请求的操作来计算一个结果张量,以及
    • 在 DAG 中保留操作的梯度函数。  
    • 在 DAG 中保留操作的梯度函数,这就是说,当你给自动微分一个张量和一个操作,它不仅会计算出结果张量,还会记住这个操作的梯度函数,也就是这个操作对输入张量的导数。例如,如果你给自动微分一个张量 x = [1, 2, 3] 和一个操作 y = x + 1,它不仅会计算出 y = [2, 3, 4],还会记住这个操作的梯度函数是 dy/dx = 1,也就是说,y 对 x 的导数是 1。这样,当你需要计算梯度时,自动微分就可以根据这个梯度函数来计算出结果张量对输入张量的梯度。
  • 在PyTorch中,DAG是动态的。需要注意的一点是,图是从头开始重新创建的;在每个 .backward() 调用之后,autograd开始填充一个新的图。
  • 后向传播开始于当在 DAG 的根节点上调用 .backward() 方法。这个方法会触发自动微分开始计算梯度。
  • 自动微分会从每个 .grad_fn 中计算梯度,这个 .grad_fn 是一个函数对象,它保存了操作的梯度函数。例如,如果一个操作是 y = x + 1,那么它的 .grad_fn 就是 dy/dx = 1。
  • 自动微分会将计算出的梯度累加到相应张量的 .grad 属性中,这个 .grad 属性是一个张量,它保存了结果张量对输入张量的梯度。例如,如果一个结果张量是 y = [2, 3, 4],那么它的 .grad 属性就是 [1, 1, 1],表示 y 对 x 的梯度是 1。
  • 使用链式法则(chain rule),自动微分会一直向后传播,直到到达叶子张量。链式法则是一种数学公式,它可以将复合函数的梯度分解为简单函数的梯度的乘积。例如,如果一个复合函数是 z = f(g(x)),那么它的梯度是 dz/dx = dz/dg * dg/dx。

import torch
import torch.nn as nn
M = nn.Linear(2, 2) # neural network module
M.eval() # set M to evaluation mode
with torch.no_grad(): # disable gradient computationfor param in M.parameters(): # loop over all parametersparam.fill_(1) # fill the parameter with 1
M.requires_grad_(False)a = torch.tensor([1., 2.], requires_grad=True) # leaf node
b = torch.tensor([13., 32.], requires_grad=True) # leaf node
c = M(a) # non-leaf node
c2 = M(b) # non-leaf node
d = c * 2  # non-leaf node
d.sum().backward() # compute gradients
print(a.grad)
print(b.grad)
print(c.grad)
print(d.grad)
print(M.weight.grad) # None

构建计算图:当我们调用backward()方法时,PyTorch会自动构建从叶子节点a到损失值d.sum()的计算图,这是一个有向无环图,表示了各个张量之间的运算关系。计算图中还包含了两个中间变量c和d,它们是由a经过M模型的前向传播得到的。计算图的作用是记录反向传播的路径,以便于计算梯度。 计算梯度:在计算图中,每个张量都有一个属性grad,用于存储它的梯度值。当我们调用backward()方法时,PyTorch会沿着计算图按照链式法则计算并填充每个张量的grad属性。由于我们只对叶子节点a的梯度感兴趣,所以只有a的grad属性会被计算出来,而中间变量c和d的grad属性会被忽略。a的grad属性的值是损失值d.sum()对a的偏导数,表示了a的变化对损失值的影响。 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/151813.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

基于梯度算法优化概率神经网络PNN的分类预测 - 附代码

基于梯度算法优化概率神经网络PNN的分类预测 - 附代码 文章目录 基于梯度算法优化概率神经网络PNN的分类预测 - 附代码1.PNN网络概述2.变压器故障诊街系统相关背景2.1 模型建立 3.基于梯度优化的PNN网络5.测试结果6.参考文献7.Matlab代码 摘要:针对PNN神经网络的光滑…

细说MySQL数据类型

TOC 目录 MySQL数据类型 数据类型分类 数值类型 tinyint类型 有符号tinyint范围测试 无符号tinyint范围测试 bit类型 bit类型的显示方式 bit类型的范围测试 float类型 有符号float范围测试 无符号float范围测试 decimal类型 字符串类型 char类型 char类型测试 …

Ubuntu 18.04/20.04 LTS 操作系统设置静态DNS

1、nano /etc/systemd/resolved.conf 2、修改配置 使用DNS服务器:223.5.5.5 223.6.6.6 [Resolve] DNS223.5.5.5 223.6.6.6 3、重启服务 systemctl restart systemd-resolved.service 4、查看解析文件 cat /run/systemd/resolve/resolv.conf # This file is man…

Jmeter 如何监控目标服务的系统资源

下载Jmeter插件管理下载 perfmon 将这个插件管理放到Jmeter的\lib\ext目录下 然后重启Jmeter jmeter-plugins-manager-1.10.jar 下载 perfmon插件 添加 io 内存 磁盘的监听 并且添加监听 在宿主机中安装代理监听程序 并启动 ServerAgent.tar.gz

数据结构-插入排序

插入排序 插入排序的三种常见方法: 直接插入排序、折半插入排序、希尔排序。 数据存储结构 因为我们是用的是C语言来实现算法,因此我们需要创建一个结构体,用来存放初始数据。 结构体定义如下: #define MAX 100 typedef int…

设计原则 | 开放封闭原则

一、开放封闭原则(OCP:Open-Closed Principle) 1、原理 软件实体(类、模块、函数等等)应该是可以扩展的,但是不可修改的。如果程序中的一处改动就会引发连锁反应,导致一些列相关模块的修改&…

GPT实战系列-P-Tuning本地化训练ChatGLM2等LLM模型,到底做了什么?(二)

GPT实战系列-如何使用P-Tuning本地化训练ChatGLM2等LLM模型?(二) 文章目录 GPT实战系列-1.训练参数配置传递2.训练前准备3.训练参数配置4.训练对象,seq2seq训练5.执行训练6.训练模型评估依赖数据集的预处理 P-Tuning v2 将 ChatGLM2-6B 模型需要微调的参…

MATLAB 嵌套switch语句||MATLAB while循环

MATLAB 嵌套switch语句 在 MATLAB 中嵌套 switch 语句是可能的,可以在 switch 一部分外嵌套 switch 语句序列。即使 case 常量的内部和外部的 switch 含有共同的值,也不算冲突出现。 MATLAB嵌套switch语句语法 嵌套switch语句的语法如下: s…

012 C++ AVL_tree

前言 本文将会向你介绍AVL平衡二叉搜索树的实现 引入AVL树 二叉搜索树虽可以缩短查找的效率,但如果数据有序或接近有序普通的二叉搜索树将退化为单支树,查找元素相当于在顺序表中搜索元素,效率低下。因此,两位俄罗斯的数学家G.M…

学习模拟简明教程【Learning to simulate】

深度神经网络是一项令人惊叹的技术。 有了足够的标记数据,他们可以学习为图像和声音等高维输入生成非常准确的分类器。 近年来,机器学习社区已经能够成功解决诸如对象分类、图像中对象检测和图像分割等问题。 上述声明中的加黑字体警告是有足够的标记数…

OpenHarmony源码下载

OpenHarmony源码下载 现在的 OpenHarmony 4.0 源码已经有了,在 https://gitee.com/openharmony 地址中,描述了源码获取的方式,但那是基于 ubuntu 或者说是 Linux 的下载方式。在 windows 平台下的下载方式没有做出介绍。 我自己尝试了 wind…

PCIe协议加持,SD卡9.1规范达到媲美SSD的速度4GB/s

近日,SD协会(SDA)宣布了最新的SD Express存储卡的进化,将microSD Express存储卡的速度提高了一倍,达到2GB/s,并引入了4个新的SD Express速度等级,以确保新的SD 9.1规范中最低的顺序性能水平。这…

【Qt开发流程】之HelloWorld程序

【Qt开发流程】之HelloWorld程序 目的编写程序新建项目文件说明及界面设计 程序运行及发布程序运行程序发布手动构建使用windeployqt进行构建 设置应用程序图标修改快捷键类型列表命令行编译程序命令行编译.ui文件自定义类项目模式及项目文件介绍项目模式项目文件 目的 从Hell…

【Linux 源码阅读记录】设备树解析 of 相关代码

前言 最近移植接触 Linux 的设备树解析相关的代码,对 Linux of (open firmware)设备树解析代码比较感兴趣。 可以通过阅读Linux 大量的优秀代码,增强一些编程与编码的技巧与经验 切入点 of_device_is_available :设…

通过bat脚本控制Oracle服务启动停止

1、将Oracle服务全部设置为手动启动 初始安装Oracle之后服务启动状态: 2、服务功能介绍 3、构建服务启动/停止bat脚本 注意:编码选择ANSI(如果编码不是ANSI运行脚本会显示乱码) echo off :main cls echo 注:请保证该脚本是使用管理员权限…

Iceberg学习笔记(1)—— 基础知识

Iceberg是一个面向海量数据分析场景的开放表格式(Table Format),其设计的目的是解决数据存储和计算引擎之间的适配的问题 表格式(Table Format)可以理解为元数据以及数据文件的一种组织方式,处于计算框架&…

c++ binary_semaphore 使用详解

c binary_semaphore 使用详解 std::binary_semaphore c20 头文件 #include <semaphore>。作用&#xff1a;二进制信号量&#xff0c;轻量同步元件&#xff0c;能控制对共享资源的互斥访问。 std::binary_semaphore 成员函数 release&#xff1a;增加内部计数器并解除…

Java —— 抽象类和接口

目录 1. 抽象类 1.1 抽象类概念 1.2 抽象类语法与特性 1.3 抽象类的作用 2. 接口 2.1 接口的概念 2.2 接口的语法规则与特性 2.3 实现多个接口(解决多继承的问题) 2.4 接口间的继承 2.5 抽象类和接口的区别 2.6 接口的使用实例 2.7 Clonable 接口和深拷贝 2.7.1 Cloneable接口 …

探索arkui(2)--- 布局(列表)--- 1(列表数据的展示)

前端开发布局是指前端开发人员宣布他们开发的新网站或应用程序正式上线的活动。在前端开发布局中&#xff0c;开发人员通常会展示新网站或应用程序的设计、功能和用户体验&#xff0c;并向公众宣传新产品的特点和优势。前端开发布局通常是前端开发领域的重要事件&#xff0c;吸…

Apache Airflow (八) :DAG任务依赖设置

&#x1f3e1; 个人主页&#xff1a;IT贫道_大数据OLAP体系技术栈,Apache Doris,Clickhouse 技术-CSDN博客 &#x1f6a9; 私聊博主&#xff1a;加入大数据技术讨论群聊&#xff0c;获取更多大数据资料。 &#x1f514; 博主个人B栈地址&#xff1a;豹哥教你大数据的个人空间-豹…