pytorch如何计算导数_PyTorch怎么用?来看这里

构建深度学习模型的基本流程就是:搭建计算图,求得损失函数,然后计算损失函数对模型参数的导数,再利用梯度下降法等方法来更新参数。

搭建计算图的过程,称为“正向传播”,这个是需要我们自己动手的,因为我们需要设计我们模型的结构。由损失函数求导的过程,称为“反向传播”,求导是件辛苦事儿,所以自动求导基本上是各种深度学习框架的基本功能和最重要的功能之一,PyTorch也不例外。

我们今天来体验一下PyTorch的自动求导吧,好为后面的搭建模型做准备。

一、设置Tensor的自动求导属性

所有的tensor都有.requires_grad属性,都可以设置成自动求导。具体方法就是在定义tensor的时候,让这个属性为True:

x = tensor.ones(2,4,requires_grad=True)

In [1]: import torchIn [2]: x = torch.ones(2,4,requires_grad=True)In [3]: print(x)tensor([[1., 1., 1., 1.], [1., 1., 1., 1.]], requires_grad=True)

只要这样设置了之后,后面由x经过运算得到的其他tensor,就都有equires_grad=True属性了。

可以通过x.requires_grad来查看这个属性。

In [4]: y = x + 2In [5]: print(y)tensor([[3., 3., 3., 3.], [3., 3., 3., 3.]], grad_fn=)In [6]: y.requires_gradOut[6]: True

如果想改变这个属性,就调用tensor.requires_grad_()方法:

In [22]: x.requires_grad_(False)Out[22]:tensor([[1., 1., 1., 1.], [1., 1., 1., 1.]])In [21]: print(x.requires_grad,y.requires_grad)False True

这里,注意区别tensor.requires_grad和tensor.requires_grad_()两个东西,前面是调用变量的属性值,后者是调用内置的函数,来改变属性。

二、求导

下面我们来试试自动求导到底怎么样。

我们首先定义一个计算图(计算的步骤):

In [28]: x = torch.tensor([[1.,2.,3.],[4.,5.,6.]],requires_grad=True)In [29]: y = x+1In [30]: z = 2*y*yIn [31]: J = torch.mean(z)

这里需要注意的是,要想使x支持求导,必须让x为浮点类型,也就是我们给初始值的时候要加个点:“.”。不然的话,就会报错。

即,不能定义[1,2,3],而应该定义成[1.,2.,3.],前者是整数,后者才是浮点数。

上面的计算过程可以表示为:

9004db61ec6f72ac8c55f895328f4075.png

好了,重点注意的地方来了!

x、y、z都是tensor,但是size为(2,3)的矩阵。但是J是对z的每一个元素加起来求平均,所以J是标量。

求导,只能是【标量】对标量,或者【标量】对向量/矩阵求导!

所以,上图中,只能J对x、y、z求导,而z则不能对x求导。

我们不妨试一试:

  • PyTorch里面,求导是调用.backward()方法。直接调用backward()方法,会计算对计算图叶节点的导数。获取求得的导数,用.grad方法。

试图z对x求导:

In [31]: z.backward()# 会报错:Traceback (most recent call last) in ()----> 1 z.backward()RuntimeError: grad can be implicitly created only for scalar outputs

正确的应该是J对x求导:

In [33]: J.backward()In [34]: x.gradOut[34]:tensor([[1.3333, 2.0000, 2.6667], [3.3333, 4.0000, 4.6667]])

检验一下,求的是不是对的。

J对x的导数应该是什么呢?

9047fbfc9f55a09571225a1d9795e5eb.png

检查发现,导数就是:

[[1.3333, 2.0000, 2.6667],

[3.3333, 4.0000, 4.6667]]

总结一下,构建计算图(正向传播,Forward Propagation)和求导(反向传播,Backward Propagation)的过程就是:

c8b8d0f0ea84b7cb7b97718457654cc9.png

三、关于backward函数的一些其他问题:

1. 不是标量也可以用backward()函数来求导?

在看文档的时候,有一点我半天没搞懂:

他们给了这样的一个例子:

947acbb92c149d22d7e123d093b9a190.png

我在前面不是说“只有标量才能对其他东西求导”么?它这里的y是一个tensor,是一个向量。按道理不能求导呀。这个参数gradients是干嘛的?

但是,如果看看backward函数的说明,会发现,里面确实有一个gradients参数:

18d6ad4513caa5ed88630e81a1bf07ff.png
6459d0721136202797e512232899e414.png

从说明中我们可以了解到:

  • 如果你要求导的是一个标量,那么gradients默认为None,所以前面可以直接调用J.backward()就行了如果你要求导的是一个张量,那么gradients应该传入一个Tensor。那么这个时候是什么意思呢?

在StackOverflow有一个解释很好:

24cba549f87ee73450c53f7965264483.png

