关于PyTorch中的register_forward_hook()函数未能执行其中hook函数的问题
Hook 是 PyTorch 中一个十分有用的特性。利用它,我们可以不必改变网络输入输出的结构,方便地获取、改变网络中间层变量的值和梯度。这个功能被广泛用于可视化神经网络中间层的 feature、gradient,从而诊断神经网络中可能出现的问题,分析网络有效性。
Hook函数机制:不改变主体,实现额外的功能,像一个挂件一样;
Hook函数本身不是本文介绍的重点,网上介绍的文章颇多,本文主要是记录一下笔者在使用hook函数时遇到的一些问题及解决过程。
register_forward_hook
首先看一下一个最简单的使用register_forward_hook的例子:
import torch
import torch.nn as nn
import torch.nn.functional as Fclass LeNet(nn.Module):def __init__(self):super(LeNet, self).__init__()self.conv1 = nn.Conv2d(3, 6, 5)self.conv2 = nn.Conv2d(6, 16, 5)self.fc1 = nn.Linear(16*5*5, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x):out = F.relu(self.conv1(x)) #1 out = F.max_pool2d(out, 2) #2out = F.relu(self.conv2(out)) #3out = F.max_pool2d(out, 2)out = out.view(out.size(0), -1)out = F.relu(self.fc1(out))out = F.relu(self.fc2(out))out = self.fc3(out)return outfeatures = []
def hook(module, input, output): # module: model.conv2 # input :in forward function [#2]# output:is [#3 self.conv2(out)]print('*'*100)features.append(output.clone().detach())# output is saved in a list net = LeNet() ## 模型实例化
x = torch.randn(2, 3, 32, 32) ## input
handle = net.conv2.register_forward_hook(hook) ## 获取整个Lenet模型 conv2的中间结果
y = net(x) ## 获取的是 关于 input x 的 conv2 结果 print(features[0].size()) # 即 [#3 self.conv2(out)]
handle.remove() ## hook删除 ,防止多次保存hook内容占用空间
输出
****************************************************************************************************
torch.Size([2, 16, 10, 10])
形状是我们想要的结果,打印一串*是为了直观地验证hook函数被调用了。
其中conv2的名称,我们可以打印模型的state_dict()来查看自己要的是哪个module
for k in model.state_dict():print(k)
输出:
conv1.weight
conv1.bias
conv2.weight
conv2.bias
fc1.weight
fc1.bias
fc2.weight
fc2.bias
fc3.weight
fc3.bias
我们上面直接拿conv2做例子了。
出现的问题
在实际使用中,我想打印最近的transformer模型alt_gvt_large的位置编码来看一下,但是遇到了问题。
我查看了一下模型中的module,找到自己想要的
import torch
import timm
import numpy as np
import cv2
import seaborn as sns
import gvt
from PIL import Image
from torchvision import transformsfmap_block = []
def forward_hook(module, data_input, data_output):print('*'*100)fmap_block.append(data_output.clone().detach())model = timm.create_model('alt_gvt_large',pretrained=False,num_classes=1000,drop_rate=0.1,drop_path_rate=0.1,drop_block_rate=None,)
pipeline = transforms.Compose([transforms.RandomCrop(224),transforms.ToTensor(),])for k in model.state_dict():print(k)
输出:
# ...
patch_embeds.3.norm.weight
patch_embeds.3.norm.bias
norm.weight
norm.bias
head.weight
head.bias
pos_block.0.proj.0.weight
pos_block.0.proj.0.bias
pos_block.1.proj.0.weight
pos_block.1.proj.0.bias
pos_block.2.proj.0.weight
pos_block.2.proj.0.bias
pos_block.3.proj.0.weight
pos_block.3.proj.0.bias
blocks.0.0.norm1.weight
blocks.0.0.norm1.bias
# ...
那肯定就是pos_block喽。
开始hook:
image = Image.open('125.jpg')
image = pipeline(image).unsqueeze(dim=0)handle = model.pos_block.register_forward_hook(forward_hook)pred = model(image)
print(fmap_block[0].shape)
handle.remove()
出大问题,根本没有输出,连我们设置来验证hook函数运行的*也没有出现,hook函数肯定没有被执行,这是怎么回事呢?
解决过程
经过仔细比对以上两次成功和失败hook经历:
conv2.bias
conv2.weight
--------
pos_block.3.proj.0.weight
pos_block.3.proj.0.bias
简单分析不难有如此猜测:只有下面直接能点( . )到weight和bias的module才能被直接hook。
但是直接将输出结果粘贴过去会出现:
handle = model.pos_block.3.proj.0.register_forward_hook(forward_hook)
直接报语法错误,数字肯定是不能直接点的。
handle = model.pos_block.3.proj.0.register_forward_hook(forward_hook)^
SyntaxError: invalid syntax
于是笔者一层一层查看进去:
for k in model.pos_block:print(k)for _k in k.proj.state_dict():print(_k)breakbreak
print(type(model.pos_block))
发现上面出现数字的地方的类型其实是:<class ‘torch.nn.modules.container.ModuleList’>,也就是一个list,那是不是直接可以用[ ]进行索引。
于是我们可以改为:
handle = model.pos_block[3].proj[0].register_forward_hook(forward_hook)
输出:
****************************************************************************************************
torch.Size([1, 256, 28, 28])
终于成功。
总结
还是对PyTorch中的Model,Module,childeren_module等理解的不到位啊,只会最基本的使用方法,稍微进阶一点的操作就会遇到阻力,以后有时间梳理一下。PyTorch是当今公认比较好用的开源框架了,但是想要随心所欲地实现自己的想法,还是需要花点时间把其中的各个组件及相互之间的关系都理解到位。