torch.autograd.Function
是 PyTorch 提供的一个接口,用于自定义前向传播和反向传播的操作。自定义操作需要继承 torch.autograd.Function 并重载 forward 和 backward 方法。
下面是一个简单的示例,展示如何自定义一个平方操作的前向传播和反向传播。
示例一:
import torch
from torch.autograd import Function
class SquareFunction(Function):@staticmethoddef forward(ctx, input):# ctx 是一个上下文对象,用于存储反向传播所需的信息ctx.save_for_backward(input)return input * input@staticmethoddef backward(ctx, grad_output):# 从上下文对象中取回前向传播保存的信息input, = ctx.saved_tensorsgrad_input = grad_output * 2 * inputreturn grad_input
# 输入张量
input = torch.tensor([2.0, 3.0], requires_grad=True)# 使用自定义的 SquareFunction
output = SquareFunction.apply(input)# 进行反向传播
output.backward(torch.tensor([1.0, 1.0]))# 打印梯度
print(input.grad) # 输出:tensor([4., 6.])
示例二:
import torchclass SignWithSigmoidGrad(torch.autograd.Function):@staticmethoddef forward(ctx, x):result = (x > 0).float()sigmoid_result = torch.sigmoid(x)ctx.save_for_backward(sigmoid_result)return result@staticmethoddef backward(ctx, grad_result):(sigmoid_result,) = ctx.saved_tensorsif ctx.needs_input_grad[0]:grad_input = grad_result * sigmoid_result * (1 - sigmoid_result)else:grad_input = Nonereturn grad_input
这段代码定义了一个自定义的 PyTorch autograd 函数 SignWithSigmoidGrad,这个函数在前向传播中计算输入张量 x 的符号函数(sign function),在反向传播中计算与 sigmoid 函数有关的梯度。
示例三:
import torch
from torch.autograd import Functionclass SquareFunction(Function):@staticmethoddef forward(ctx, input):# ctx 是一个上下文对象,用于存储反向传播所需的信息ctx.save_for_backward(input)return torch.sum(input)@staticmethoddef backward(ctx, grad_output):# 从上下文对象中取回前向传播保存的信息input, = ctx.saved_tensorsgrad_input = grad_output * 2 * inputreturn grad_input# 输入张量
input = torch.tensor([2.0, 3.0], requires_grad=True)# 使用自定义的 SquareFunction
output = SquareFunction.apply(input)# 进行反向传播
output.backward(torch.tensor(2.0))# 打印梯度
print(input.grad) # 输出:tensor([8., 12.])