PyTorch框架学习四——计算图与动态图机制

PyTorch框架学习四——计算图与动态图机制

  • 一、计算图
  • 二、动态图与静态图
  • 三、torch.autograd
    • 1.torch.autograd.backward()
    • 2.torch.autograd.grad()
    • 3.autograd小贴士
    • 4.代码演示理解
      • (1)构建计算图并反向求导:
      • (2)grad_tensors的理解:
      • (3)autograd.gard与create_graph的结合:
      • (4)小贴士1
      • (5)小贴士2
      • (6)小贴士3

一、计算图

计算图是用来描述运算的有向无环图,它包含了两个主要元素:结点(Node)与边(Edge)。其中结点表示数据,如向量、矩阵、张量等,而边表示运算,如加减乘除卷积等。

下面用计算图来表示:y=(x+w)×(w+1)这样的一个运算。

令 a = x + w,b = w + 1,则y = a×b。

计算图如下图所示:
在这里插入图片描述
构建这样的计算图是很方便求解梯度的,以对w求偏导为例,假设x和w的初始值为2和1:
在这里插入图片描述
强调两个概念:

  1. 叶子结点:用户创建的结点,如x与w,tensor中的is_leaf属性就是指示张量是否为叶子结点。非叶子结点运算后会被释放,叶子结点的梯度会被保留,若想保留非叶子结点的梯度,可以用retain_grad()。
  2. grad_fn:记录创建张量时所用的方法(函数),在梯度的反向传播时会用到,以上述的计算为例:y.grad_fn = <MulBackward 0>,a.grad_fn = <AddBackward 0>。

二、动态图与静态图

动态图静态图
实现方式运算与搭建同时进行先搭建计算图,后计算
特点灵活易调节高效但不灵活
框架PyTorchTensorFlow

三、torch.autograd

这是一个自动求导系统,提供了类和函数用来对任意标量函数进行求导。
下面介绍autograd中两个自动求导的函数:

1.torch.autograd.backward()

没有返回值,但是已经对数据进行了自动求导。

torch.autograd.backward(tensors: Union[torch.Tensor, Sequence[torch.Tensor]], grad_tensors: Union[torch.Tensor, Sequence[torch.Tensor], None] = None, retain_graph: Optional[bool] = None, create_graph: bool = False, grad_variables: Union[torch.Tensor, Sequence[torch.Tensor], None] = None)

在这里插入图片描述

  1. tensors:用于求导的张量们。
  2. grad_tensors:多梯度权重,下面用例子理解。
  3. retain_graph:(布尔,可选)若为False,计算图计算完之后就会被释放,若为True,则会保留。
  4. create_graph:(布尔,可选)若为True,会创建导数的计算图,用于更高阶的求导,默认为False。

2.torch.autograd.grad()

torch.autograd.grad(outputs: Union[torch.Tensor, Sequence[torch.Tensor]], inputs: Union[torch.Tensor, Sequence[torch.Tensor]], grad_outputs: Union[torch.Tensor, Sequence[torch.Tensor], None] = None, retain_graph: Optional[bool] = None, create_graph: bool = False, only_inputs: bool = True, allow_unused: bool = False)

在这里插入图片描述

3.autograd小贴士

  1. 梯度不会自动清零,若不清零,则会叠加上原来的数据,需要手动清零:grad.zero_()。
  2. 依赖于叶子结点的结点,requires_grad默认为True。
  3. 叶子结点不可执行in-place操作,in-place操作为在原始内存中改变数据的操作,如a += torch.ones((1, )) (加等操作a的内存地址不变,所以对张量不能做这项操作)。这是因为,在前向传播时,会记录叶子结点的地址,反向求导时是会依据记录的地址去取值进行运算的,若在中途用in-place操作改变了值,则梯度求解会出错。

4.代码演示理解

(1)构建计算图并反向求导:

# 设置 x 和 w 的初始值
w = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)
# 构建计算图
a = torch.add(w, x)
b = torch.add(w, 1)
y = torch.mul(a, b)
# 反向求导
y.backward()
print(w.grad, x.grad)

