更新fielddata为true_在pytorch中停止梯度流的若干办法,避免不必要模块的参数更新...

 在pytorch中停止梯度流的若干办法,避免不必要模块的参数更新

2020/4/11 FesianXu

前言

在现在的深度模型软件框架中,如TensorFlow和PyTorch等等,都是实现了自动求导机制的。在深度学习中,有时候我们需要对某些模块的梯度流进行精确地控制,包括是否允许某个模块的参数更新,更新地幅度多少,是否每个模块更新地幅度都是一样的。这些问题非常常见,但是在实践中却很容易出错,我们在这篇文章中尝试对第一个子问题,也就是如果精确控制某些模型是否允许其参数更新,进行总结。如有谬误,请联系指出,转载请注明出处。

本文实验平台:pytorch 1.4.0, ubuntu 18.04, python 3.6

联系方式:

e-mail: FesianXu@gmail.com

QQ: 973926198

github: https://github.com/FesianXu

知乎专栏: 计算机视觉/计算机图形理论与应用

微信公众号213b6956a16f1b5ba973fe5fc7a5ed87.png


为什么我们要控制梯度流

为什么我们要控制梯度流?这个答案有很多个,但是都可以归结为避免不需要更新的模型模块被参数更新。我们在深度模型训练过程中,很可能存在多个loss,比如GAN对抗生成网络,存在G_lossD_loss,通常来说,我们通过D_loss只希望更新判别器(Discriminator),而生成网络(Generator)并不需要,也不能被更新;生成网络只在通过G_loss学习的情况下,才能被更新。这个时候,如果我们不控制梯度流,那么我们在训练D_loss的时候,我们的前端网络GeneratorCNN难免也会被一起训练,这个是我们不期望发生的。

fd3d8e6ca566b66c0342816b5ad181f4.png

Fig 1.1 典型的GAN结构,由生成器和判别器组成。

多个loss的协调只是其中一种情况,还有一种情况是:我们在进行模型迁移的过程中,经常采用某些已经预训练好了的特征提取网络,比如VGG, ResNet之类的,在适用到具体的业务数据集时候,特别是小数据集的时候,我们可能会希望这些前端的特征提取器不要更新,而只是更新末端的分类器(因为数据集很小的情况下,如果贸然更新特征提取器,很可能出现不期望的严重过拟合,这个时候的合适做法应该是更新分类器优先),这个时候我们也可以考虑停止特征提取器的梯度流。

这些情况还有很多,我们在实践中发现,精确控制某些模块的梯度流是非常重要的。笔者在本文中打算讨论的是对某些模块的梯度流的截断,而并没有讨论对某些模块梯度流的比例缩放,或者说最细粒度的梯度流控制,后者我们将会在后文中讨论。

一般来说,截断梯度流可以有几种思路:

  1. 停止计算某个模块的梯度,在优化过程中这个模块还是会被考虑更新,然而因为梯度已经被截断了,因此不能被更新。

  • 设置tensor.detach():完全截断之前的梯度流
  • 设置参数的requires_grad属性:单纯不计算当前设置参数的梯度,不影响梯度流
  • torch.no_grad():效果类似于设置参数的requires_grad属性

在优化器中设置不更新某个模块的参数,这个模块的参数在优化过程中就不会得到更新,然而这个模块的梯度在反向传播时仍然可能被计算。

我们后面分别按照这两大类思路进行讨论。


停止计算某个模块的梯度

在本大类方法中,主要涉及到了tensor.detach()requires_grad的设置,这两种都无非是对某些模块,某些节点变量设置了是否需要梯度的选项。

tensor.detach()

tensor.detach()的作用是:

tensor.detach()会创建一个与原来张量共享内存空间的一个新的张量,不同的是,这个新的张量将不会有梯度流流过,这个新的张量就像是从原先的计算图中脱离(detach)出来一样,对这个新的张量进行的任何操作都不会影响到原先的计算图了。因此对此新的张量进行的梯度流也不会流过原先的计算图,从而起到了截断的目的。

