关于PyTorch中的register_forward_hook()函数未能执行其中hook函数的问题

关于PyTorch中的register_forward_hook()函数未能执行其中hook函数的问题

Hook 是 PyTorch 中一个十分有用的特性。利用它,我们可以不必改变网络输入输出的结构,方便地获取、改变网络中间层变量的值和梯度。这个功能被广泛用于可视化神经网络中间层的 feature、gradient,从而诊断神经网络中可能出现的问题,分析网络有效性。

Hook函数机制:不改变主体,实现额外的功能,像一个挂件一样;

Hook函数本身不是本文介绍的重点,网上介绍的文章颇多,本文主要是记录一下笔者在使用hook函数时遇到的一些问题及解决过程。

register_forward_hook

首先看一下一个最简单的使用register_forward_hook的例子:

import torch
import torch.nn as nn
import torch.nn.functional as Fclass LeNet(nn.Module):def __init__(self):super(LeNet, self).__init__()self.conv1 = nn.Conv2d(3, 6, 5)self.conv2 = nn.Conv2d(6, 16, 5)self.fc1 = nn.Linear(16*5*5, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x):out = F.relu(self.conv1(x))     #1 out = F.max_pool2d(out, 2)      #2out = F.relu(self.conv2(out))   #3out = F.max_pool2d(out, 2)out = out.view(out.size(0), -1)out = F.relu(self.fc1(out))out = F.relu(self.fc2(out))out = self.fc3(out)return outfeatures = []
def hook(module, input, output): # module: model.conv2 # input :in forward function  [#2]# output:is  [#3 self.conv2(out)]print('*'*100)features.append(output.clone().detach())# output is saved  in a list net = LeNet() ## 模型实例化 
x = torch.randn(2, 3, 32, 32) ## input 
handle = net.conv2.register_forward_hook(hook) ## 获取整个Lenet模型 conv2的中间结果
y = net(x)  ## 获取的是 关于 input x 的 conv2 结果 print(features[0].size()) # 即 [#3 self.conv2(out)]
handle.remove() ## hook删除 ,防止多次保存hook内容占用空间

输出

****************************************************************************************************
torch.Size([2, 16, 10, 10])

形状是我们想要的结果,打印一串*是为了直观地验证hook函数被调用了。

其中conv2的名称,我们可以打印模型的state_dict()来查看自己要的是哪个module

for k in model.state_dict():print(k)

输出:

conv1.weight
conv1.bias
conv2.weight
conv2.bias
fc1.weight
fc1.bias
fc2.weight
fc2.bias
fc3.weight
fc3.bias

我们上面直接拿conv2做例子了。

出现的问题

在实际使用中,我想打印最近的transformer模型alt_gvt_large的位置编码来看一下,但是遇到了问题。

我查看了一下模型中的module,找到自己想要的

import torch
import timm
import numpy as np
import cv2
import seaborn as sns
import gvt
from PIL import Image
from torchvision import transformsfmap_block = []
def forward_hook(module, data_input, data_output):print('*'*100)fmap_block.append(data_output.clone().detach())model = timm.create_model('alt_gvt_large',pretrained=False,num_classes=1000,drop_rate=0.1,drop_path_rate=0.1,drop_block_rate=None,)
pipeline = transforms.Compose([transforms.RandomCrop(224),transforms.ToTensor(),])for k in model.state_dict():print(k)

输出:

# ...
patch_embeds.3.norm.weight
patch_embeds.3.norm.bias
norm.weight
norm.bias
head.weight
head.bias
pos_block.0.proj.0.weight
pos_block.0.proj.0.bias
pos_block.1.proj.0.weight
pos_block.1.proj.0.bias
pos_block.2.proj.0.weight
pos_block.2.proj.0.bias
pos_block.3.proj.0.weight
pos_block.3.proj.0.bias
blocks.0.0.norm1.weight
blocks.0.0.norm1.bias
# ...

那肯定就是pos_block喽。

开始hook:


image = Image.open('125.jpg')
image = pipeline(image).unsqueeze(dim=0)handle = model.pos_block.register_forward_hook(forward_hook)pred = model(image)
print(fmap_block[0].shape)
handle.remove()

出大问题,根本没有输出,连我们设置来验证hook函数运行的*也没有出现,hook函数肯定没有被执行,这是怎么回事呢?

解决过程

经过仔细比对以上两次成功和失败hook经历:

conv2.bias
conv2.weight
--------
pos_block.3.proj.0.weight
pos_block.3.proj.0.bias

简单分析不难有如此猜测:只有下面直接能点( . )到weight和bias的module才能被直接hook。

但是直接将输出结果粘贴过去会出现:

handle = model.pos_block.3.proj.0.register_forward_hook(forward_hook)

