torch中关于张量是否是叶子结点,张量梯度是否会被计算,张量梯度是否会被保存的感悟

先上结论:

1、叶子结点定义:  

(1)不依赖其它任何结点的张量

(2)依赖其它张量,但其依赖的所有张量的require_grad=False

#  判断方法:查看is_leaf属性

2、张量梯度是否会被计算:  

require_grad=True,且依赖其的张量不全为require_grad=False,该张量梯度会被计算

# 判断方法:backward之后查看张量的.grad属性(中间变量满足上述要求的梯度肯定也会被计算,只是backward之后会被释放掉,无法查看。中间变量其下面的叶子结点梯度被计算,根据链式法则,侧面也可证明中间变量的梯度肯定被计算了,因此本文只采用叶子结点说明该规律)

3、张量梯度是否会被保存(前提: 张量梯度可以被计算):  

(1)是叶子结点

(2)是非叶子结点,但retain_grad=True

#  判断方法: backward之后查看张量的.grad属性

然后看例子:

# 两个例子先证明第一条:叶子结点定义

import torch# 两个例子先证明第一条:叶子结点定义
a = torch.tensor(1.,requires_grad=True)
b = torch.tensor(1.,requires_grad=False)
d = a+b
d.backward()print(a.is_leaf) # True  (不依赖其它任何张量的结点)
print(b.is_leaf) # True  (不依赖其它任何张量的结点)
print(d.is_leaf) # False (依赖其它张量a和b,但张量a的require_grad=True,所以d不是叶子结点)c = torch.tensor(1., requires_grad=False)
e = b+cprint(b.is_leaf)    # True  (不依赖其它任何张量的结点)
print(c.is_leaf)    # True  (不依赖其它任何张量的结点)
print(e.is_leaf)    # True (依赖其它张量b和c,但张量a和c的require_grad=False,所以e仍然是叶子结点)
# e.backward()会报错,因为所有结点的require_grad=False,因此不需要求梯度

# 两个例子再证明第二条:张量梯度是否会被计算

import torch# 两个例子再证明第二条:张量梯度是否会被计算
# require_grad=True的张量,且依赖其的张量不全为require_grad=False梯度会被计算
a = torch.tensor(1.,requires_grad=True)   # a require_grad=True,因此a的梯度会被计算
b = torch.tensor(1.,requires_grad=False)  # b require_grad=False,因此b的梯度不会被计算
d = a+b                                   # d require_grad=True,因此d的梯度会被计算
c = torch.tensor(1., requires_grad=True)  # c require_grad=True,因此c的梯度会被计算
e = a+c                                   # e require_grad=True,因此e的梯度会被计算
d = d.detach()
print(d.requires_grad)                    # d require_grad=False,因此d的梯度不会被计算f = d+e 
f.backward()
print(a.grad)                             # a有梯度# require_grad=True的张量,但依赖其的张量全为require_grad=False,梯度不会被计算
a = torch.tensor(1.,requires_grad=True)   # a require_grad=True, 因此a的梯度会被计算
b = torch.tensor(1.,requires_grad=False)  # b require_grad=False, 因此b的梯度不会被计算
d = a+b                                   # d require_grad=True, 因此d的梯度会被计算
c = torch.tensor(1., requires_grad=True)  # c require_grad=True,因此c的梯度会被计算
d = d.detach()                            # d require_grad=False, 因此d的梯度不会被计算
f = d+c                                   # f require_grad=True,因此f的梯度会被计算
f.backward()
print(a.grad)                             # a没有梯度

# 再证明第三条: 张量梯度是否会被保存,前提是张量的梯度能被计算(既满足第二条)