结果如下,与手动计算结果一致:

tensor([5.]) tensor([2.])

(2)grad_tensors的理解:

w = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)a = torch.add(w, x)     # retain_grad()
b = torch.add(w, 1)y0 = torch.mul(a, b)    # y0 = (x+w) * (w+1)
y1 = torch.add(a, b)    # y1 = (x+w) + (w+1)    dy1/dw = 2loss = torch.cat([y0, y1], dim=0)       # [y0, y1]
grad_tensors = torch.tensor([1., 2.])loss.backward(gradient=grad_tensors)    # gradient 传入 torch.autograd.backward()中的grad_tensorsprint(w.grad)

grad_tensors的作用就类似于一个权重,当要求导的对象有多个梯度时,它就是各个梯度加和的权重,比如这里 dy0 / dw = 5,dy1/dw = 2,那么w的总梯度值为 5×1 + 2×2 = 9。
结果如下:

tensor([9.])

(3)autograd.gard与create_graph的结合:

x = torch.tensor([3.], requires_grad=True)
y = torch.pow(x, 2)     # y = x**2grad_1 = torch.autograd.grad(y, x, create_graph=True)   # grad_1 = dy/dx = 2x = 2 * 3 = 6
print(grad_1)grad_2 = torch.autograd.grad(grad_1[0], x)              # grad_2 = d(dy/dx)/dx = d(2x)/dx = 2
print(grad_2)

其中y就是用于求导的张量,x就是需要梯度的张量,grad_1就是第一次求导后x的梯度,因为create_graph为True,所以已经构建了导数的计算图,可以对grad_1再次求导,得到第二次求导后x的梯度grad_2:

(tensor([6.], grad_fn=<MulBackward0>),)
(tensor([2.]),)

(4)小贴士1

这里我们构建了四次计算图,四次一模一样的计算:

w = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)for i in range(4):a = torch.add(w, x)b = torch.add(w, 1)y = torch.mul(a, b)y.backward()print(w.grad)

若我们不对梯度手动清零,结果就如下所示:

tensor([5.])
tensor([10.])
tensor([15.])
tensor([20.])

因为每次的梯度都一样都为5,所以若不手动清零,则梯度会叠加起来。
我们在原来的基础上添加上手动清零:

w = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)for i in range(4):a = torch.add(w, x)b = torch.add(w, 1)y = torch.mul(a, b)y.backward()print(w.grad)w.grad.zero_()

结果为:

tensor([5.])
tensor([5.])
tensor([5.])
tensor([5.])

这样才是正确的。

(5)小贴士2

w = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)a = torch.add(w, x)
b = torch.add(w, 1)
y = torch.mul(a, b)print(a.requires_grad, b.requires_grad, y.requires_grad)

结果为:

True True True

(6)小贴士3

a = torch.ones((1, ))
print(id(a), a)a += torch.ones((1, ))
print(id(a), a)

结果为:

2379593847576 tensor([1.])
2379593847576 tensor([2.])

可见加等操作是in-place操作,是在原始内存中改变数据的操作。
而加法操作就不是,如下所示:

a = torch.ones((1, ))
print(id(a), a)a = a + torch.ones((1, ))
print(id(a), a)

内存不一样:

3019154559688 tensor([1.])
3019156480632 tensor([2.])

如果我们对叶子结点进行in-place操作:

w = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)a = torch.add(w, x)
b = torch.add(w, 1)
y = torch.mul(a, b)w += torch.ones((1,))
# w.add_(1)y.backward()print(w.grad)

会报如下错误提示:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/491714.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

美国准备跳过5G直接到6G 用上万颗卫星包裹全球,靠谱吗?

来源&#xff1a;瞭望智库这项2015年提出的计划&#xff0c;规模极其巨大&#xff0c;总计要在2025年前发射近12000颗卫星。有自媒体认为&#xff0c;该计划表示美国将在太空中建立下一代宽带网络&#xff0c;绕过5G&#xff0c;直接升级到6G&#xff0c;并据此认为“6G并不遥远…

