pytorch的backward()的底层实现逻辑

自动微分是一种计算张量(tensors)的梯度(gradients)的技术,它在深度学习中非常有用。自动微分的基本思想是:

  • 自动微分会记录数据(张量)和所有执行的操作(以及产生的新张量)在一个由函数(Function)对象组成的有向无环图(DAG)中。在这个图中,叶子节点是输入张量,根节点是输出张量。通过从根节点到叶子节点追踪这个图,可以使用链式法则(chain rule)自动地计算梯度。
  • 在前向传播(forward pass)中,自动微分同时做两件事:
    • 运行请求的操作来计算一个结果张量,以及
    • 在 DAG 中保留操作的梯度函数。  
    • 在 DAG 中保留操作的梯度函数,这就是说,当你给自动微分一个张量和一个操作,它不仅会计算出结果张量,还会记住这个操作的梯度函数,也就是这个操作对输入张量的导数。例如,如果你给自动微分一个张量 x = [1, 2, 3] 和一个操作 y = x + 1,它不仅会计算出 y = [2, 3, 4],还会记住这个操作的梯度函数是 dy/dx = 1,也就是说,y 对 x 的导数是 1。这样,当你需要计算梯度时,自动微分就可以根据这个梯度函数来计算出结果张量对输入张量的梯度。
  • 在PyTorch中,DAG是动态的。需要注意的一点是,图是从头开始重新创建的;在每个 .backward() 调用之后,autograd开始填充一个新的图。
  • 后向传播开始于当在 DAG 的根节点上调用 .backward() 方法。这个方法会触发自动微分开始计算梯度。
  • 自动微分会从每个 .grad_fn 中计算梯度,这个 .grad_fn 是一个函数对象,它保存了操作的梯度函数。例如,如果一个操作是 y = x + 1,那么它的 .grad_fn 就是 dy/dx = 1。
  • 自动微分会将计算出的梯度累加到相应张量的 .grad 属性中,这个 .grad 属性是一个张量,它保存了结果张量对输入张量的梯度。例如,如果一个结果张量是 y = [2, 3, 4],那么它的 .grad 属性就是 [1, 1, 1],表示 y 对 x 的梯度是 1。
  • 使用链式法则(chain rule),自动微分会一直向后传播,直到到达叶子张量。链式法则是一种数学公式,它可以将复合函数的梯度分解为简单函数的梯度的乘积。例如,如果一个复合函数是 z = f(g(x)),那么它的梯度是 dz/dx = dz/dg * dg/dx。

import torch
import torch.nn as nn
M = nn.Linear(2, 2) # neural network module
M.eval() # set M to evaluation mode
with torch.no_grad(): # disable gradient computationfor param in M.parameters(): # loop over all parametersparam.fill_(1) # fill the parameter with 1
M.requires_grad_(False)a = torch.tensor([1., 2.], requires_grad=True) # leaf node
b = torch.tensor([13., 32.], requires_grad=True) # leaf node
c = M(a) # non-leaf node
c2 = M(b) # non-leaf node
d = c * 2  # non-leaf node
d.sum().backward() # compute gradients
print(a.grad)
print(b.grad)
print(c.grad)
print(d.grad)
print(M.weight.grad) # None

构建计算图:当我们调用backward()方法时,PyTorch会自动构建从叶子节点a到损失值d.sum()的计算图,这是一个有向无环图,表示了各个张量之间的运算关系。计算图中还包含了两个中间变量c和d,它们是由a经过M模型的前向传播得到的。计算图的作用是记录反向传播的路径,以便于计算梯度。 计算梯度:在计算图中,每个张量都有一个属性grad,用于存储它的梯度值。当我们调用backward()方法时,PyTorch会沿着计算图按照链式法则计算并填充每个张量的grad属性。由于我们只对叶子节点a的梯度感兴趣,所以只有a的grad属性会被计算出来,而中间变量c和d的grad属性会被忽略。a的grad属性的值是损失值d.sum()对a的偏导数,表示了a的变化对损失值的影响。 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/151813.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

基于梯度算法优化概率神经网络PNN的分类预测 - 附代码

基于梯度算法优化概率神经网络PNN的分类预测 - 附代码 文章目录 基于梯度算法优化概率神经网络PNN的分类预测 - 附代码1.PNN网络概述2.变压器故障诊街系统相关背景2.1 模型建立 3.基于梯度优化的PNN网络5.测试结果6.参考文献7.Matlab代码 摘要:针对PNN神经网络的光滑…

细说MySQL数据类型

TOC 目录 MySQL数据类型 数据类型分类 数值类型 tinyint类型 有符号tinyint范围测试 无符号tinyint范围测试 bit类型 bit类型的显示方式 bit类型的范围测试 float类型 有符号float范围测试 无符号float范围测试 decimal类型 字符串类型 char类型 char类型测试 …

Jmeter 如何监控目标服务的系统资源

下载Jmeter插件管理下载 perfmon 将这个插件管理放到Jmeter的\lib\ext目录下 然后重启Jmeter jmeter-plugins-manager-1.10.jar 下载 perfmon插件 添加 io 内存 磁盘的监听 并且添加监听 在宿主机中安装代理监听程序 并启动 ServerAgent.tar.gz

数据结构-插入排序

插入排序 插入排序的三种常见方法: 直接插入排序、折半插入排序、希尔排序。 数据存储结构 因为我们是用的是C语言来实现算法,因此我们需要创建一个结构体,用来存放初始数据。 结构体定义如下: #define MAX 100 typedef int…

012 C++ AVL_tree

前言 本文将会向你介绍AVL平衡二叉搜索树的实现 引入AVL树 二叉搜索树虽可以缩短查找的效率,但如果数据有序或接近有序普通的二叉搜索树将退化为单支树,查找元素相当于在顺序表中搜索元素,效率低下。因此,两位俄罗斯的数学家G.M…