import torch# 再证明第三条: 张量梯度是否会被保存,前提是张量的梯度能被计算(既满足第二条)
# (1)叶子结点的梯度会被保存
a = torch.tensor(1.,requires_grad=True)     # a的require_grad= True,且a是一个叶子结点(a.is_leaf=True),所以backward之后a的梯度会被保存。(满足条件3)
b = torch.tensor(1.,requires_grad=False)    # b是一个叶子结点(b.is_leaf=True),但b的require_grad= False,所以backward之后b的梯度不会被保存.(不满足条件3)
d = a+b                                     # d的require_grad= True, 但d不是一个叶子结点,所以backward之后b的梯度不会被保存。(不满足条件3)
d.backward()
print(a.is_leaf)
print(a.grad)                               # a的梯度被保存# 感悟: 神经网络各层里面的参数require_grad=True(属于Parameter类型,其初始化的时候默认require_grad=True),并且如果上层不会被断开(满足梯度可以被计算条件)。且神经网络里面各层的参数都是叶子结点(从计算图可以得知满足1里面第一条),因此满足梯度保存条件第一条。因此其梯度一定会被保存。满足了以上两条,因此backward的时候其梯度一定会被计算并且保存,从而step的时候才能用于梯度更新)# (2)非叶子结点,但retain_grad=True的张量梯度也会被保存。
a = torch.tensor(1.,requires_grad=True)     # a的require_grad= True,且a是一个叶子结点(a.is_leaf=True),所以backward之后a的梯度会被保存。(满足条件3)
b = torch.tensor(1.,requires_grad=False)    # b是一个叶子结点(b.is_leaf=True),但b的require_grad= False,所以backward之后b的梯度不会被保存.(不满足条件3)
d = a+b                                     # d的require_grad= True, 但d不是一个叶子结点,所以backward之后b的梯度不会被保存。(不满足条件3)
print(d.is_leaf)                            # False
d.retain_grad()                             # retain_grad=True
d.backward()
print(d.is_leaf)                            # False
print(d.grad)                               # d的梯度被保存

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/240764.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

SpringMVC之文件的下载

系列文章目录 提示:这里可以添加系列文章的所有文章的目录,目录需要自己手动添加 SpringMVC之文件的下载 提示:写完文章后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 系列文章目录前言一、文件下载实现…

认识Linux背景

1.发展史 Linux从哪里来?它是怎么发展的?在这里简要介绍Linux的发展史 要说Linux,还得从UNIX说起 UNIX发展的历史 1968年,一些来自通用电器公司、贝尔实验室和麻省理工学院的研究人员开发了一个名叫Multics的特殊操作系统。Mu…

分布式锁常见问题及其解决方案

一、为什么要使用分布式锁? 因为在集群下,相当于多个JVM,就相当于多个锁,集群之间锁是没有关联的,会照成锁失效从而导致线程安全问题 分布式锁可以分别通过MySQL、Redis、Zookeeper来进行实现 二、redis分布式锁的实…

华为发布全闪备份一体机旗舰新品,并宣布备份软件开源

[中国,上海,2023年12月20日]在20日举行的OceanProtect数据保护新品发布会上,华为发布全闪备份一体机旗舰新品,并宣布备份软件开源,以应对智慧金融、自动驾驶等场景对数据备份效率及数据安全方面的新诉求,为…

工业信息采集平台的五大核心优势

关键字:工业信息采集平台,蓝鹏数据采集系统,蓝鹏测控系统, 生产管控系统, 生产数据处理平台,MES系统数据采集, 蓝鹏数据采集平台通过实现和构成其他工业数据信息平台的一级设备进行通讯,从而完成平台之间的无缝对接。这里我们采用的最多的方式是和PLC进行…

神经网络:深度学习基础

1.反向传播算法(BP)的概念及简单推导 反向传播(Backpropagation,BP)算法是一种与最优化方法(如梯度下降法)结合使用的,用来训练人工神经网络的常见算法。BP算法对网络中所有权重计算…

力扣labuladong——一刷day77

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 前言一、力扣207. 课程表 前言 图这种数据结构有一些比较特殊的算法,比如二分图判断,有环图无环图的判断,拓扑排序,以…

Redis取最近10条记录

有时候我们有这样的需求,就是取最近10条数据展示,这些数据不需要存数据库,只用于暂时最近的10条,就没必要在用到Mysql类似的数据库,只需要用redis即可,这样既方便也快! 具体取最近10条的方法&a…

19-二分-值域二分-有序矩阵中第 K 小的元素