直接报语法错误,数字肯定是不能直接点的。

handle = model.pos_block.3.proj.0.register_forward_hook(forward_hook)^
SyntaxError: invalid syntax

于是笔者一层一层查看进去:

for k in model.pos_block:print(k)for _k in k.proj.state_dict():print(_k)breakbreak 
print(type(model.pos_block))

发现上面出现数字的地方的类型其实是:<class ‘torch.nn.modules.container.ModuleList’>,也就是一个list,那是不是直接可以用[ ]进行索引。

于是我们可以改为:

handle = model.pos_block[3].proj[0].register_forward_hook(forward_hook)

输出:

****************************************************************************************************
torch.Size([1, 256, 28, 28])

终于成功。

总结

还是对PyTorch中的Model,Module,childeren_module等理解的不到位啊,只会最基本的使用方法,稍微进阶一点的操作就会遇到阻力,以后有时间梳理一下。PyTorch是当今公认比较好用的开源框架了,但是想要随心所欲地实现自己的想法,还是需要花点时间把其中的各个组件及相互之间的关系都理解到位。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/532839.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

geoda权重矩阵导入matlab,空间计量经济学-分析解析.ppt

厦门大学 邓明 空间截面回归模型 地理加权回归模型 地理加权回归模型扩展了普通线性回归模型。在GWR模型中&#xff0c;特定区位的回归系数不再是利用全部信息获得的假定常数&#xff0c;而是利用邻近观测值的子样本数据信息进行局域(Local)回归估计而得&#xff0c;并随着空间…

树莓派摄像头基础配置及测试

树莓派摄像头基础配置 step 1 硬件连接 硬件连接&#xff0c;注意不要接反了&#xff0c;排线蓝色一段朝向网口的方向。&#xff08;笔者的设备是树莓派4B&#xff09; step 2 安装raspi-config 安装 raspi-config raspi-config在raspbian中是预装的&#xff0c;而在kali、…

matlab sobel锐化,sobel锐化 - yirui wu.ppt

sobel锐化 - yirui wu第六章 图像锐化 图像锐化的概念 图像锐化的目的是加强图像中景物的细节边缘和轮廓。 锐化的作用是使灰度反差增强。 因为边缘和轮廓都位于灰度突变的地方。所以锐化算法的实现是基于微分作用。 图像锐化方法 图像的景物细节特征&#xff1b; 一阶微分锐化…

使用百度云智能SDK和树莓派搭建简易的人脸识别系统 Python语言版

硬件 树莓派4B一个CSI摄像头一个 笔者使用的是树莓派4B和CSI摄像头&#xff0c;但是树莓派3和USB摄像头等相似设备均可。 百度云智能设置 Step 1 登录 百度云智能 网址https://cloud.baidu.com/ 首先登录百度账号&#xff0c;与百度云、百度贴吧等互通&#xff0c;可直接…

php 5.6 引用传递,升级到5.6.x后如何在php中修复引用传递

我最近将fom php 5.2升级到5.6,并且有一些代码我无法修复&#xff1a;//Finds users with the same ip- or email-addressfunction find_related_users($user_id) {global $pdo;//print_R($pdo);//Let SQL do the magic!$sth $pdo->prepare(CALL find_related_users(?));$…

RuntimeError: [enforce fail at inline_container.cc:145] . PytorchStreamReader failed reading zip arc

RuntimeError: [enforce fail at inline_container.cc:145] . PytorchStreamReader failed reading zip archive: failed finding central directory 原因分析 这个报错是出现在PyTorch在读入模型参数时&#xff1a; checkpoint torch.load(epoch_15.pth, map_locationcpu)…

xp搭建 php环境,windows xp 下 LAMP环境搭建

1. apache安装步骤如下图在浏览器中输入&#xff1a;localhost&#xff0c;出现下面页面说明已成功安装apache。2. mysql安装如下图显示在运行里面输入cmd &#xff0c;然后连接测试mysql &#xff0c;如图所示&#xff1a;3. php安装(1)将php压缩包解压到安装路径中的php目录…

C++中的虚函数(表)实现机制以及用C语言对其进行的模拟实现

C中的虚函数(表)实现机制以及用C语言对其进行的模拟实现 声明&#xff1a;本文非博主原创&#xff0c;转自https://blog.twofei.com/496/&#xff0c;博主读后受益良多&#xff0c;特地转载&#xff0c;一是希望好文能有更多人看到&#xff0c;二是为了日后自己查阅。 前言 …

php 前端模板 yii,php – Yii2高级模板:添加独立网页