这样说可能不够清楚,我们举个例子。众所周知,我们的pytorch是动态计算图网络,正是因为计算图的存在,才能实现自动求导机制。考虑一个表达式:

如果用计算图表示则如Fig 2.1所示。

990edb40371d3f791edff624440ca580.png

Fig 2.1 计算图示例

考虑在这个式子的基础上,加上一个分支:

那么计算图就变成了:

5429d2dabe34d2d164a390db355169aa.png

Fig 2.2 添加了新的分支后的计算图

如果我们不detach() 中间的变量z,分别对pqw进行反向传播梯度,我们会有:

x = torch.tensor(([1.0]),requires_grad=True)
y = x**2
z = 2*y
w= z**3

# This is the subpath
# Do not use detach()
p = z
q = torch.tensor(([2.0]), requires_grad=True)
pq = p*q
pq.backward(retain_graph=True)

w.backward()
print(x.grad)

输出结果为 tensor([56.])。我们发现,这个结果是吧pqw的反向传播结果都进行了考虑的,也就是新增加的分支的反向传播影响了原先主要枝干的梯度流。这个时候我们用detach()可以把p给从原先计算图中脱离出来,使得其不会干扰原先的计算图的梯度流,如:

e69242d72bc04f37ae6f21c5970babcb.png

Fig 2.3 用了detach之后的计算图

那么,代码就对应地修改为:

x = torch.tensor(([1.0]),requires_grad=True)
y = x**2
z = 2*y
w= z**3

# detach it, so the gradient w.r.t `p` does not effect `z`!
p = z.detach()
q = torch.tensor(([2.0]), requires_grad=True)
pq = p*q
pq.backward(retain_graph=True)
w.backward()
print(x.grad)

这个时候,因为分支的梯度流已经影响不到原先的计算图梯度流了,因此输出为tensor([48.])

这只是个计算图的简单例子,在实际模块中,我们同样可以这样用,举个GAN的例子,代码如:

    def backward_D(self):
        # Fake
        # stop backprop to the generator by detaching fake_B
        fake_AB = self.fake_B
        # fake_AB = self.fake_AB_pool.query(torch.cat((self.real_A, self.fake_B), 1))
        self.pred_fake = self.netD.forward(fake_AB.detach())
        self.loss_D_fake = self.criterionGAN(self.pred_fake, False)

        # Real
        real_AB = self.real_B # GroundTruth
        # real_AB = torch.cat((self.real_A, self.real_B), 1)
        self.pred_real = self.netD.forward(real_AB)
        self.loss_D_real = self.criterionGAN(self.pred_real, True)

        # Combined loss
        self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5

        self.loss_D.backward()

    def backward_G(self):
        # First, G(A) should fake the discriminator
        fake_AB = self.fake_B
        pred_fake = self.netD.forward(fake_AB)
        self.loss_G_GAN = self.criterionGAN(pred_fake, True)

        # Second, G(A) = B
        self.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_A

        self.loss_G = self.loss_G_GAN + self.loss_G_L1

        self.loss_G.backward()


    def forward(self):
        self.real_A = Variable(self.input_A)
        self.fake_B = self.netG.forward(self.real_A)
        self.real_B = Variable(self.input_B)

    # 先调用 forward, 再 D backward, 更新D之后; 再G backward, 再更新G
    def optimize_parameters(self):
        self.forward()

        self.optimizer_D.zero_grad()
        self.backward_D()
        self.optimizer_D.step()

        self.optimizer_G.zero_grad()
        self.backward_G()
        self.optimizer_G.step()

我们注意看第六行,self.pred_fake = self.netD.forward(fake_AB.detach())使得在反向传播D_loss的时候不会更新到self.netG,因为fake_AB是由self.netG生成的,代码如self.fake_B = self.netG.forward(self.real_A)

设置requires_grad

