- object.grad.zero_()的意思是清0object的梯度值。
下面做个实验。
x = torch.arange(4.0)
x.requires_grad_(True)
x.grad
# 注意此时为None,不为0
y = 2 * torch.dot(x, x)
y.backward()
x.grad
# tensor([ 0., 4., 8., 12.])
x.grad.zero_()
x.grad
# tensor([0., 0., 0., 0.])
- 在默认情况下,PyTorch会累积梯度,我们需要清除之前的值,假如不清0会出现什么现象,看下面的实验。
x = torch.arange(5.0)
x.requires_grad_(True)
y = 2 * torch.dot(x, x)
y.backward()
x.grad
# Out[58]: tensor([ 0., 4., 8., 12., 16.])
z = 2 * torch.dot(x, x)
z.backward()
x.grad
# Out[61]: tensor([ 0., 8., 16., 24., 32.]),结果不对
-
那么上面这个错误结果是怎么来的呢?
PyTorch会累积梯度,tensor([ 0., 8., 16., 24., 32.]) = tensor([ 0., 4., 8., 12., 16.]) + tensor([ 0., 4., 8., 12., 16.])得到的结果; -
所以下面这段代码的意思是迭代param时不需要构建计算图,并且迭代完成后就把param.grad清0,因为再一次调用sgd时就是下一个batch得到的param.grad,batch和batch是没有关系的。
def sgd(params, lr, batch_size): #@save"""小批量随机梯度下降。"""with torch.no_grad():for param in params:param -= lr * param.grad / batch_sizeparam.grad.zero_()
参考资料
- https://zh-v2.d2l.ai/chapter_preliminaries/autograd.html;