学习模拟简明教程【Learning to simulate】

深度神经网络是一项令人惊叹的技术。 有了足够的标记数据,他们可以学习为图像和声音等高维输入生成非常准确的分类器。 近年来,机器学习社区已经能够成功解决诸如对象分类、图像中对象检测和图像分割等问题。 上述声明中的加黑字体警告是有足够的标记数…

OpenHarmony源码下载

OpenHarmony源码下载 现在的 OpenHarmony 4.0 源码已经有了,在 https://gitee.com/openharmony 地址中,描述了源码获取的方式,但那是基于 ubuntu 或者说是 Linux 的下载方式。在 windows 平台下的下载方式没有做出介绍。 我自己尝试了 wind…

PCIe协议加持,SD卡9.1规范达到媲美SSD的速度4GB/s

近日,SD协会(SDA)宣布了最新的SD Express存储卡的进化,将microSD Express存储卡的速度提高了一倍,达到2GB/s,并引入了4个新的SD Express速度等级,以确保新的SD 9.1规范中最低的顺序性能水平。这…

【Qt开发流程】之HelloWorld程序

【Qt开发流程】之HelloWorld程序 目的编写程序新建项目文件说明及界面设计 程序运行及发布程序运行程序发布手动构建使用windeployqt进行构建 设置应用程序图标修改快捷键类型列表命令行编译程序命令行编译.ui文件自定义类项目模式及项目文件介绍项目模式项目文件 目的 从Hell…

通过bat脚本控制Oracle服务启动停止

1、将Oracle服务全部设置为手动启动 初始安装Oracle之后服务启动状态: 2、服务功能介绍 3、构建服务启动/停止bat脚本 注意:编码选择ANSI(如果编码不是ANSI运行脚本会显示乱码) echo off :main cls echo 注:请保证该脚本是使用管理员权限…

Iceberg学习笔记(1)—— 基础知识

Iceberg是一个面向海量数据分析场景的开放表格式(Table Format),其设计的目的是解决数据存储和计算引擎之间的适配的问题 表格式(Table Format)可以理解为元数据以及数据文件的一种组织方式,处于计算框架&…

Java —— 抽象类和接口

目录 1. 抽象类 1.1 抽象类概念 1.2 抽象类语法与特性 1.3 抽象类的作用 2. 接口 2.1 接口的概念 2.2 接口的语法规则与特性 2.3 实现多个接口(解决多继承的问题) 2.4 接口间的继承 2.5 抽象类和接口的区别 2.6 接口的使用实例 2.7 Clonable 接口和深拷贝 2.7.1 Cloneable接口 …

探索arkui(2)--- 布局(列表)--- 1(列表数据的展示)

前端开发布局是指前端开发人员宣布他们开发的新网站或应用程序正式上线的活动。在前端开发布局中,开发人员通常会展示新网站或应用程序的设计、功能和用户体验,并向公众宣传新产品的特点和优势。前端开发布局通常是前端开发领域的重要事件,吸…

Apache Airflow (八) :DAG任务依赖设置

🏡 个人主页:IT贫道_大数据OLAP体系技术栈,Apache Doris,Clickhouse 技术-CSDN博客 🚩 私聊博主:加入大数据技术讨论群聊,获取更多大数据资料。 🔔 博主个人B栈地址:豹哥教你大数据的个人空间-豹…

44、echarts图形自动轮播tooltip提示,并显示高亮

自动轮播方法 参数myChart代表echarts的实例名称, options指定图表的配置项和数据, num类目数量(原因:循环时达到最大值后,使其从头开始循环), time轮播间隔时长 //自动轮播显示高亮--tooltip提示 export function autoHover(myChart, option, num, ti…

【漏洞复现】IP-guard WebServer 远程命令执行

漏洞描述 IP-guard是一款终端安全管理软件,旨在帮助企业保护终端设备安全、数据安全、管理网络使用和简化IT系统管理。互联网上披露IP-guard WebServer远程命令执行漏洞情报。攻击者可利用该漏洞执行任意命令,获取服务器控制权限。 免责声明 技术文章仅供参考,任何个人和…

2024年软件测试面试必看系列,看完去面试你会感谢我的!!

朋友圈点赞的测试用例 功能测试 1点赞后是否显示结果 2.点赞后是否可以取消; 3.点赞取消后是否可以重复点赞; 4.共同好友点赞后,是否有消息提醒; 5.非共同好友点赞后,是否有消息提醒; 6.点击点赞人昵称,是否可以跳转到他/她的主页; 7.自己能…

Spring IOC/DI和MVC及若依对应介绍

文章目录 一、Spring IOC、DI注解1.介绍2.使用 二、Spring MVC注解1.介绍2.使用 一、Spring IOC、DI注解 1.介绍 什么是Spring IOC/DI? IOC(Inversion of Control:控制反转)是面向对象编程中的一种设计原则。其中最常见的方式叫做依赖注入(…

【考研】数据结构(更新到顺序表)

线性表的定义和基本操作 学习目标 线性表定义&#xff1a;具有相同数据类型的n个数据元素的有序序列。 顺序表定义&#xff1a; 特点 基本操作 定义 静态&#xff1a; #include<stdio.h> #include<stdlib.h>#define MaxSize 10//静态 typedef struct{int …

Sonar生成PDF错误Can‘t get Compute Engine task status.Retry..... HTTP error: 401

报错及修改&#xff1a; 报错&#xff1a;INFO: Can’t get Compute Engine task status.Retry… org.sonarqube.ws.connectors.ConnectionException: HTTP error: 401, msg: , query: org.apache.commons.httpclient.methods.GetMethod7a021f49 ERROR: Problem generating PD…