我在backend / views / site下添加了help.php,并在SiteController.php下声明了一个能够识别链接的函数public function behaviors(){return [access > [class > AccessControl::className(),rules > [[actions > [login, error],allow > true,],[actions > […

C++中数组和指针的关系(区别)详解

C中数组和指针的关系&#xff08;区别&#xff09;详解 本文转自&#xff1a;http://c.biancheng.net/view/1472.html 博主在阅读后将文中几个知识点提出来放在前面&#xff1a; 没有方括号和下标的数组名称实际上代表数组的起始地址&#xff0c;这意味着数组名称实际上就是…

安装php独立环境,0507-php独立环境的安装与配置 Web程序 - 贪吃蛇学院-专业IT技术平台...

1.在一个纯英文目录下新建三个文件夹2.安装apache(选择好版本)过程中该填的按格式填好&#xff0c;其余的只更改安装目录即可如果报错1901是安装版本的问题。检查&#xff1a;安装完成后localhost打开为It works!添加到电脑属性环境变量&#xff1a;3.将php文件解压文档放到AMP…

linux中PATH变量-详细介绍

转自&#xff1a;https://blog.csdn.net/haozhepeng/article/details/100584451 转载者勘误 原文最后提到的 echo 命令对于环境变量的修改无影响。这是肯定的&#xff0c;echo 命令相当于只是一个打印的函数&#xff08;比如 Python 中的 print&#xff09;。这里要修改环境变…

php assert eval,代码执行函数之一句话木马

前言大家好&#xff0c;我是阿里斯&#xff0c;一名IT行业小白。非常抱歉&#xff0c;昨天的内容出现瑕疵比较多&#xff0c;今天重新整理后再次发出&#xff0c;修改并添加了细节&#xff0c;另增加了常见的命令执行函数如果哪里不足&#xff0c;还请各位表哥指出。eval和asse…

显卡、显卡驱动、CUDA、CUDA Toolkit、cuDNN 梳理

显卡、显卡驱动、CUDA、CUDA Toolkit、cuDNN 梳理 转自&#xff1a;https://www.cnblogs.com/marsggbo/p/11838823.html#nvccnvidia-smi GPU型号含义 显卡&#xff1a; 简单理解这个就是我们前面说的GPU&#xff0c;尤其指NVIDIA公司生产的GPU系列&#xff0c;因为后面介绍的…

php中msubstr,PHP学习:thinkphp中字符截取函数msubstr()用法分析

《PHP学习&#xff1a;thinkphp中字符截取函数msubstr()用法分析》要点&#xff1a;本文介绍了PHP学习&#xff1a;thinkphp中字符截取函数msubstr()用法分析&#xff0c;希望对您有用。如果有疑问&#xff0c;可以联系我们。本文实例讲述了thinkphp中字符截取函数msubstr()用法…

VS Code的Error: Running the contributed command: ‘_workbench.downloadResource‘ failed解决

VS Code的Error: Running the contributed command: _workbench.downloadResource failed解决 转自&#xff1a;https://blog.csdn.net/ibless/article/details/118610776 1 问题描述 此前&#xff0c;本人参考网上教程在VS Code中配置了“Remote SSH”插件&#xff08;比如这…

Oracle闪回报错,oracle 闪回区满了,ORA-19815

oracle 闪回区满了&#xff0c;查看日志报错&#xff1a;ORA-19815&#xff0c;命令行输入&#xff1a;sqlplus / as sysdbastartup mount //如果你的数据库出现了无法连接的情况时&#xff0c;可以加上这句select file_type, percent_space_used as used,percent_space_rec…

[2021-ICCV] MUSIQ Multi-scale Image Quality Transformer 论文简析

[2021-ICCV] MUSIQ: Multi-scale Image Quality Transformer 论文简析 论文&#xff1a;https://arxiv.org/abs/2108.05997 代码&#xff1a;https://github.com/google-research/google-research/tree/master/musiq 概述 当前SOTA的IQA&#xff08;图像质量评估&#xff0…

安装oracle不动了,windows2008安装ORACLE到2%不动的问题 | 信春哥,系统稳,闭眼上线不回滚!...

最近又有网友遇到在windows2008服务器上安装ORACLE软件时到2%就卡住不动的问题&#xff0c;下面是该网友的描述&#xff1a;oralce 11g r2 windows server 2008 R2安装到最后一步复制数据文件时卡到2% 不走了内存一直飙升求解决这个问题前段时间也有人遇到过&#xff0c;但是他…

手把手教你入门Git --- Git使用指南(Linux)

手把手教你入门Git — Git使用指南&#xff08;Linux&#xff09; 系统&#xff1a;ubuntu 18.04 LTS 本文所有git命令操作实验具有连续性&#xff0c;git小白完全可以从头到尾跟着本文所有给出的命令走一遍&#xff0c;就会对git有一个初步的了解&#xff0c;应当能做到会用并…