【Pytorch笔记】4.梯度计算

深度之眼官方账号 - 01-04-mp4-计算图与动态图机制

前置知识:计算图
可以参考我的笔记:
【学习笔记】计算机视觉与深度学习(2.全连接神经网络)

计算图

在这里插入图片描述
以这棵计算图为例。这个计算图中,叶子节点为x和w。

import torchw = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)a = torch.add(w, x)
b = torch.add(w, 1)
y = torch.mul(a, b)# 调用backward()方法,开始反向求梯度
y.backward()
print(w.grad)print("is_leaf:\n", w.is_leaf, x.is_leaf, a.is_leaf, b.is_leaf, y.is_leaf)
print("gradient:\n", w.grad, x.grad, a.grad, b.grad, y.grad)

输出:

tensor([5.])
is_leaf:True True False False False
gradient:tensor([5.]) tensor([2.]) None None None

由此可见,非叶子节点在最后不会被保留梯度。这是出于节省空间的需要而这样设计的。实际的计算图会非常大,如果每个节点都保留梯度,会占用非常大的存储空间,而这些节点的梯度对于我们学习并没有什么帮助。

如果非要看他们的梯度,可以这样操作:在a = torch.add(w, x)的后面加上一句a.retain_grad(),这样a的梯度就会被存储起来。
输出会变成:

tensor([5.])
is_leaf:True True False False False
gradient:tensor([5.]) tensor([2.]) tensor([2.]) None None

对于节点,还可以看这些节点进行的运算。grad_fn,gradient function的缩写,表示这个节点的tensor是什么运算产生的。加一句:

print("gradient function:\n", w.grad_fn, '\n', x.grad_fn, '\n', a.grad_fn, '\n', b.grad_fn, '\n', y.grad_fn)

会输出

gradient function:NoneNone<AddBackward0 object at 0x000001B1DA3651C0><AddBackward0 object at 0x000001B1DA3651F0><MulBackward0 object at 0x000001B1DA3515B0>

retain_graph

import torchw = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)a = torch.add(w, x)
a.retain_grad()
b = torch.add(w, 1)
y = torch.mul(a, b)# 调用backward()方法,开始反向求梯度
y.backward()
y.backward()

连续两次调用backward()方法,会报这样的错误:

RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

原因是我们进行第一次backward()后,计算图就被自动释放掉了,进行第二次backward()时,没有计算图可以计算梯度,于是报错。

解决方案:backward内部添加一个参数:retain_graph=True,意思是计算完梯度后保留计算图。

# 调用backward()方法,开始反向求梯度
y.backward(retain_graph=True)
y.backward()

这样就不会报错了。

gradient

当计算图末部的节点有1个以上时,有时我们会希望他们之间的梯度有一个权重关系。这时就会用上gradient

import torchw = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)a = torch.add(w, x)
b = torch.add(w, 1)# 不难看出,y0和y1是两个互不干扰的末部节点
y0 = torch.mul(a, b)
y1 = torch.add(a, b)# 将两个末部节点打包起来
loss = torch.cat([y0, y1], dim=0)
grad_tensors = torch.tensor([1., 2.])# 将grad_tensors中的内容作为权重,变成y0+2y1
loss.backward(gradient=grad_tensors)print(w.grad)

输出

tensor([9.])

如果把grad_tensors改成:

grad_tensors = torch.tensor([1., 3.])

输出变成:

tensor([11.])

torch.autograd.grad()

除了加减乘除法,我们还可以对torch进行求导操作。求的是 d ( o u t p u t s ) d ( i n p u t s ) \frac{d(outputs)}{d(inputs)} d(inputs)d(outputs)

torch.autograd.grad(outputs,inputs,grad_outputs=None,retain_graph=None,create_graph=False)

outputs和inputs已在上述定义中给出;
grad_outputs:多梯度权重;
retain_graph:保留计算图;
create_graph:创建计算图。

import torch# y = x ** 2
x = torch.tensor([3.], requires_grad=True)
y = torch.pow(x, 2)# grad_1 = dy / dx = 2x = 6
grad_1 = torch.autograd.grad(y, x, create_graph=True)
print(grad_1)# grad_2 = d(dy / dx) / dx = 2
grad_2 = torch.autograd.grad(grad_1, x)
print(grad_2)

输出

(tensor([6.], grad_fn=<MulBackward0>),)
(tensor([2.]),)

autograd注意事项

1.梯度不会自动清零

import torchw = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)for i in range(4):a = torch.add(w, x)b = torch.mul(w, x)y = torch.mul(a, b)y.backward()print("w's grad: ", w.grad)# w.grad.zero_()

输出:

w's grad:  tensor([8.])
w's grad:  tensor([16.])
w's grad:  tensor([24.])
w's grad:  tensor([32.])

由此可以看出,在不加上注释掉的那一行时,梯度在w处是不断累积的。而如果我们把print后面的那句w.grad.zero_()加上,输出就会变成:

w's grad:  tensor([8.])
w's grad:  tensor([8.])
w's grad:  tensor([8.])
w's grad:  tensor([8.])

w.grad.zero_()的意思就是把w处积累的梯度清零。

2.依赖于叶子节点的节点,requires_grad默认为True

可以从上面的代码中发现,我们只有在定义w和x两个tensor时,设置requires_grad为True。这个参数在定义tensor时默认为False。后面我们的a、b、y都没有设置这个参数。

如果我们定义w和x的时候不加上requires_grad=True,那么y.backward()这一步就会报错,因为我们的预设,这两个tensor不需要梯度,于是就无法求梯度。而w和x是我们计算图上的叶子节点,所以必须加上requires_grad=True。

而后面通过w和x延伸定义出的a、b、y,由于依赖的w、x的requires_grad是True,那么a、b、y的这个参数也被默认设置为了True,不需要我们手动添加。

3.叶子节点不可执行in-place操作

计算图上叶子节点处的tensor不能进行原地修改。

什么是in-place操作?
t = torch.tensor([1., 2.])
t.add_(3.)
print(t)

输出

tensor([4., 5.])

torch.Tensor.add_就是torch.add的in-place版本。所谓in-place,就是在tensor上进行原地修改。大部分的torch.tensor的运算,名字后面加一个下划线,就变成inplace操作了。

再比如求绝对值:

t = torch.tensor([-1., -2.])
t.abs_()
print(t)

输出

tensor([1., 2.])

知道什么是in-place操作后,我们尝试一下在requires_grad=True的叶子节点上原地修改,代码如下:

import torchw = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)a = torch.add(w, x)
b = torch.mul(w, x)
y = torch.mul(a, b)w.add_(1)y.backward()

报错信息:

RuntimeError: a leaf Variable that requires grad is being used in an in-place operation.

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/93759.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

基于.Net Core实现自定义皮肤WidForm窗口

前言 今天一起来实现基于.Net Core、Windows Form实现自定义窗口皮肤&#xff0c;并实现窗口移动功能。 素材 准备素材&#xff1a;边框、标题栏、关闭按钮图标。 窗体设计 1、创建Window窗体项目 2、窗体设计 拖拉4个Panel控件&#xff0c;分别用于&#xff1a;标题栏、关…

【Redis】基础数据结构-字典

Redis 字典 基本语法 字典是Redis中的一种数据结构&#xff0c;底层使用哈希表实现&#xff0c;一个哈希表中可以存储多个键值对&#xff0c;它的语法如下&#xff0c;其中KEY为键&#xff0c;field和value为值&#xff08;也是一个键值对&#xff09;&#xff1a; HSET key…

基于SSM农产品商城系统

基于SSM农产品商城系统的设计与实现&#xff0c;前后端分离&#xff0c;文档 开发语言&#xff1a;Java数据库&#xff1a;MySQL技术&#xff1a;SpringSpringMVCMyBatisVue工具&#xff1a;IDEA/Ecilpse、Navicat、Maven 系统展示 农产品列表 产品详情 个人中心 登陆界面 管…

以太网基础学习(二)——ARP协议

一、什么是MAC地址 MAC地址&#xff08;英语&#xff1a;Media Access Control Address&#xff09;&#xff0c;直译为媒体访问控制位址&#xff0c;也称为局域网地址&#xff08;LAN Address&#xff09;&#xff0c;MAC位址&#xff0c;以太网地址&#xff08;Ethernet Addr…

【算法训练-字符串 三】字符串相加

废话不多说&#xff0c;喊一句号子鼓励自己&#xff1a;程序员永不失业&#xff0c;程序员走向架构&#xff01;本篇Blog的主题是【字符串相加】&#xff0c;使用【字符串】这个基本的数据结构来实现&#xff0c;这个高频题的站点是&#xff1a;CodeTop&#xff0c;筛选条件为&…

电脑突然提示mfc140u.dll丢失,缺失mfc140u.dll无法运行程序的解决方法

在当今信息化社会&#xff0c;电脑已经成为我们生活和工作中不可或缺的一部分。然而&#xff0c;随着技术的不断发展&#xff0c;电脑也会出现各种问题。其中&#xff0c;最常见的问题之一就是“mfc140u.dll丢失”。那么&#xff0c;当我们遇到这个问题时&#xff0c;应该如何解…

ISP图像信号处理——白平衡校正和标定介绍以及C++实现

从数码相机直接输出的未经过处理过的RAW图到平常看到的JEPG图有一系列复杂的图像信号处理过程&#xff0c;称作ISP&#xff08;Image Signal Processing&#xff09;。这个过程会经过图像处理和压缩。 参考文章1&#xff1a;http://t.csdn.cn/LvHH5 参考文章2&#xff1a;htt…

WebSocket实战之四WSS配置

一、前言 上一篇文章WebSocket实战之三遇上PAC &#xff0c;碰到的问题只能上安全的WebSocket&#xff08;WSS&#xff09;才能解决&#xff0c;配置证书还是挺麻烦的&#xff0c;主要是每年都需要重新更新证书&#xff0c;我配置过的证书最长有效期也只有两年&#xff0c;搞不…