tensor.detach()是截断梯度流的一个好办法,但是在设置了detach()的张量之前的所有模块,梯度流都不能回流了(不包括这个张量本身,这个张量已经脱离原先的计算图了),如以下代码所示:

x = torch.randn(2, 2)
x.requires_grad = True

lin0 = nn.Linear(2, 2)
lin1 = nn.Linear(2, 2)
lin2 = nn.Linear(2, 2)
lin3 = nn.Linear(2, 2)
x1 = lin0(x)
x2 = lin1(x1)
x2 = x2.detach() # 此处设置了detach,之前的所有梯度流都不会回传了
x3 = lin2(x2)
x4 = lin3(x3)
x4.sum().backward()
print(lin0.weight.grad)
print(lin1.weight.grad)
print(lin2.weight.grad)
print(lin3.weight.grad)

输出为:

None
None
tensor([[-0.7784, -0.7018],
        [-0.4261, -0.3842]])
tensor([[ 0.5509, -0.0386],
        [ 0.5509, -0.0386]])

我们发现lin0.weight.gradlin0.weight.grad都为None了,因为通过脱离中间张量,原先计算图已经和当前回传的梯度流脱离关系了。

这样有时候不够理想,因为我们可能存在只需要某些中间模块不计算梯度,但是梯度仍然需要回传的情况,在这种情况下,如下图所示,我们可能只需要不计算B_net的梯度,但是我们又希望计算A_netC_net的梯度,这个时候怎么办呢?当然,通过detach()这个方法是不能用了。

a436ca4d79b8f1db33dee53e722aeeb6.png

事实上,我们可以通过设置张量的requires_grad属性来设置某个张量是否计算梯度,而这个不会影响梯度回传,只会影响当前的张量。修改上面的代码,我们有:

x = torch.randn(2, 2)
x.requires_grad = True

lin0 = nn.Linear(2, 2)
lin1 = nn.Linear(2, 2)
lin2 = nn.Linear(2, 2)
lin3 = nn.Linear(2, 2)
x1 = lin0(x)
x2 = lin1(x1)
for p in lin2.parameters():
    p.requires_grad = False
x3 = lin2(x2)
x4 = lin3(x3)
x4.sum().backward()
print(lin0.weight.grad)
print(lin1.weight.grad)
print(lin2.weight.grad)
print(lin3.weight.grad)

输出为:

tensor([[-0.0117,  0.9976],
        [-0.0080,  0.6855]])
tensor([[-0.0075, -0.0521],
        [-0.0391, -0.2708]])
None
tensor([[0.0523, 0.5429],
        [0.0523, 0.5429]])

啊哈,正是我们想要的结果,只有设置了requires_grad=False的模块没有计算梯度,但是梯度流又能够回传。

另外,设置requires_grad经常用在对输入变量和输入的标签进行新建的时候使用,如:

for mat,label in dataloader:
    mat = Variable(mat, requires_grad=False)
    label = Variable(mat,requires_grad=False)
    ...

当然,通过把所有前端网络都设置requires_grad=False,我们可以实现类似于detach()的效果,也就是把该节点之前的所有梯度流回传截断。以VGG16为例子,如果我们只需要训练其分类器,而固定住其特征提取器网络的参数,我们可以采用将前端网络的所有参数的requires_grad设置为False,因为这个时候完全不需要梯度流的回传,只需要前向计算即可。代码如:

model = torchvision.models.vgg16(pretrained=True)
for param in model.features.parameters():
    param.requires_grad = False

torch.no_grad()

在对训练好的模型进行评估测试时,我们同样不需要训练,自然也不需要梯度流信息了。我们可以把所有参数的requires_grad属性设置为False,事实上,我们常用torch.no_grad()上下文管理器达到这个目的。即便输入的张量属性是requires_grad=True,   torch.no_grad()可以将所有的中间计算结果的该属性临时转变为False

如例子所示:

x = torch.randn(3, requires_grad=True)
x1 = (x**2)
print(x.requires_grad)
print(x1.requires_grad)