这是二分法的第19篇算法,力扣链接。 给你一个 n x n 矩阵 matrix ,其中每行和每列元素均按升序排序,找到矩阵中第 k 小的元素。 请注意,它是 排序后 的第 k 小元素,而不是第 k 个 不同 的元素。 你必须找到一个内存复杂…

ElasticSearch插件手动安装

./plugin install file:///home/es2.4/es/elasticsearch-kopf-master.zipElasticSearch插件手动安装_如何下载elasticsearch-kopf-master.zip-CSDN博客

Go 代码检查工具 golangci-lint

一、介绍 golangci-lint 是一个代码检查工具的集合,聚集了多种 Go 代码检查工具,如 golint、go vet 等。 优点: 运行速度快可以集成到 vscode、goland 等开发工具中包含了非常多种代码检查器可以集成到 CI 中这是包含的代码检查器列表&…

异常(Java)

1.异常的概念 在 Java 中,将程序执行过程中发生的不正常行为称为异常 。 1.算数异常 System.out.println(10 / 0); // 执行结果 Exception in thread "main" java.lang.ArithmeticException: / by zero 2.数组越界异常 int[] arr {1, 2, 3}; System.out.…

ARM 汇编入门

ARM 汇编入门 引言 ARM 汇编语言是 ARM 架构的汇编语言,用于直接控制 ARM 处理器。虽然现代软件开发更多地依赖于高级语言和编译器,但理解 ARM 汇编仍然对于深入了解系统、优化代码和进行低级调试非常重要。本文将为您提供一个简单的 ARM 汇编入门指南…

DBA-MySql面试问题及答案-上

文章目录 1.什么是数据库?2.如何查看某个操作的语法?3.MySql的存储引擎有哪些?4.常用的2种存储引擎?6.可以针对表设置引擎吗?如何设置?6.选择合适的存储引擎?7.选择合适的数据类型8.char & varchar9.Mysql字符集10.如何选择…

第九周算法题(哈希映射,二分,Floyd算法 (含详细讲解) )

第九周算法题 第一题 题目来源&#xff1a;33. 搜索旋转排序数组 - 力扣&#xff08;LeetCode&#xff09; 题目描述&#xff1a;整数数组 nums 按升序排列&#xff0c;数组中的值 互不相同 。 在传递给函数之前&#xff0c;nums 在预先未知的某个下标 k&#xff08;0 <…

全网最全ChatGPT指令大全prompt

全网最全的ChatGPT大全提示词&#xff0c;大家可以进行下载。 AIGC ChatGPT 职场案例 AI 绘画 与 短视频制作 PowerBI 商业智能 68集 数据库Mysql 8.0 54集 数据库Oracle 21C 142集 Office 2021实战应用 Python 数据分析实战&#xff0c; ETL Informatica 数据仓库案例实战 E…

【JAVA面试题】什么是引用传递?什么是值传递?

&#x1f34e;个人博客&#xff1a;个人主页 &#x1f3c6;个人专栏&#xff1a; JAVA ⛳️ 功不唐捐&#xff0c;玉汝于成 前言 博客的正文部分可以详细介绍Java中参数传递的机制&#xff0c;强调Java是按值传递的&#xff0c;并解释了基本数据类型和对象引用在这种传…

二级分销的魅力:无限裂变创造十八亿的流水

有这么一个团队&#xff0c;仅靠这一个二级分销&#xff0c;六个月就打造了十八亿的流水。听着是不是很恐怖&#xff1f;十八亿确实是一个很大的数字&#xff0c;那么这个团队是怎么做到的呢&#xff1f;我们接着往下看。 这是一个销售减脂产品的团队。不靠网店&#xff0c;不…

【JMeter入门】—— JMeter介绍

1、什么是JMeter Apache JMeter是Apache组织开发的基于Java的压力测试工具&#xff0c;用于对软件做压力测试。它最初被设计用于Web应用测试&#xff0c;但后来扩展到其他测试领域。 &#xff08;Apache JMeter是100%纯JAVA桌面应用程序&#xff09; Apache JMeter可以用于对静…

pycharm git 版本回退

参考 https://blog.csdn.net/qq_38175912/article/details/102860195 yoyoketang 悠悠课堂