PyTorch框架学习五——图像预处理transforms(一)

PyTorch框架学习五——图像预处理transforms&#xff08;一&#xff09;一、transforms运行机制二、transforms的具体方法1.裁剪&#xff08;1&#xff09;随机裁剪&#xff1a;transforms.RandomCrop()&#xff08;2&#xff09;中心裁剪&#xff1a;transforms.CenterCrop()&…

IBM Watson大裁70% 员工,撕掉了国内大批伪AI企业最后一块遮羞布!

来源:新医路Watson 是IBM 的重量级AI 系统&#xff1b;近年IBM 大力发展AI 医疗&#xff0c;在2015 年成立独立的 Watson Health 部门&#xff0c;并收购多家医疗数据公司&#xff0c;前景看好。然而短短三年&#xff0c;这个明星部门就要裁员50% 到70% 的员工&#xff0c;代表…

PyTorch框架学习六——图像预处理transforms(二)

PyTorch框架学习六——图像预处理transforms&#xff08;二&#xff09;&#xff08;续&#xff09;二、transforms的具体方法4.图像变换&#xff08;1&#xff09;尺寸变换&#xff1a;transforms.Resize()&#xff08;2&#xff09;标准化&#xff1a;transforms.Normalize()…

numpy方法读取加载mnist数据集

方法来自机器之心公众号 首先下载mnist数据集&#xff0c;并将里面四个文件夹解压出来&#xff0c;下载方法见前面的博客 import tensorflow as tf import numpy as np import osdataset_path rD:\PycharmProjects\tensorflow\MNIST_data # 这是我存放mnist数据集的位置 is_…

纳米线传感器来了,传感芯片还会远吗

来源&#xff1a;科学网“无旁路电路”纳米线桥接生长方案 黄辉供图微型气体检测仪 黄辉供图人工智能、可穿戴装备、物联网等信息技术迅猛发展&#xff0c;需要海量的传感器提供支持&#xff0c;大数据和云计算等业务也需要各种传感器实时采集数据来支撑。但目前的传感器存在国…

PyTorch框架学习七——自定义transforms方法

PyTorch框架学习七——自定义transforms方法一、自定义transforms注意要素二、自定义transforms步骤三、自定义transforms实例&#xff1a;椒盐噪声虽然前面的笔记介绍了很多PyTorch给出的transforms方法&#xff0c;也非常有用&#xff0c;但是也有可能在具体的问题中需要开发…

美国芯片简史:军方大力扶持下的产物 但一度被日 韩超越

来源&#xff1a;知乎专栏腾讯科技近日发起系列策划&#xff0c;聚焦各个芯片大国的发展历程。第四期&#xff1a;《美国芯片简史》。集成电路是电子信息产业的的基石&#xff0c;电子信息产业对国民经济与社会发展具有重大推动作用。从全球集成电路产业发展历程来看&#xff0…

PyTorch框架学习八——PyTorch数据读取机制(简述)

PyTorch框架学习八——PyTorch数据读取机制&#xff08;简述&#xff09;一、数据二、DataLoader与Dataset1.torch.utils.data.DataLoader2.torch.utils.data.Dataset三、数据读取整体流程琢磨了一段时间&#xff0c;终于对PyTorch的数据读取机制有了一点理解&#xff0c;并自己…

报告 | 2019年全球数字化转型现状研究报告

来源&#xff1a;Prophet2019年&#xff0c;战略数字化转型的重要性已经不止于IT领域&#xff0c;而影响着全公司的竞争力。企业的相关预算直线攀升&#xff0c;利益相关方所关注的颠覆性技术数量急剧增加。数字化项目开始由首席高管主导&#xff0c;并由相互协作的跨职能团队管…

Android调用binder实现权限提升-android学习之旅(81)

当进程A权限较低&#xff0c;而B权限较高时&#xff0c;容易产生提权漏洞 fuzz测试的测试路径 First level Interface是服务 Second level Interface是服务中对应的接口 1.首先获取第一层和第二层接口&#xff0c;及服务以及对应服务提供的接口 2.根据以上信息结合参数类型信息…