with torch.no_grad():
    x2 = (x**2)
    print(x1.requires_grad)
    print(x2.requires_grad)

输出为:

True
True
True
False

注意到只是在torch.no_grad()上下文管理器范围内计算的中间变量的属性requires_grad才会被转变为False,在该管理器外面计算的并不会变化。

不过和单纯手动设置requires_grad=False不同的是,在设置了torch.no_grad()之前的层是不能回传梯度的,延续之前的例子如:

x = torch.randn(2, 2)
x.requires_grad = True

lin0 = nn.Linear(2, 2)
lin1 = nn.Linear(2, 2)
lin2 = nn.Linear(2, 2)
lin3 = nn.Linear(2, 2)
x1 = lin0(x)
with torch.no_grad():
    x2 = lin1(x1)
x3 = lin2(x2)
x4 = lin3(x3)
x4.sum().backward()
print(lin0.weight.grad)
print(lin1.weight.grad)
print(lin2.weight.grad)
print(lin3.weight.grad)

输出为:

None
None
tensor([[-0.0926, -0.0945],
        [-0.2793, -0.2851]])
tensor([[-0.5216,  0.8088],
        [-0.5216,  0.8088]])

此处如果我们打印lin1.weight.requires_grad我们会发现其为True,但是其中间变量x2.requires_grad=False

一般来说在实践中,我们的torch.no_grad()通常会在测试模型的时候使用,而不会选择在选择性训练某些模块时使用[1],例子如:

model.train()
# here train the model, just skip the codes
model.eval() # here we start to evaluate the model
with torch.no_grad():
 for each in eval_data:
  data, label = each
  logit = model(data)
  ... # here we just skip the codes

注意

通过设置属性requires_grad=False的方法(包括torch.no_grad())很多时候可以避免保存中间计算的buffer,从而减少对内存的需求,但是这个也是视情况而定的,比如如[2]的所示

graph LR;
input-->A_net;
A_net-->B_net;
B_net-->C_net;

如果我们不需要A_net的梯度,我们设置所有A_netrequires_grad=False,因为后续的B_netC_net的梯度流并不依赖于A_net,因此不计算A_net的梯度流意味着不需要保存这个中间计算结果,因此减少了内存。

但是如果我们不需要的是B_net的梯度,而需要A_netC_net的梯度,那么问题就不一样了,因为A_net梯度依赖于B_net的梯度,就算不计算B_net的梯度,也需要保存回传过程中B_net中间计算的结果,因此内存并不会被减少。

但是通过tensor.detach()的方法并不会减少内存使用,这一点需要注意。


设置优化器的更新列表

这个方法更为直接,即便某个模块进行了梯度计算,我只需要在优化器中指定不更新该模块的参数,那么这个模块就和没有计算梯度有着同样的效果了。如以下代码所示:

class model(nn.Module):
    def __init__(self):
        super().__init__()
        self.model_1 = nn.linear(10,10)
        self.model_2 = nn.linear(10,20)
        self.fc = nn.linear(20,2)
        self.relu = nn.ReLU()
       
    def foward(inputv):
        h = self.model_1(inputv)
        h = self.relu(h)
        h = self.model_2(inputv)
        h = self.relu(h)
        return self.fc(h)

在设置优化器时,我们只需要更新fc层model_2层,那么则是:

curr_model = model()
opt_list = list(curr_model.fc.parameters())+list(curr_model.model_2.parameters())
optimizer = torch.optim.SGD(opt_list, lr=1e-4)

当然你也可以通过以下的方法去设置每一个层的学习率来避免不需要更新的层的更新[3]:

optim.SGD([
                {'params': model.model_1.parameters()},
                {'params': model.mode_2.parameters(), 'lr': 0},
         {'params': model.fc.parameters(), 'lr': 0}
            ], lr=1e-2, momentum=0.9)

