Pytorch:backward()函数详解

.backward()

.backward() 是 PyTorch 中用于自动求导的函数,它的主要作用是计算损失函数对模型参数的梯度,从而实现反向传播算法。

在深度学习中,我们通常使用梯度下降算法来更新模型参数,使得模型能够逐步逼近最优解。
在梯度下降算法中,我们需要计算损失函数关于模型参数的梯度,以便确定参数更新的方向和大小。
这个计算过程就是反向传播算法,而 loss.backward() 就是反向传播算法的实现。

官方:官方文档

Tensor.backward(gradient=None, retain_graph=None, create_graph=False, inputs=None)[source] .
当前Variable(理解成函数Y)对leaf variable(理解成变量X=[x1,x2,x3])求偏导。

Computes the gradient of current tensor wrt graph leaves.
计算当前张量相对于图中叶子节点的梯度。

The graph is differentiated using the chain rule. If the tensor is non-scalar (i.e. its data has more than one element) and requires gradient, the function additionally requires specifying gradient. It should be a tensor of matching type and location, that contains the gradient of the differentiated function w.r.t. self.
“使用链式法则对图进行微分。如果张量是非标量(即其数据具有多个元素)并且需要梯度,则该函数还需要指定梯度。它应该是相同类型和位置的张量,其中包含相对于自身的微分函数的梯度。”

This function accumulates gradients in the leaves - you might need to zero .grad attributes or set them to None before calling it. See Default gradient layouts for details on the memory layout of accumulated gradients.
这个函数在叶子节点中累积梯度,可能需要在调用它之前将 .grad 属性清零或设置为 None。有关累积梯度的内存布局详情,请参阅默认梯度布局。

Parameters

gradient (Tensor or None) – 计算图可以通过链式法则求导。如果Tensor是 非标量(non-scalar)的(即是说Y中有不止一个y,即Y=[y1,y2,…]),且requires_grad=True。那么此函数需要指定gradient,它的形状应该和Variable的长度匹配(这个就很好理解了,gradient的长度体与Y的长度一直才能保存每一个yi的梯度值啊),里面保存了Variable的梯度。
原文链接:https://blog.csdn.net/weixin_43763731/article/details/88982979

retain_graph (bool, optional) – 这个参数默认是False。计算梯度所必要的buffer在经历过一次backward过程后不会被释放。如果你想多次计算某个子图的梯度的时候,设置为True。

create_graph (bool, optional) – If True, graph of the derivative will be constructed, allowing to compute higher order derivative products. Defaults to False.

inputs (sequence of Tensor) – Inputs w.r.t. which the gradient will be accumulated into .grad. All other Tensors will be ignored. If not provided, the gradient is accumulated into all the leaf Tensors that were used to compute the attr::tensors.

看到一篇比较清楚地讲解,我把结论写在这里,具体的参看原文:
参考:https://blog.csdn.net/witnessai1/article/details/79763596

  1. Tensor必须是一个一维标量
    如:a = v(t.FloatTensor([2, 3]), requires_grad=True)
    m = v(t.FloatTensor([[2, 3]]), requires_grad=True) 则不行
    会有如下报错:
    在这里插入图片描述
    报错信息:backward只能被应用在一个标量上,也就是一个一维tensor,或者传入跟变量相关的梯度。

  2. 特别注意Tensor里面默认的参数requires_grad=Falserequires_grad == True , 则表示它可以参与求导,也可以从它向后求导。
    requires_grad == True 具有传递性:如果:x.requires_grad == Truey.requires_grad == Falsez=f(x,y)z.requires_grad == True

  3. Tensor.backward(parameters)接受的 参数parameters必须要和Tensor的大小一模一样,然后作为Tensor的系数传回去

  4. 如果Tensor不是一个一维标量,想要获取对应梯度,需要计算jacobian矩阵。

loss.backward()和torch.autograd.grad的区别

