multi task训练torch_手把手教你使用PyTorch(2)-requires_gradamp;computation graph

import torch

1. Requires_grad

但是,模型毕竟不是人,它的智力水平还不足够去自主辨识那些量的梯度需要计算,既然如此,就需要手动对其进行标记。

在PyTorch中,通用的数据结构tensor包含一个attributerequires_grad,它被用于说明当前量是否需要在计算中保留对应的梯度信息,以上文所述的线性回归为例,容易知道参数www为需要训练的对象,为了得到最合适的参数值,我们需要设置一个相关的损失函数,根据梯度回传的思路进行训练。

官方文档中的说明如下

If there’s a single input to an operation that requires gradient, its output will also require gradient.

只要某一个输入需要相关梯度值,则输出也需要保存相关梯度信息,这样就保证了这个输入的梯度回传。

而反之,若所有的输入都不需要保存梯度,那么输出的requires_grad会自动设置为False。既然没有了相关的梯度值,自然进行反向传播时会将这部分子图从计算中剔除。

Conversely, only if all inputs don’t require gradient, the output also won’t require it. Backward computation is never performed in the subgraphs, where all Tensors didn’t require gradients.

对于那些要求梯度的tensor,PyTorch会存储他们相关梯度信息和产生他们的操作,这产生额外内存消耗,为了优化内存使用,默认产生的tensor是不需要梯度的。

而我们在使用神经网络时,这些全连接层卷积层等结构的参数都是默认需要梯度的。

a = torch.tensor([1., 2., 3.])

print('a:', a.requires_grad)

b = torch.tensor([1., 4., 2.], requires_grad = True)

print('b:', b.requires_grad)

print('sum of a and b:', (a+b).requires_grad)

a: False

b: True

sum of a and b: True

2. Computation Graph

从PyTorch的设计原理上来说,在每次进行前向计算得到pred时,会产生一个用于梯度回传的计算图,这张图储存了进行back propagation需要的中间结果,当调用了.backward()后,会从内存中将这张图进行释放

这张计算图保存了计算的相关历史和提取计算所需的所有信息,以output作为root节点,以input和所有的参数为leaf节点,

we only retain the grad of the leaf node with requires_grad =True

在完成了前向计算的同时,PyTorch也获得了一张由计算梯度所需要的函数所组成的图

而从数据集中获得的input其requires_grad为False,故我们只会保存参数的梯度,进一步据此进行参数优化

在PyTorch中,multi-task任务一个标准的train from scratch流程为

for idx, data in enumerate(train_loader):

xs, ys = data

optmizer.zero_grad()

# 计算d(l1)/d(x)

pred1 = model1(xs) #生成graph1

loss = loss_fn1(pred1, ys)

loss.backward() #释放graph1

# 计算d(l2)/d(x)

pred2 = model2(xs)#生成graph2

loss2 = loss_fn2(pred2, ys)

loss.backward() #释放graph2

# 使用d(l1)/d(x)+d(l2)/d(x)进行优化

optmizer.step()

Computation Graph本质上是一个operation的图,所有的节点都是一个operation,而进行相应计算的参数则以叶节点的形式进行输入

借助torchviz库以下面的模型作为示例

import torch.nn.functional as F

import torch.nn as nn

class Conv_Classifier(nn.Module):

def __init__(self):

super(Conv_Classifier, self).__init__()

self.conv1 = nn.Conv2d(1, 5, 5)

self.pool1 = nn.MaxPool2d(2)

self.conv2 = nn.Conv2d(5, 16, 5)

self.pool2 = nn.MaxPool2d(2)

self.fc1 = nn.Linear(256, 20)

self.fc2 = nn.Linear(20, 10)

def forward(self, x):

x = F.relu(self.pool1((self.conv1(x))))

x = F.relu(self.pool2((self.conv2(x))))

x = F.dropout2d(x, training=self.training)

x = x.view(-1, 256)

x = F.relu(self.fc1(x))

x = F.relu(self.fc2(x))

return x

Mnist_Classifier = Conv_Classifier()

from torchviz import make_dot

input_sample = torch.rand((1, 1, 28, 28))

make_dot(Mnist_Classifier(input_sample), params=dict(Mnist_Classifier.named_parameters()))

其对应的计算梯度所需的图(计算图)为

可以看到,所有的叶子节点对应的操作都被记录,以便之后的梯度回传。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/409336.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【原】unity shader(3)反射贴图