这种方法不需要更改模型本身结构,也不需要添加模型的额外节点,但是需要保存梯度的中间变量,并且将会计算不需要计算的模块的梯度(即便最后优化的时候不考虑更新),这样浪费了内存和计算时间。

Reference

[1]. https://blog.csdn.net/LoseInVain/article/details/82916163

[2]. https://discuss.pytorch.org/t/requires-grad-false-does-not-save-memory/21936

[3]. https://pytorch.org/docs/stable/optim.html#module-torch.optim

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/564500.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

php 自动切图,前端工程师技能之photoshop巧用系列扩展篇自动切图

目录[1]初始设置 [2]自动切图前面的话随着photoshop版本的不断升级,软件本身增加了很多新的功能,也为切图工作增加了很多的便利。photoshop最新的版本新增了自动切图功能,本文将详细介绍photoshop的这个新功能初始设置当然首先还是要进行一些…

Spring通知类型及使用ProxyFactoryBean创建AOP代理

Spring 通知类型 通知(Advice)其实就是对目标切入点进行增强的内容,Spring AOP 为通知(Advice)提供了 org.aopalliance.aop.Advice 接口。 Spring 通知按照在目标类方法的连接点位置,可以分为以下五种类型…

matlab二维谐振子,基于有限差分法求解的二维谐振子的MATLAB程序如下。哪位大神能帮我做个注明啊,完全看不懂啊,,急...

基于有限差分法求解的二维谐振子的MATLAB程序如下。哪位大神能帮我做个注明啊,完全看不懂啊,,急0____丿呆呆丶2017.04.15浏览20次分享举报tic clc clear L20; W20; N20; M20; hxL/(2*N); hyW/(2*M); Szeros((2*M-1)*(2*N-1)); for m1:2*M-1 D…

typescript get方法_.NET手撸绘制TypeScript类图——上篇

.NET手撸绘制TypeScript类图——上篇近年来随着交互界面的精细化,TypeScript越来越流行,前端的设计也越来复杂,而类图正是用简单的箭头和方块,反映对象与对象之间关系/依赖的好方式。许多工具都能生成C#类图,有些工具也…

Spring使用AspectJ开发AOP

AspectJ 是一个基于 Java 语言的 AOP 框架,它扩展了 Java 语言。Spring 2.0 以后,新增了对 AspectJ 方式的支持,新版本的 Spring 框架,建议使用 AspectJ 方式开发 AOP。 使用 AspectJ 开发 AOP 通常有两种方式: 基于 …

matlab特征方程的根,MATLAB 求解特征方程的根轨迹图稳定性分析

原文:http://tecdat.cn/?p3871根轨迹分析在下文中,我们提供了用于根轨迹分析的强大MATLAB命令的简要描述。读者可能想知道为什么当强大的MATLAB命令可用时,教师强调学习手工计算。对于给定的一组开环极点和零点,MATLAB立即绘制根…

python如何把一张图像的所有像素点的值都显示出来_情人节,教你用 Python 向女神表白...

点击上方 “AirPython”,选择 “加为星标”第一时间关注 Python 技术干货!2020年,这个看起来如此浪漫的年份,你还是一个人吗?难不成我还能是一条狗?提醒你一下,后天就是 2月14日了。什么&#x…

Spring事务管理接口

Spring 的事务管理是基于 AOP 实现的,而 AOP 是以方法为单位的。Spring 的事务属性分别为传播行为、隔离级别、只读和超时属性,这些属性提供了事务应用的方法和描述策略。 在 Java EE 开发经常采用的分层模式中,Spring 的事务处理位于业务逻…

倒计时小工具_送你3个倒数计日的小程序,让你不再遗忘重要事

每天我们忙于工作,忙于生活,在很多重要事情,重要人的生日,以及重要有意义的日子总会在忙碌中被遗忘,那么这该怎么办呢?别紧张,小编为你带来3个倒数计日的小程序,让你不再遗忘重要事情…

Spring声明式事务管理

