【转载+修改】pytorch中backward求梯度方法的具体解析

原则上,pytorch不支持张量对张量的求导,它只支持标量对张量的求导
我们先看标量对张量求导的情况

import torch
x=torch.ones(2,2,requires_grad=True)
print(x)
print(x.grad_fn)

输出,由于x是被直接创建的,也就是说它是一个叶子节点,所以它的grad_fn属性的值为None
tensor([[1., 1.], [1., 1.]], requires_grad=True) None
接下来对叶子节点x进行第一步操作,y=x+2

y=x+2
print(y)
print(y.grad_fn)

输出:这里可以看到y的grad_fn属性变成了AddBackward,所以grad_fn属性记录的是该张量的上一步操作。
tensor([[3., 3.], [3., 3.]], grad_fn=) <AddBackward0 object at 0x00000206A7EE2108>
然后进行操作z=y * y * 3,再对z求平均值

z=y*y*3 
out=z.mean()
print(z,out)

输出结果:
tensor([[27., 27.], [27., 27.]], grad_fn=) tensor(27., grad_fn=)
此时我们利用backward()函数来求x的梯度,由于out是求平均值得到的一个标量,所以我们可以不用向backward函数传递一个张量,而是直接计算。

out.backward()
print(x.grad)

输出:tensor([[4.5000, 4.5000], [4.5000, 4.5000]])

我们来手动计算,看得到的结果是否与backword函数得到的结果一致
在这里插入图片描述
显然,结果是一致的。

再来看张量对张量求导的情况
前面已经强调过,pytorch不允许张量对张量求导,所以在使用张量对张量求导的时候,必须要传入一个与被求导张量同形的张量,然后pytorch根据传入的张量与被求导张量作加权求和将其转化为标量,这里比较晦涩难懂,没关系,接下来我们用例子来解释
首先创建一个叶子节点x

x=torch.tensor([[1.0,2.0],[3.0,4.0]],requires_grad=True)
print(x)

tensor([[1., 2.], [3., 4.]], requires_grad=True)
接下来计算y=3*x

y=3*x
print(y)

tensor([[ 3., 6.], [ 9., 12.]], grad_fn=)
接下来我们直接用y求导y.backward()
毫无意外,直接报错,这就印证了前面说过的pytorch不支持张量对张量直接求导

RuntimeError: grad can be implicitly created only for scalar outputs

于是我们构建一个与y同形的张量z,(一般是传入一个单位张量,可以参考y.backward(torch.ones_like(y)),这样计算得到的参数梯度没有张量z的影响)把z作为y.backward()的参数求y对x的导。

z=torch.tensor([[1.0,0.1],[0.01,0.001]],dtype=torch.float)
y.backward(z)
print(x.grad)

输出结果:tensor([[3.0000, 0.3000], [0.0300, 0.0030]])(张量的梯度是一个与原张量同形的张量)
事实上,到这里我们仍一头雾水,不知道这个结果是如何得出的,下面给出他的通用计算公式(需要注意的是,该公式只是用来方便计算的,属于计算技巧)
在这里插入图片描述
至于y对x的导数,如果有多层复合函数,利用链式法则计算即可。上面的例子比较简单,y对x求导的结果是3,再乘以张量z,很容验证得到同样的结果。
接下来我们推导上述计算公式,上文已经提到,对于表达式y.backward(z) (y、z为同形张量)的计算过程,实际上将y与z加权求和得到标量m,然后用m对x求导得到结果,也就是说实际上有这样一步计算m=torch.sum(y*z)
我们可以来验证这一步计算的正确性

m = torch.sum(y*z)
print(m)

输出tensor(3.7020, grad_fn=)可以看到的是m是一个标量。
接着再用m对x求导

m.backward()
print(x.grad)

很容易得到上述结果tensor([[3.0000, 0.3000], [0.0300, 0.0030]])
下面给出上述计算公式的推导:
在这里插入图片描述
参考链接:https://blog.csdn.net/weixin_45021364/article/details/105194187

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/10417.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Linux查看内存的几种方法

PS的拼接方法 ps aux|head -1;ps aux|grep -v PID|sort -rn -k 4|head 进程的 status 比如说你要查看的进程pid是33123 cat /proc/33123/status VmRSS: 表示占用的物理内存 top PID&#xff1a;进程的ID USER&#xff1a;进程所有者 PR&#xff1a;进程的优先级别&#x…

python内置函数

https://www.runoob.com/python/python-built-in-functions.html https://www.runoob.com/python3/python3-function.html

SSD寿命和写放大测试

一、简述 SSD寿命规格&#xff0c;业界标准为TBW&#xff0c;TBW指的是Terabyte Writteb写入的兆兆字节&#xff0c;也有定义为Total Bytes Written&#xff0c;SSD使用寿命结束之前指定工作量可以写入SSD的总数据量&#xff0c;用来表达固态硬盘的寿命指标。 因为 SSD 使用 N…

同步、异步、阻塞、非阻塞

一、概念 同步与异步&#xff08;线程间调用&#xff09;的区别&#xff1a;关注的是调用方与被调用方之间的交互方式。同步调用会等待被调用方的结果返回&#xff0c;而异步调用则不会等待结果立即返回&#xff0c;可以通过回调或其他方式获取结果。 阻塞非阻塞&#xff08;…

springboot集成

maven配置 <dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-data-redis</artifactId> </dependency> <dependency><groupId>org.apache.commons</groupId><artifactId>…

wordpress我的个人网站搭建

WordPress介绍 WordPress是一个功能强大且易于使用的网站管理平台。它是基于PHP和MySQL构建的&#xff0c;可以在各种不同的主机上运行。 wordpress对服务器的要求 需求最低版本要求PHP7.4 或更高版本MySQL5.6 或更高版本Web服务器任意&#xff08;如&#xff1a;Apache、Ng…

