一、hook的意义:
在不改动网络结构的情况下获取网络中间层输出。
没有使用hook的时候,想要得到conv2的输出,就要将在forward函数中经过conv2后的结果保存下来,然后和最终结果一起返回。
import torch
import torch.nn as nn
import torch.nn.functional as Fclass LeNet(nn.Module):def __init__(self):super(LeNet, self).__init__()self.conv1 = nn.Conv2d(3, 6, 5)self.conv2 = nn.Conv2d(6, 16, 5)self.fc1 = nn.Linear(16*5*5, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)
def forward(self, x):out = self.conv1(x)out = F.relu(out) out = F.max_pool2d(out, 2) out = self.conv2(out)out_conv2 = outout = F.relu(out)out = F.max_pool2d(out, 2)out = out.view(out.size(0), -1)out = F.relu(self.fc1(out))out = F.relu(self.fc2(out))out = self.fc3(out)return out, out_conv2
缺点:
- 太麻烦
- 但很多时候,我们并没有办法去直接修改网络的源代码,比如在pytorch中已经封装好的网络,那么这个时候就可以利用hook从外部获取Module的中间输出结果了。
所以可以通过使用hook的方式来得到model的中间结果,并不修改model的代码
二、使用方法:
1、定义hook函数
hook()函数是register_forward_hook()函数必须提供的参数,好处是“用户可以自行决定拦截了中间信息之后要做什么!”,比如自己想单纯的记录网络的输入输出(也可以进行修改等更加复杂的操作)。首先定义几个容器用于记录:
hook函数需要三个参数(这三个参数的名字你可以自己定义,但是必须是三个),这三个参数是系统传给hook函数的,自己不能修改这三个参数:
hook(module, input, output) -> None or modified output
# 1:定义module_name用于记录相应的module名字、定义用于获取网络各层输入输出tensor的容器
module_name = []
features_in_hook = []
features_out_hook = []
# 2:hook函数负责将相应的module名字、获取的输入输出 添加到feature列表中
def hook(module, fea_in, fea_out):print("hooker working")module_name.append(module.__class__)features_in_hook.append(fea_in)features_out_hook.append(fea_out)return None
2、在需要hook的网络层进行register
# load model
net = LeNet()# 确定取出哪一层的输出,“net.conv2”要和init函数中的self.conv2保持一致
# 在forward中第一次使用“conv2”时hook住,并将结果存储进hook函数
handle = net.conv2.register_forward_hook(hook)
3、走整个forward,然后得到hook的输入
# 将输入输入进model,让输出走过整个forward
x = torch.randn(2, 3, 32, 32)
y = net(x)# 得到hook的输出
print(module_name)
print(features_in_hook)
print(features_out_hook)
4、移除hook
# 将hook移除
handle.remove()
三、完整代码:
import torch
import torch.nn as nn
import torch.nn.functional as Fclass LeNet(nn.Module):def __init__(self):super(LeNet, self).__init__()self.conv1 = nn.Conv2d(3, 6, 5)self.conv2 = nn.Conv2d(6, 16, 5)self.fc1 = nn.Linear(16 * 5 * 5, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x):out = self.conv1(x)out = F.relu(out)out = F.max_pool2d(out, 2)# 在这里hook住,因为这是第一次出现conv2的地方out = self.conv2(out)# hook结束后,得到结果,然后继续forwardout = F.relu(out)out = F.max_pool2d(out, 2)out = out.view(out.size(0), -1)out = F.relu(self.fc1(out))out = F.relu(self.fc2(out))out = self.fc3(out)return out# 1:定义用于获取网络各层输入输出tensor的容器
# 并定义module_name用于记录相应的module名字
module_name = []
features_in_hook = []
features_out_hook = []
# 2:hook函数负责将相应的module名字、获取的输入输出 添加到feature列表中
def hook(module, fea_in, fea_out):print("hooker working")module_name.append(module.__class__)features_in_hook.append(fea_in)features_out_hook.append(fea_out)return None# load model
net = LeNet()# 确定取出哪一层的输出,“net.conv2”要和init函数中的self.conv2保持一致
# 在forward中第一次使用“conv2”时hook住,并将结果存储进hook函数
handle = net.conv2.register_forward_hook(hook)# 将输入输入进model,让输出走过整个forward
x = torch.randn(2, 3, 32, 32)
y = net(x)# 得到hook的输出
print(module_name)
print(features_in_hook)
print(features_out_hook)# 将hook移除
handle.remove()
pytorch的hook机制之register_forward_hook - 知乎
Pytorch register_forward_hook()简单用法_pytorch forward hook-CSDN博客