PyTorch框架学习九——网络模型的构建

PyTorch框架学习九——网络模型的构建一、概述二、nn.Module三、模型容器Container1.nn.Sequential2.nn.ModuleList3.nn.ModuleDict()4.总结笔记二到八主要介绍与数据有关的内容&#xff0c;这次笔记将开始介绍网络模型有关的内容&#xff0c;首先我们不追求网络内部各层的具体…

中国17种稀土有啥军事用途?没它们,美军技术优势将归零

来源&#xff1a;陶慕剑观察 稀土就是化学元素周期表中镧系元素——镧(La)、铈(Ce)、镨(Pr)、钕(Nd)、钷(Pm)、钐(Sm)、铕(Eu)、钆(Gd)、铽(Tb)、镝(Dy)、钬(Ho)、铒(Er)、铥(Tm)、镱(Yb)、镥(Lu)&#xff0c;再加上钪(Sc)和钇(Y)共17种元素。中国稀土占据着众多的世界第一&…

PyTorch框架学习十——基础网络层(卷积、转置卷积、池化、反池化、线性、激活函数)

PyTorch框架学习十——基础网络层&#xff08;卷积、转置卷积、池化、反池化、线性、激活函数&#xff09;一、卷积层二、转置卷积层三、池化层1.最大池化nn.MaxPool2d2.平均池化nn.AvgPool2d四、反池化层最大值反池化nn.MaxUnpool2d五、线性层六、激活函数层1.nn.Sigmoid2.nn.…

PyTorch框架学习十一——网络层权值初始化

PyTorch框架学习十一——网络层权值初始化一、均匀分布初始化二、正态分布初始化三、常数初始化四、Xavier 均匀分布初始化五、Xavier正态分布初始化六、kaiming均匀分布初始化前面的笔记介绍了网络模型的搭建&#xff0c;这次将介绍网络层权值的初始化&#xff0c;适当的初始化…

W3C 战败:无权再制定 HTML 和 DOM 标准!

来源&#xff1a;CSDN历史性时刻&#xff01;——近日&#xff0c;W3C正式宣告战败&#xff1a;HTML和DOM标准制定权将全权移交给浏览器厂商联盟WHATWG。由苹果、Google、微软和Mozilla四大浏览器厂商组成的WHATWG已经与万维网联盟&#xff08;World Wide Web Consortium&#…

PyTorch框架学习十二——损失函数

PyTorch框架学习十二——损失函数一、损失函数的作用二、18种常见损失函数简述1.L1Loss&#xff08;MAE&#xff09;2.MSELoss3.SmoothL1Loss4.交叉熵CrossEntropyLoss5.NLLLoss6.PoissonNLLLoss7.KLDivLoss8.BCELoss9.BCEWithLogitsLoss10.MarginRankingLoss11.HingeEmbedding…

化合物半导体的机遇

来源&#xff1a;国盛证券半导体材料可分为单质半导体及化合物半导体两类&#xff0c;前者如硅&#xff08;Si&#xff09;、锗(Ge&#xff09;等所形成的半导体&#xff0c;后者为砷化镓&#xff08;GaAs&#xff09;、氮化镓&#xff08;GaN&#xff09;、碳化硅&#xff08;…

PyTorch框架学习十三——优化器

PyTorch框架学习十三——优化器一、优化器二、Optimizer类1.基本属性2.基本方法三、学习率与动量1.学习率learning rate2.动量、冲量Momentum四、十种常见的优化器&#xff08;简单罗列&#xff09;上次笔记简单介绍了一下损失函数的概念以及18种常用的损失函数&#xff0c;这次…

最全芯片产业报告出炉,计算、存储、模拟IC一文扫尽

来源&#xff1a;智东西最近几年&#xff0c; 半导体产业风起云涌。 一方面&#xff0c; 中国半导体异军突起&#xff0c; 另一方面&#xff0c; 全球产业面临超级周期&#xff0c;加上人工智能等新兴应用的崛起&#xff0c;中美科技摩擦频发&#xff0c;全球半导体现状如何&am…