参考:loss.backward()和torch.autograd.grad的区别

  • loss.backward()会将求导结果累加在grad上。这也是为什么在训练每个batch的最开始,需要对梯度清零的原因。
  • torch.autograd.grad不会将求导结果累加在grad上。
  • loss.backward()后,非叶子节点的导数计算完成之后就会被清空。不过,可以在非叶子节点之后,加上
    “非叶子节点.retain_grad()” 来解决这个问题。(作用同:requires_grad == True
  • torch.autograd.grad可以获取非叶子节点的梯度。
  • PS:Pytorch中的张量有一个is_leaf的属性。若一个张量为叶子节点,则其is_leaf属性就为True,若这个张量为非叶子节点,则其is_leaf属性就为False。一般地,由用户自己创建的张量为叶子节点。另外,神经网络中各层的权值weight和偏差bias对应的张量也为叶子节点。由叶子节点得到的中间张量为非叶子节点。在反向传播中,叶子节点可以理解为不依赖其它张量的张量。

Pytorch中的自动求导机制会根据输入和前向传播过程自动构建计算图(节点就是参与运算的变量,图中的边就是变量之间的运算关系),然后根据计算图进行反向传播,计算每个节点的梯度值。

Pytorch提供了两种求梯度的方法,分别是backward()和torch.autograd.grad()。

  • backward()方法可以计算根节点对应的所有叶子节点的梯度。
  • 如果不需要求出当前张量对所有产生该张量的叶子节点的梯度,则可以使用torch.autograd.grad()。

不过需要注意的是,这两种梯度方法都会在反向传播求导的时候释放计算图,如果需要再次做自动求导,因为计算图已经不再了,就会报错。如果要在反向传播的时候保留计算图,可以设置retain_graph=True。

pytorch的计算图

参考:https://zhuanlan.zhihu.com/p/33378444
pytorch是动态图机制,所以在训练模型时候,每迭代一次都会构建一个新的计算图。而计算图其实就是代表程序中变量之间的关系。举个列子: y = ( a + b ) ( b + c ) y = (a+b)(b+c) y=(a+b)(b+c) 在这个运算过程就会建立一个如下的计算图:
在这里插入图片描述
注意图中的 leaf_node,叶子结点就是由用户自己创建的Variable变量,在这个图中仅有a,b,c 是 leaf_node。为什么要关注leaf_node?因为在网络backward时候,需要用链式求导法则求出网络最后输出的梯度,然后再对网络进行优化,如下就是网络的求导过程。

在这里插入图片描述

x = Variable(torch.FloatTensor([[1, 2]]), requires_grad=True)  # 定义一个输入变量
y = Variable(torch.FloatTensor([[3, 4],        [5, 6]]))
loss = torch.mm(x, y)    # 变量之间的运算
loss.backward(torch.FloatTensor([[1, 0]]), retain_graph=True)  # 求梯度,保留图                                    
print(x.grad.data)   # 求出 x_1 的梯度x.grad.data.zero_()  # 最后的梯度会累加到叶节点,所以叶节点清零loss.backward(torch.FloatTensor([[0, 1]]))   # 求出 x_2的梯度
print(x.grad.data)        # 求出 x_2的梯度

这里有一点不太理解:为什么loss.backward(torch.FloatTensor([[1, 0]]), retain_graph=True)是对 x 1 x_1 x1求导,而loss.backward(torch.FloatTensor([[0, 1]])) 是对 x 2 x_2 x2求导。

好像有一点明白了,torch.FloatTensor([[1, 0]]) x 2 x_2 x2为0???

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/234237.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

linux中deadline调度原理与代码注释

简介 deadline调度是比rt调度更高优先级的调度,它没有依赖于优先级的概念,而是给了每个实时任务一定的调度时间,这样的好处是:使多个实时任务场景的时间分配更合理,不让一些实时任务因为优先级低而饿死。deadline调度…

(保姆级教程)一篇文章,搞定所有Linux命令,以及tar解压缩命令,wget、rpm等下载安装命令,Linux的目录结构,以及用户和用户组

文章目录 Linux命令1. Linux目录结构2. 基本命令(了解)3. 目录(文件夹)命令列出目录切换目录创建目录删除目录复制目录移动和重命名目录 4. 文件命令创建文件编辑文件编辑文件时的其他操作 查看文件移动/重命名文件复制文件删除文…

机器学习---聚类(原型聚类、密度聚类、层次聚类)

1. 原型聚类 原型聚类也称为“基于原型的聚类” (prototype-based clustering),此类算法假设聚类结构能通过一 组原型刻画。算法过程:通常情况下,算法先对原型进行初始化,再对原型进行迭代更新求解。著 名的原型聚类算法&#…

基于Redis限流(aop切面+redis实现“令牌桶算法”)

令牌桶算法属于流量控制算法,在一定时间内保证一个键(key)的访问量不超过某个阈值。这里的关键是设置一个令牌桶,在某个时间段内生成一定数量的令牌,然后每次访问时从桶中获取令牌,如果桶中没有令牌&#x…

【机器学习】梯度下降法:从底层手写实现线性回归

【机器学习】Building-Linear-Regression-from-Scratch 线性回归 Linear Regression0. 数据的导入与相关预处理0.工具函数1. 批量梯度下降法 Batch Gradient Descent2. 小批量梯度下降法 Mini Batch Gradient Descent(在批量方面进行了改进)3. 自适应梯度…

C++相关闲碎记录(17)

1、IO操作 (1)class及其层次体系 (2)全局性stream对象 (3)用来处理stream状态的成员函数 前四个成员函数可以设置stream状态并返回一个bool值,注意fail()返回是failbit或者badbit两者中是否任一…

【重点】【DP】152.乘积最大的子数组

题目 法1:DP 参考:https://blog.csdn.net/Innocence02/article/details/128326633 f[i]表示以i结尾的连续子数组的最大乘积,d[i]表示以i结尾的连续子数组的最小乘积 。 如果只有正数,我们只需要考虑最大乘积f[i];有负…

MATLAB - Gazebo 仿真环境

系列文章目录 前言 机器人系统工具箱(Robotics System Toolbox™)为使用 Gazebo 模拟器可视化的模拟环境提供了一个界面。通过 Gazebo,您可以在真实模拟的物理场景中使用机器人进行测试和实验,并获得高质量的图形。 Gazebo 可在…

c# OpenCV 基本绘画(直线、椭圆、矩形、圆、多边形、文本)(四)

我们将在这里演示如何使用几何形状和文本注释图像。 Cv2.Line() 绘制直线 Cv2.Ellipse() 绘制椭圆Cv2.Rectangle() 绘制矩形Cv2.Circle() 绘制圆Cv2.FillPoly() 绘制多边形Cv2.PutText() 绘制文本 一、绘制直线 Cv2.Line(image, start_point, end_point, color, thickness) …

从传统型数据库到非关系型数据库

一 什么是数据库 数据库顾名思义保存数据的仓库,其本质是一个具有数据存储功能的复杂系统软件,数据库最终把数据保存在计算机硬盘,但数据库并不是直接读写数据在硬盘,而是中间隔了一层操作系统,通过文件系统把数据保存…

2023ChatGPT浪潮,2024开源大语言模型会成王者?

《2023ChatGPT浪潮,2024开源大语言模型会成王者?》 一、2023年的回顾 1.1、背景 我们正迈向2023年的终点,回首这一年,技术行业的发展如同车轮滚滚。尽管互联网行业在最近几天基本上处于冬天,但在这一年间我们仍然经…

递归经典三题

目录1.斐波那契数列: 2.青蛙跳台阶问题: 3.汉诺塔问题 1.斐波那契数列: 由斐波那契数列从第三项开始,每一项等于前两项之和,可以使用递归计算给定整数的斐波那契数。 1,1,2,3&am…

酒水品牌网站建设的效果如何

酒是人们餐桌常常出现的饮品,市场中的大小酒品牌或经销商数量非常多,国内国外都有着巨大市场,酒讲究的是品质与品牌,信息发展迅速的时代,商家们都希望通过多种方式获得生意增长。 酒商非常注重品牌,消费者也…

为什么要编写测试用例,测试用例写给谁看?

“为什么要编写测试用例,测试用例写给谁看”,这个问题看似简单,但却涵盖了一系列复杂的考虑因素,并不太好回答。 为了向各位学测试的同学们解释清楚“为什么编写测试用例是至关重要的”,我将通过以下5个方面进行展开&…

EMD、EEMD、FEEMD、CEEMD、CEEMDAN的区别、原理和Python实现(二)EEMD

往期精彩内容: 风速预测(一)数据集介绍和预处理-CSDN博客 风速预测(二)基于Pytorch的EMD-LSTM模型-CSDN博客 风速预测(三)EMD-LSTM-Attention模型-CSDN博客 风速预测(四&#xf…

鸿蒙 - arkTs:渲染(循环 - ForEach,判断 - if)

ForEach循环渲染: 参数: 要循环遍历的数组,Array类型遍历的回调方法,Function类型为每一项生成唯一标识符的方法,有默认生成方法,非必传 使用示例: interface Item {name: String,price: N…

作物模型中引入灌溉参数

在没有设置灌溉时,土壤水分模拟结果如下找到了PCSE包中田间管理文件的标准写法 在agromanager.py中有详细的信息(如何设置灌溉以及施肥量) Version: 1.0 AgroManagement: - 2022-10-15:CropCalendar:crop_name: sugar-beetvariety_name:

HarmonyOS ArkTS 中DatePicker先择时间 路由跳转并传值到其它页

效果 代码 代码里有TextTimerController 这一种例用方法较怪,Text ,Button Datepicker 的使用。 import router from ohos.router’则是引入路由模块。 import router from ohos.router Entry Component struct TextnewClock {textTimerController: TextTimerContr…

管理类联考——数学——真题篇——按题型分类——充分性判断题——蒙猜E

老老规矩,看目录,平均每年2E,跟2D一样,D是全对,E是全错,侧面也看出10道题,大概是3A/B,3C,2D,2E,其实还是蛮平均的。但E为1道的情况居多。 第20题…