【数据结构】排序(2)—冒泡排序 快速排序

目录 一. 冒泡排序 基本思想 代码实现 时间和空间复杂度 稳定性 二. 快速排序 基本思想 代码实现 hoare法 挖坑法 前后指针法 时间和空间复杂度 稳定性 一. 冒泡排序 基本思想 冒泡排序是一种交换排序。两两比较数组元素&#xff0c;如果是逆序(即排列顺序与排序后…

定时任务管理平台青龙 QingLong

一、关于 QingLong 1.1 QingLong 介绍 青龙面板是支持 Python3、JavaScript、Shell、Typescript 多语言的定时任务管理平台&#xff0c;支持在线管理脚本和日志等。其功能丰富&#xff0c;能够满足大部分需求场景&#xff0c;值得一试。 主要功能 支持多种脚本语言&#xf…

我的企业证书是正常的但是下载应用app到手机提示无法安装“app名字”无法安装此app,因为无法验证其完整性解决方案

我的企业证书是正常的但是下载应用app到手机提示无法安装“app名字”无法安装此app&#xff0c;因为无法验证其完整性解决方案 首先&#xff0c;确保您从可信任的来源下载并安装企业开发者签名过的应用程序。如果您不确定应用程序的来源&#xff0c;建议您联系应用程序提供者…

你写过的最蠢的代码是?——AI领域的奇妙体验

&#x1f337;&#x1f341; 博主猫头虎&#xff08;&#x1f405;&#x1f43e;&#xff09;带您 Go to New World✨&#x1f341; &#x1f984; 博客首页——&#x1f405;&#x1f43e;猫头虎的博客&#x1f390; &#x1f433; 《面试题大全专栏》 &#x1f995; 文章图文…

Redis与分布式-哨兵模式

接上文 Redis与分布式-主从复制 1.哨兵模式 启动一个哨兵&#xff0c;只需要修改配置文件即可&#xff0c; sentinel monitor lbwnb 1247.0.0.1 6001 1先将所有服务关闭&#xff0c;然后修改配置文件&#xff0c;redis Master&#xff0c;redis Slave&#xff0c;redis Slave…

源码系列 之 ThreadLocal

简介 ThreadLocal的作用是做数据隔离&#xff0c;存储的变量只属于当前线程&#xff0c;相当于当前线程的局部变量&#xff0c;多线程环境下&#xff0c;不会被别的线程访问与修改。常用于存储线程私有成员变量、上下文&#xff0c;和用于同一线程&#xff0c;不同层级方法间传…

复习C语言数组的用法

实验内容 1.1设计一个函数fun&#xff0c;功能是有N*N的矩阵&#xff0c;根据给定的m值&#xff0c;m<N,将每行元素中的值&#xff0c;均往右移m个位置&#xff0c;左边置0 #include<stdio.h> void fun(int (*a)[3],int m){int n,j,i,k,num;int p2;//右移位置列数nu…

基于体素场景的摄像机穿模处理

基于上一篇一种基于体素的射线检测 使用射线处理第三人称摄像头穿模问题 基于体素的第三人称摄像机拉近简单处理 摄像机移动至碰撞点处 简单的从角色身上发射一条射线到摄像机&#xff0c;中途遇到碰撞就把摄像机移动至该碰撞点 public void UpdateDistance(float defaultDist…

OpenGL之光照贴图

我们需要拓展之前的系统,引入漫反射和镜面光贴图(Map)。这允许我们对物体的漫反射分量和镜面光分量有着更精确的控制。 漫反射贴图 我们希望通过某种方式对物体的每个片段单独设置漫反射颜色。我们仅仅是对同样的原理使用了不同的名字:其实都是使用一张覆盖物体的图像,让我…

软件测试教程 自动化测试selenium篇(二)

掌握Selenium常用的API的使用 一、webdriver API public class Main {public static void main(String[] args) {ChromeOptions options=new ChromeOptions();//参数表示允许所有请求options.addArguments("--remote-allow-origins=*");WebDriver webDriver=new Chr…

【Maven基础篇-黑马程序员】Maven项目管理从基础到高级,一次搞定!

文章目录 前言Maven简介Maven是什么Maven的作用 Maven的下载与安装Maven基础概念仓库坐标仓库配置全局setting与用户setting区别 第一个Maven程序&#xff08;手工制作&#xff09;第一个Maven程序&#xff08;IDEA生成&#xff09;使用模版&#xff08;骨架&#xff09;创建Ma…

vcruntime140.dll如何修复,快速修复vcruntime140.dll丢失的三种方法

vcruntime140.dll是Visual C 2015运行库的一个组件&#xff0c;它包含了许多运行时函数&#xff0c;用于支持各种程序的正常运行。当vcruntime140.dll文件丢失时&#xff0c;可能会导致一些程序无法正常运行。本文将详细介绍vcruntime140.dll的作用、丢失原因以及三种修复方法。…