改编自《cg教程--可编程实时图形学权威指南》上的demo。 反射向量计算公式 RI-2N(N*I) 备注N*I是点乘 I入射光线,N法向量 函数实现: float3 reflect(float3 I,float3 N) { return I-2.0*N*dot(N,I); } Shader "CG shader Reflect"{Propertie…

分时技术用户可以独占计算机资源,计算机基础第二章选择题(带答案修改版 )校过...

一,选择题(选择最确切的一个答案,将其代码填入括号中)1、操作系统是一种( ).A, 应用软件 B, 系统软件 C, 通用软件 D, 工具软件2、计算机系统的组成包括( ).A,程序和数据 B,处理器和内存 C,计算机硬件和计算机软件 D,处理器,存储器和外围设备3、下面关于计算机软件的描述正确的…

div显示在上层_DIV重叠 如何优先显示(div浮在重叠的div上面)

如果有2个div有重叠,默认是根据html解析顺序,最后加载的优先级最高(浮在最上面)。问题:如果想把前面加载的div显示在最上面?关键字:z-index举例:--原来的页面:first div是被second div盖住了&am…

计算机专业英语的时态特点,英语时态表的学习与整理

学习英语必须要了解英语的各种时态,不了解时态会说话时闹出笑话,也会引出歧义,做题的时候也会出错。所以英语的时态一定要分清也要记牢,无论什么时候运用英语的时候都不要忘记。现在就和沪江小编一起了解了解吧!英语时态表 时态名…

el replace 表达式_EL表达式运算符、常用函数详解

运算符&#xff1a;1.算术运算符有五个&#xff1a;、-、*或$、/或div、%或mod2.关系运算符有六个&#xff1a;或eq、!或ne、或gt、<或le、>或ge3.逻辑运算符有三个&#xff1a;&&或and、||或or、!或not4.其它运算符有三个&#xff1a;Empty运算符、条件运算符、…

[短彩信]C#短彩信模块开发设计(2)——配置

准备从以下几个方面简单的谈谈短彩信模块的实现&#xff1a; [短彩信]C#短彩信模块开发设计&#xff08;1&#xff09;——架构&#xff08;http://www.cnblogs.com/CopyPaster/archive/2012/12/07/2806776.html&#xff09;[短彩信]C#短彩信模块开发设计&#xff08;2&#xf…

python管理工具ports_Python options.port方法代码示例

本文整理汇总了Python中tornado.options.port方法的典型用法代码示例。如果您正苦于以下问题&#xff1a;Python options.port方法的具体用法&#xff1f;Python options.port怎么用&#xff1f;Python options.port使用的例子&#xff1f;那么恭喜您, 这里精选的方法代码示例或…

html5 canvas文字颜色,我可以通过HTML5 Canvas中的字符文本颜色来做吗?

我告诉你这个解决方法.基本上你一次输出一个字符,并使用内置的measureText()函数来确定每个字母的宽度.然后我们将我们想要绘制的位置偏移相同的数量.您可以修改此代码段,以产生所需的效果.假设我们有这样的HTML&#xff1a;和Javascript一样&#xff1a;var canvas document.…

转移指令总结

转移指令&#xff1a;可以修改ip的指令。无条件转移 jmp(1) jmp short s 标号&#xff0c;短转移&#xff1a;用一个字节表示大小&#xff0c;范围为-128--127 (2) jmp near ptr s 标号&#xff0c;近转移&#xff1a;用两个字节表示大小&#xff0c;范围为-32768--32767(3) …

浅谈对程序员的认识_浅谈IT界程序员大佬普遍对性的追求

原标题&#xff1a;浅谈IT界程序员大佬普遍对性的追求业界程序员大佬跟普通程序员的差别&#xff0c;别的不说&#xff0c;对于完成一个需求来说&#xff0c;除了更少的 bug&#xff0c;还有什么优势&#xff1f;还有程序员对性的追求。下面谈谈最顶级的程序员对20个性的追求可…

乔治敦大学计算机专业排名,2020USNEWS数据科学与分析专业综合排名(上)

2020年USNEWS专业排名已经陆续放出了&#xff0c;今天慧德留学就带大家看一下2020年美国USNEWSS数据科学与分析专业的综合排名&#xff0c;供大家参考。独立项目综合排名 学校名称 专业名称 专业英文名 开设学位 所属科系1 哈佛大学 计算科学与工程 Computational Science and …

Javascript事件绑定this

在FF中的事件绑定是使用addEventListener&#xff0c;其中函数中的this就是被绑定事件的元素&#xff1b;而在IE下的attachEvent函数中的this是指window。 DRY&#xff1a;Don‘t Repeat Yourself&#xff1b; 对于自己声明的函数&#xff0c;如果参数是多个&#xff0c;并且可…

python xlutils教程_Python基于xlutils修改表格内容过程解析

一、xlutils是什么是一个提供了许多操作修改excel文件方法的库&#xff1b;属于python的第三方模块xlrd库用于读取excel文件中的数据&#xff0c;xlwt库用于将数据写入excel文件&#xff0c;修改用xlutils模块&#xff1b;xlutils库也仅仅是通过复制一个副本进行操作后保存一个…

html 显示不吃,20180902_html_第二次_张旺

Frequently Asked QuestionsIs it secure to send my companys information to COMIS?How can I enable SSL for my computer?1. Is it secure to send my companys information to COMIS?Your company information is protected by your unique user name and passwordwhic…

电脑键盘按钮功能注释大全

F1帮助 F2改名 F3搜索 F4地址 F5刷新 F6切换 F10菜单 CTRLA全选 CTRLC复制 CTRLX剪切 CTRLV粘贴 CTRLZ撤消 CTRLO打开 SHIFTDELETE永久删除 DELETE删除 ALTENTER属性 ALTF4关闭 CTRLF4关闭 ALTTAB切换 ALTESC切换 ALT空格键窗口菜单 CTRLESC开始菜单 拖动某一项时按CTRL复制所选…