Spring 的事务管理有两种方式:一种是传统的编程式事务管理,即通过编写代码实现的事务管理;另一种是基于 AOP 技术实现的声明式事务管理。由于在实际开发中,编程式事务管理很少使用,所以我们只对 Spring 的声明式事务管…

无法检查指定的位置是否位于cfs上_打印机知识普及:七大原因导致的打印机无法打印及解决方法...

打印机无法打印的原因有很多,如果我们遇到打印机无法打印应该首先从简单到复杂入手。首先必须排除一些最简单的问题,比如打印机是否正常安装。另外打印机内部是不是已经放置有墨盒以及打印纸等,这些基本问题必须排除,另外还有一个…

Spring基于Annotation实现事务管理

在 Spring 中,除了使用基于 XML 的方式可以实现声明式事务管理以外,还可以通过 Annotation 注解的方式实现声明式事务管理。 使用 Annotation 的方式非常简单,只需要在项目中做两件事,具体如下。 1 在 Spring 容器中注册驱动&…

python扩展,用python扩展列

我试图在python中扩展数据帧的某一列。在r中,我将使用这个函数:在python中,我发现df.pivot_table(),但我刚发现一个错误:pandas.pivot_table(df, values Value, index[Day, Money, Product],columns[Account])^更新结果:数据帧没有更改。它只返回相同的数据文件而不进行扩展。考…

基于matlab的pcm系统仿真_深入理解基于RISC-V ISS Spike的仿真系统:探索Spike,pk和fesrv...

Spike, the RISC-V ISA Simulator, implements a functional model of one or more RISC-V processors.Spike is named after the golden spike used to celebrate the completion of the US transcontinental railway.一些同学初接触RISC-V,总逃脱不了被Rocketchip…

单机安装oracle,单机安装oracle系统

一、安装oracle6.9操作系统1、硬盘划分为30G(以下分区皆强制为主分区)/boot 200Mswap 4G/ 剩余全部空间安装下一步的步骤就不说了勾选图形选项基本系统--》备份客户端,大系统性能,存储工具,安全工具,性能工具,目录客户…

redis timeout设置多少合适_热水器怎么调温度?一般热水器温度设置多少度比较合适?...

对于热水器的温度设置您知道怎么操作吗?一般情况下电热水器设置多少度比较合适呢?今天蜜罐蚁装修网小编给大家介绍下热水器如何调节温度,以及热水器调节温度在什么范围比较合适。下面请看小编以美的电热水器某个型号产品举例图文讲解热水器调…

用swing设计一个打地鼠小游戏_这7个风靡欧美的英语小游戏,学会胜过刷100道题!...

精彩导读小编为大家搜罗了一些在国外家喻户晓的语言类小游戏。好的方法胜过刷上100道题,真正让孩子觉得好玩,教学才会事半功倍!01Would You Rather...最近牛津大学的面试考题惊天地爆出了一题:你想要变成吸血鬼还是想要变成僵尸&a…

oracle数据库9i安装,Oracle 9i数据库服务器的安装和辅助软件安装教程

安装数据库服务器以Oracle 9i数据库服务器软件的安装过程为例,介绍数据库服务器的安装过程。14.3.1 安装数据库服务器系统环境数据库服务器安装之前,一般都需要检测系统安装环境,以避免系统不支持、内存不够、硬盘空间不足等情况发生。下面从…

oracle 触发器 行级,oracle的行级触发器使用

行级触发器:当触发器被触发时,要使用被插入、更新或删除的记录中的列值,有时要使用操作前、后列的值.:NEW 修饰符访问操作完成后列的值:OLD 修饰符访问操作完成前列的值例1: 建立一个触发器, 当职工表 emp 表被删除一条记录时,把被…

司铭宇老师:如何让企业销售培训效果落地

如何让企业销售培训效果落地 在企业销售培训中,我们经常听到一个词,那就是“落地”。所谓的“落地”,简单来说就是将培训中所学到的知识和技能转化为实际的工作行动,从而提高销售业绩。但是,如何才能让销售培训效果真…