一般来说,我是对标量求导,比如在神经网络里面,我们的loss会是一个标量,那么我们让loss对神经网络的参数w求导,直接通过loss.backward()即可。

但是,有时候我们可能会有多个输出值,比如loss=[loss1,loss2,loss3],那么我们可以让loss的各个分量分别对x求导,这个时候就采用:

loss.backward(torch.tensor([[1.0,1.0,1.0,1.0]]))

如果你想让不同的分量有不同的权重,那么就赋予gradients不一样的值即可,比如:

loss.backward(torch.tensor([[0.1,1.0,10.0,0.001]]))

这样,我们使用起来就更加灵活了,虽然也许多数时候,我们都是直接使用.backward()就完事儿了。

2. 一个计算图只能backward一次

一个计算图在进行反向求导之后,为了节省内存,这个计算图就销毁了。

如果你想再次求导,就会报错。

比如你定义了计算图:

ef6ed079cb9991e32c7b4bdcdecf2c7e.png

你先求p求导,那么这个过程就是反向的p对y求导,y对x求导。

求导完毕之后,这三个节点构成的计算子图就会被释放:

765b57889916d654f257036d5afd8172.png

那么计算图就只剩下z、q了,已经不完整,无法求导了。

所以这个时候,无论你是想再次运行p.backward()还是q.backward(),都无法进行,报错如下:

RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.

好,怎么办呢?

遇到这种问题,一般两种情况:

1. 你的实际计算,确实需要保留计算图,不让子图释放。

那么,就更改你的backward函数,添加参数retain_graph=True,重新进行backward,这个时候你的计算图就被保留了,不会报错。

但是这样会吃内存!,尤其是,你在大量迭代进行参数更新的时候,很快就会内存不足,memory out了。

2. 你实际根本没必要对一个计算图backward多次,而你不小心多跑了一次backward函数。

通常,你要是在IPython里面联系PyTorch的时候,因为你会反复运行一个单元格的代码,所以很容易一不小心把backward运行了多次,就会报错。这个时候,你就检查一下代码,防止backward运行多次即可。

文章转自:https://zhuanlan.zhihu.com/p/51385110

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/532947.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

crt中 新建的连接存储在哪_数字存储示波器的VPO技术

当使用数字存储示波器测量串行传输信号、数字电路上的地址/数据/控制总线、信号元器件上的噪声、复合视频信号或调制信号时,面临的最大困难在于这些信号随机、变化迅速、杂乱或不具备周期性。因此,为了提高捕获这些信号的几率,减少数字存储示…

计算机在平面设计中的作用,比例设计在平面设计中的作用与意义

随着互联网的不断发展,用户体验在设计师的产品设计中占有的比重越大了,而今天我们就一起来了解一下,比例设计在平面设计中的作用与意义。一、平面设计中的比例是什么?比例尺是指设计元素相对于其他元素的相对大小。一个物体只有在与其他物体…

元组可以直接添加进数据库吗_数据库篇-第一章:数据库基本概念

面试必备基础数据库知识,扫码关注公众号提升01 第一,什么是数据库?维基百科上是这样定义的:所谓“数据库”是以一定方式储存在一起、能予多个用户共享、具有尽可能小的冗余度、与应用程序彼此独立的数据集合。一个数据库由多个表空…

win7计算机找不到脚本文件夹,win7系统TXT文件打开提示找不到脚本文件的解决方法...

很多小伙伴都遇到过win7系统TXT文件打开提示找不到脚本文件的困惑吧,一些朋友看过网上零散的win7系统TXT文件打开提示找不到脚本文件的处理方法,并没有完完全全明白win7系统TXT文件打开提示找不到脚本文件是如何解决的,今天小编准备了简单的解…

剪切文件_转录组测序技术和结果解读(十六)——可变剪切

可变剪切的概念可变剪切是指从一个mRNA前体中通过不同剪接方式,选择不同的剪接位点组合,所产生不同的mRNA剪接异构体的过程。可变剪切的分类:外显子缺失 (Exon skipping);可变的5’端剪切 (Alternative 5’ splicing);…

archlinux详细安装步骤_最新Centos的liunx安装宝塔的详细步骤

很多人买的服务器是win系统或者是liunx系统,要是说win那就基本上不用学习就和自己的电脑一样操作就可以,但是有些新人刚接触liunx系统不知道怎么安装宝塔环境那今天126云就给大家详细介绍一下 步骤和操作请看下图准备的东西是 挂载磁盘 这个简单介绍就是…

卡诺模型案例分析_3个维度看竞品分析!

谁都想站在巨人的肩膀上,问题是怎么上去?ABC分享会线下24期回顾时间:10月24日 下午13:00-17:30地点:上海嘉定U-CUBE创意空间 参与人数:18人主题:怎样做竞品分析这次活动是第二次有上…

