import torch
1. Requires_grad
但是,模型毕竟不是人,它的智力水平还不足够去自主辨识那些量的梯度需要计算,既然如此,就需要手动对其进行标记。
在PyTorch中,通用的数据结构tensor包含一个attributerequires_grad,它被用于说明当前量是否需要在计算中保留对应的梯度信息,以上文所述的线性回归为例,容易知道参数www为需要训练的对象,为了得到最合适的参数值,我们需要设置一个相关的损失函数,根据梯度回传的思路进行训练。
官方文档中的说明如下
If there’s a single input to an operation that requires gradient, its output will also require gradient.
只要某一个输入需要相关梯度值,则输出也需要保存相关梯度信息,这样就保证了这个输入的梯度回传。
而反之,若所有的输入都不需要保存梯度,那么输出的requires_grad会自动设置为False。既然没有了相关的梯度值,自然进行反向传播时会将这部分子图从计算中剔除。
Conversely, only if all inputs don’t require gradient, the output also won’t require it. Backward computation is never performed in the subgraphs, where all Tensors didn’t require gradients.
对于那些要求梯度的tensor,PyTorch会存储他们相关梯度信息和产生他们的操作,这产生额外内存消耗,为了优化内存使用,默认产生的tensor是不需要梯度的。
而我们在使用神经网络时,这些全连接层卷积层等结构的参数都是默认需要梯度的。
a = torch.tensor([1., 2., 3.])
print('a:', a.requires_grad)
b = torch.tensor([1., 4., 2.], requires_grad = True)
print('b:', b.requires_grad)
print('sum of a and b:', (a+b).requires_grad)
a: False
b: True
sum of a and b: True
2. Computation Graph
从PyTorch的设计原理上来说,在每次进行前向计算得到pred时,会产生一个用于梯度回传的计算图,这张图储存了进行back propagation需要的中间结果,当调用了.backward()后,会从内存中将这张图进行释放
这张计算图保存了计算的相关历史和提取计算所需的所有信息,以output作为root节点,以input和所有的参数为leaf节点,
we only retain the grad of the leaf node with requires_grad =True
在完成了前向计算的同时,PyTorch也获得了一张由计算梯度所需要的函数所组成的图
而从数据集中获得的input其requires_grad为False,故我们只会保存参数的梯度,进一步据此进行参数优化
在PyTorch中,multi-task任务一个标准的train from scratch流程为
for idx, data in enumerate(train_loader):
xs, ys = data
optmizer.zero_grad()
# 计算d(l1)/d(x)
pred1 = model1(xs) #生成graph1
loss = loss_fn1(pred1, ys)
loss.backward() #释放graph1
# 计算d(l2)/d(x)
pred2 = model2(xs)#生成graph2
loss2 = loss_fn2(pred2, ys)
loss.backward() #释放graph2
# 使用d(l1)/d(x)+d(l2)/d(x)进行优化
optmizer.step()
Computation Graph本质上是一个operation的图,所有的节点都是一个operation,而进行相应计算的参数则以叶节点的形式进行输入
借助torchviz库以下面的模型作为示例
import torch.nn.functional as F
import torch.nn as nn
class Conv_Classifier(nn.Module):
def __init__(self):
super(Conv_Classifier, self).__init__()
self.conv1 = nn.Conv2d(1, 5, 5)
self.pool1 = nn.MaxPool2d(2)
self.conv2 = nn.Conv2d(5, 16, 5)
self.pool2 = nn.MaxPool2d(2)
self.fc1 = nn.Linear(256, 20)
self.fc2 = nn.Linear(20, 10)
def forward(self, x):
x = F.relu(self.pool1((self.conv1(x))))
x = F.relu(self.pool2((self.conv2(x))))
x = F.dropout2d(x, training=self.training)
x = x.view(-1, 256)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
return x
Mnist_Classifier = Conv_Classifier()
from torchviz import make_dot
input_sample = torch.rand((1, 1, 28, 28))
make_dot(Mnist_Classifier(input_sample), params=dict(Mnist_Classifier.named_parameters()))
其对应的计算梯度所需的图(计算图)为
可以看到,所有的叶子节点对应的操作都被记录,以便之后的梯度回传。