一套流程6个步骤,教你如何正确采购询价

采购询价&#xff08;RFQ&#xff09;是一种竞争性投标文件&#xff0c;用于邀请供应商或承包商就标准化或重复生产的产品或服务提交报价。 询价通常用于大批量/低价值项目&#xff0c;买方必须提供技术规格和商业要求&#xff0c;该文件有时也称为招标书或投标邀请书。询价流…

git恢复删除的分支

1.查看被删除的分支 git remote prune --dry-run origin 被删除的分支是191 2.找到被删除分支的最后一次提交记录的commit SHA值 git reflog 最后一次提交的commit SHA值是3fa7532 3.恢复分支 git checkout -b xiaomeng 3fa7532 4.恢复成功后提交到远端&#xff0c;over&…

ubuntu20.04 安装 docker engine

打开docker官网 点击上图中间的Linux&#xff0c;会是这样&#xff1a; 点击上图的左边栏的 Docker Engine,点击install, 点击 Ubuntu&#xff0c;会是这样&#xff1a; 把页面翻下来&#xff0c;先按照 Insstallation methods 中的 set up thre repository&#xff0c;执行这些…

pytorch工具——认识pytorch

目录 pytorch的基本元素操作创建一个没有初始化的矩阵创建一个有初始化的矩阵创建一个全0矩阵并可指定数据元素类型为long直接通过数据创建张量通过已有的一个张量创建相同尺寸的新张量利用randn_like方法得到相同尺寸张量&#xff0c;并且采用随机初始化的方法为其赋值采用.si…

压力测试-商场项目

1.压力测试 压力测试是给软件不断加压&#xff0c;强制其在极限的情况下运行&#xff0c;观察它可以运行到何种程度&#xff0c;从而发现性能缺陷&#xff0c;是通过搭建与实际环境相似的测试环境&#xff0c;通过测试程序在同一时间内或某一段时间内&#xff0c;向系统发送预…

项目:点餐系统1

项目简介&#xff1a;实现一个http点餐系统服务器&#xff0c;能够支持用户在浏览器访问服务器获取餐馆首页&#xff0c;进行菜品以及订单管理。 具体模型如下&#xff1a; 用户分类&#xff1a; 管理员&#xff1a;进行订单以及菜品管理&#xff08;菜品&订单的增删改查&a…

STM32MP157驱动开发——按键驱动(异步通知)

文章目录 “异步通知 ”机制&#xff1a;信号的宏定义&#xff1a;信号注册 APP执行过程驱动编程做的事应用编程做的事异步通知方式的按键驱动程序(stm32mp157)button_test.cgpio_key_drv.cMakefile修改设备树文件编译测试 “异步通知 ”机制&#xff1a; 信号的宏定义&#x…

protobuf安装教程

protobuf安装 一&#xff0c;Windows下安装下载protobuf配置环境变量检查是否安装成功 二&#xff0c;Linux下安装下载protobuf安装protobuf检查是否安装成功 一&#xff0c;Windows下安装 下载protobuf 下载地址 本次下载以v21.11为例&#xff0c;根据自己需求下载即可。 配…

2023年深圳杯数学建模 D题 基于机理的致伤工具推断

致伤工具的推断一直是法医工作中的热点和难点。由于作用位置、作用方式的不同&#xff0c;相同的致伤工具在人体组织上会形成不同的损伤形态&#xff0c;不同的致伤工具也可能形成相同的损伤形态。致伤工具品种繁多、形态各异&#xff0c;但大致可分为两类&#xff1a;锐器&…

【12】STM32·HAL库开发-STM32时钟系统 | F1/F4/F7时钟树 | 配置系统时钟

目录 1.认识时钟树&#xff08;掌握&#xff09;1.1什么是时钟&#xff1f;1.2认识时钟树&#xff08;F1&#xff09;1.2.1STM32F103时钟树简图1.2.2STM32CubeMX时钟树&#xff08;F103&#xff09; 1.3认识时钟树&#xff08;F4&#xff09;1.3.1F407时钟树1.3.2F429时钟树1.3…

FFmpeg AVFilter的原理(三)- filter是如何被驱动的

首先上官方filter的链接&#xff1a;https://ffmpeg.org/ffmpeg-filters.html 关于filter命令行&#xff1a;FFmpeg-4.0 的filter机制的架构与实现.之一 Filter原理 1、下面是一个avfilter的graph 上图是ffmpeg中doc/examples中filtering_video.c案例的示意图。 特别注意上面蓝…

HPC集群调度系统和计算系统

什么是计算云&#xff1f; 所谓的计算云指的是为计算业务优化的类云基础架构&#xff0c;它强调用云的方式解决计算问题&#xff0c;而不是将“计算”搬到现有的公有云或者容器云上。 目前公有云或者容器云&#xff08;例如k8s&#xff09;上的HPC解决方案本质上都是将现有的H…

Linux CentOS监控系统的运行情况工具 - top/htop/glances/sar/nmon

在CentOS系统中&#xff0c;您可以使用以下工具来监控系统的运行情况&#xff1a; 1. top&#xff1a; top 是一个命令行工具&#xff0c;用于实时监控系统的进程、CPU、内存和负载情况。您可以使用以下命令来启动 top&#xff1a; top 输出 2. htop&#xff1a; htop 是一…

【MySQL】数据库基础

今天我们来讲一讲使用MySQL必备的基础知识 目录 一、MySQL的登录选项 二、什么是数据库 三、一些主流数据库 四、 服务器&#xff0c;数据库&#xff0c;表关系 五、mysql的架构 5.1 存储引擎 5.1.1 查看存储引擎 5.1.2 不同存储引擎的对比 六、SQL语言的分类 一、MyS…