intellij服务器证书不受信任,ssl证书不受信任怎么办?ssl证书不受信任解决方案有什么?...

随之愈来愈多的ssl证书错误的状况出現,大伙儿都是有ssl证书不受信任怎么办这类的难题,而且对这种难题很头痛,下边将带大伙儿解析一下ssl证书不受信任的缘故及解决方案。一、ssl证书不会受到信任是什么缘故1、SSL证书并不是来源于认可的SSL证书…

小马源码_Java互联网架构-重新认识Java8-HashMap-不一样的源码解读

欢迎关注头条号:java小马哥周一至周日早九点半!下午三点半!精品技术文章准时送上!!!精品学习资料获取通道,参见文末看源码前我们必须先知道一下ConcurrentHashMap的基本结构。ConcurrentHashMap…

安装默认报表服务器虚拟目录,报表服务器虚拟目录(Reporting Services 配置)

报表服务器虚拟目录(Reporting Services 配置)12/15/2008本文内容使用“报表服务器虚拟目录”页可以配置报表服务器的虚拟目录。用于访问报表服务器 Web 服务的 URL 将包含该虚拟目录名称。完整的 URL 包括前缀(http:// 或 https://)、服务器名称和虚拟目录。服务器名称可能是内…

小程序向webview传参_独家 | 支付宝小程序向个人开发者开放公测

基于兴趣和周围小群体开发的个人小程序,才是为支付宝提供更加多样化的生活服务场景的来源。文 | Tech星球 (微信ID:tech618) 尹非凡、刘宁宁2月26日,Tech星球(微信ID:tech618) 独家获悉,支付宝小程序今日正式面向个人…

原神服务器维护后抽奖池会更新吗,原神:更新维护一小时,补偿60原石,玩家祈求多维护几天!...

10月21号,原神社区发布公告,游戏将会在10月22号7点至11点进行停服维护,所有玩家在这个时间段将无法进入游戏。而作为补偿,官方会赠送5级以上的玩家240原石(停服一小时送60原石)。这是偷偷的更新吗?官方并没有说更新内容…

涉及子模块_COMSOL Multiphysics 5.6 RF模块更新详解

业界领先的多物理场仿真、App 设计与部署的软件解决方案提供商COMSOL 公司发布了全新的COMSOL Multiphysics 软件5.6 版本。新版本为多核和集群计算提供了计算速度更快且内存需求更低的求解器、更加高效的CAD 装配处理功能、仿真App 布局模板,以及一系列包括剪裁平面…

系统参数shell服务器,shell 调用远程服务器shell

shell 调用远程服务器shell 内容精选换一换流程定义文件描述业务逻辑的XML文件,包括workflow.xml、coordinator.xml、bundle.xml三类,最终由Oozie引擎解析并执行。描述业务逻辑的XML文件,包括workflow.xml、coordinator.xml、bundle.xml三类&…

endnote国标_Citavi 与 Endnote 在 Word 插入引用,哪个更适合你?

前言:不黑、不吹,客观讨论,如有补充请留言,我们一定完善内容。我们先看下两者在 Word 界面的显示截图:Endnote :(看起来很简洁)Citavi :(看起来功能多一些&am…

思科服务器如何修改启动项,思科配置tftp服务器

思科配置tftp服务器 内容精选换一换使用mount命令挂载文件系统到云服务器,云服务器系统提示timed out。原因1:网络状态不稳定。原因2:网络连接异常。原因3:云服务器DNS配置错误,导致解析不到文件系统的域名&#xff0c…

社保费客户端显示服务器连接异常,社保费客户端登录服务器异常

社保费客户端登录服务器异常 内容精选换一换本章节指导您使用MongoDB客户端,通过弹性云服务器内网方式连接GaussDB(for Mongo)集群实例。操作系统使用场景:弹性云服务器的操作系统以Linux为例,客户端本地使用的计算机系统以Windows为例。目标…

双继承_在Python中使用双下划线防止类属性被覆盖!

在使用Python编写面向对象的代码时,我们会常常使用“继承”这种开发方式。例如下面这一段代码:class Info:def __init__(self):passdef calc_age(self):print(我是父类的方法) class PeopleInfo(Info):def __init__(self):super().__init__()def calc_ag…

云服务器 自有操作系统,云服务器 自有操作系统

云服务器 自有操作系统 内容精选换一换监控是保持云耀云服务器可靠性、可用性和性能的重要部分,通过监控,用户可以观察云耀云服务器资源。为使用户更好地掌握自己的云耀云服务器运行状态,公有云平台提供了云监控。您可以使用该服务监控您的云…

分割线不显示_90后都30岁了,为什么还不结婚

2020年中国第一批90后已经30岁了。在传统观念里,30岁作为人生的分水岭,成家,立业,结婚,生子,通通要在这之前解决掉,才算赶上了,人生的进度条,然而媒体针对90后&#xff0…