Pytorch register_forward_hook()

一、hook的意义:

在不改动网络结构的情况下获取网络中间层输出

没有使用hook的时候,想要得到conv2的输出,就要将在forward函数中经过conv2后的结果保存下来,然后和最终结果一起返回。

import torch
import torch.nn as nn
import torch.nn.functional as Fclass LeNet(nn.Module):def __init__(self):super(LeNet, self).__init__()self.conv1 = nn.Conv2d(3, 6, 5)self.conv2 = nn.Conv2d(6, 16, 5)self.fc1 = nn.Linear(16*5*5, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)
def forward(self, x):out = self.conv1(x)out = F.relu(out)     out = F.max_pool2d(out, 2)      out = self.conv2(out)out_conv2 = outout = F.relu(out)out = F.max_pool2d(out, 2)out = out.view(out.size(0), -1)out = F.relu(self.fc1(out))out = F.relu(self.fc2(out))out = self.fc3(out)return out, out_conv2

缺点: 

  1. 太麻烦 
  2. 但很多时候,我们并没有办法去直接修改网络的源代码,比如在pytorch中已经封装好的网络,那么这个时候就可以利用hook从外部获取Module的中间输出结果了。

所以可以通过使用hook的方式来得到model的中间结果,并不修改model的代码

二、使用方法:

1、定义hook函数

hook()函数是register_forward_hook()函数必须提供的参数,好处是“用户可以自行决定拦截了中间信息之后要做什么!”,比如自己想单纯的记录网络的输入输出(也可以进行修改等更加复杂的操作)。首先定义几个容器用于记录:

hook函数需要三个参数(这三个参数的名字你可以自己定义,但是必须是三个),这三个参数是系统传给hook函数的,自己不能修改这三个参数:

hook(module, input, output) -> None or modified output

# 1:定义module_name用于记录相应的module名字、定义用于获取网络各层输入输出tensor的容器
module_name = []
features_in_hook = []
features_out_hook = []
# 2:hook函数负责将相应的module名字、获取的输入输出 添加到feature列表中
def hook(module, fea_in, fea_out):print("hooker working")module_name.append(module.__class__)features_in_hook.append(fea_in)features_out_hook.append(fea_out)return None

2、在需要hook的网络层进行register

# load model
net = LeNet()# 确定取出哪一层的输出,“net.conv2”要和init函数中的self.conv2保持一致
# 在forward中第一次使用“conv2”时hook住,并将结果存储进hook函数
handle = net.conv2.register_forward_hook(hook)

3、走整个forward,然后得到hook的输入

# 将输入输入进model,让输出走过整个forward
x = torch.randn(2, 3, 32, 32)
y = net(x)# 得到hook的输出
print(module_name)
print(features_in_hook)
print(features_out_hook)

4、移除hook

# 将hook移除
handle.remove()

三、完整代码:

import torch
import torch.nn as nn
import torch.nn.functional as Fclass LeNet(nn.Module):def __init__(self):super(LeNet, self).__init__()self.conv1 = nn.Conv2d(3, 6, 5)self.conv2 = nn.Conv2d(6, 16, 5)self.fc1 = nn.Linear(16 * 5 * 5, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x):out = self.conv1(x)out = F.relu(out)out = F.max_pool2d(out, 2)# 在这里hook住,因为这是第一次出现conv2的地方out = self.conv2(out)# hook结束后,得到结果,然后继续forwardout = F.relu(out)out = F.max_pool2d(out, 2)out = out.view(out.size(0), -1)out = F.relu(self.fc1(out))out = F.relu(self.fc2(out))out = self.fc3(out)return out# 1:定义用于获取网络各层输入输出tensor的容器
# 并定义module_name用于记录相应的module名字
module_name = []
features_in_hook = []
features_out_hook = []
# 2:hook函数负责将相应的module名字、获取的输入输出 添加到feature列表中
def hook(module, fea_in, fea_out):print("hooker working")module_name.append(module.__class__)features_in_hook.append(fea_in)features_out_hook.append(fea_out)return None# load model
net = LeNet()# 确定取出哪一层的输出,“net.conv2”要和init函数中的self.conv2保持一致
# 在forward中第一次使用“conv2”时hook住,并将结果存储进hook函数
handle = net.conv2.register_forward_hook(hook)# 将输入输入进model,让输出走过整个forward
x = torch.randn(2, 3, 32, 32)
y = net(x)# 得到hook的输出
print(module_name)
print(features_in_hook)
print(features_out_hook)# 将hook移除
handle.remove()

pytorch的hook机制之register_forward_hook - 知乎

Pytorch register_forward_hook()简单用法_pytorch forward hook-CSDN博客

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/795358.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Nginx 安装与实践

目录 一、安装 Nginx1、先安装 Brew2、再安装 Nginx 二、常用的 Nginx 命令三、简单的 Nginx 配置四、查看日志的 Linux 命令1、查看日志的 Linux 命令2、实时查看项目运行时打印的日志 一、安装 Nginx 推荐使用 HomeBrew 来安装 Nginx。 1、先安装 Brew 详见:Home…

.NET面试题20道

1、 switch的break作用: 如果加了break,则break的作用是在相应的位置跳出整个循环: 2、Sting和 StringBuilder的区别: String 对象是不可改变,空间不足时需要为该新对象分配新的空间&#xff0c…

Linux:进程等待究竟是什么?如何解决子进程僵尸所带来的内存泄漏问题?

Linux:进程等待究竟是什么?如何解决子进程僵尸所带来的内存泄漏问题? 一、进程等待的概念二、进程等待存在的意义三、如何进行进程等待3.1 wait()是实现进程等待1、wait()原型2. 验证wait()能回收僵尸子进程的空间 3.2 waitpid()实现进程等待…

Kali Linux汉化教程

以下是Kali Linux的汉化教程: 打开终端,并切换为root用户。可以使用命令“sudo su root”来切换到root用户。 更新源。使用命令“apt-get update”来更新系统的软件包列表。 安装中文字体。可以使用命令“apt install ttf-wqy-zenhei”来安装中文字体。…

560.和为K的子数组

560.和为K的子数组 给你一个整数数组 nums 和一个整数 k ,请你统计并返回 该数组中和为 k 的子数组的个数 。 子数组是数组中元素的连续非空序列。 示例 1: 输入:nums [1,1,1], k 2 输出:2示例 2: 输入&#xf…

Vue3:优化-从响应式数据中获取纯数据

一、情景说明 我们知道,Vue3中,创建变量时,常用ref、reactive来包裹,这样,这个变量就是响应式数据 然而,有时候,我们只需要纯数据 例如,我们在调用后端接口的时候,我们只…

Win10 下 git error unable to create file Invalid argument 踩坑实录

原始解决方案参看:https://stackoverflow.com/questions/26097568/git-pull-error-unable-to-create-file-invalid-argument 本问题解决于 2024-02-18,使用 git 版本 2.28.0.windows.1 解决方案 看 Git 抛出的出错的具体信息,比如如下都来自…

GPU的了解

3D动画揭秘显卡的GPU是如何工作的_哔哩哔哩_bilibili 位于显卡中。 与CPU区别: 100名小学生和1位数学博士 做100道非常简单的算术题,小朋友一个人一道题,比博士快。 做1道非常复杂的数学问题,只有博士可以做出来。 CPU主要用于快…

DeepFM。FM(Factorization Machine,因子分解机)。大规模稀疏矩阵。协同过滤方法。

目录 DeepFM。 FM(Factorization Machine,因子分解机)。 大规模稀疏矩阵中的特征组合问题。

c51 单片机如何控制小灯闪烁?

目录 硬件电路设计 软件编程 烧录程序 测试 调整和优化 C51单片机是一种经典的8位微控制器,广泛应用于各种嵌入式系统和智能控制项目中。 C51单片机控制小灯闪烁主要涉及到硬件电路设计和软件编程两个方面。下面是一个基本的步骤说明: 硬件电路设计…

零基础10 天入门 Web3之第2天

10 天入门 Web3之第2天Web3 是互联网的下一代,它将使人们拥有自己的数据并控制自己的在线体验。Web3 基于区块链技术,该技术为安全、透明和可信的交易提供支持。我准备做一个 10 天的学习计划,可帮助大家入门 Web3: 一、这是第二…

铸铁平台合理布局的重要性

铸铁平台合理布局的重要性是为了确保工作环境的安全和效率。以下是一些重要的原因: 安全性:合理布局可以最大限度地减少工作场所的事故和伤害。通过将设备和材料放置在正确的位置,可以降低工作人员被危险物体击中或跌倒的风险。此外&#xff…

【瑞萨RA6M3】1. 基于 vscode 搭建开发环境

基于 vscode 搭建开发环境 1. 准备2. 安装2.1. 安装瑞萨软件包2.2. 安装编译器2.3. 安装 cmake2.4. 安装 openocd2.5. 安装 ninja2.6. 安装 make 3. 生成初始代码4. 修改 cmake 脚本5. 调试准备6. 仿真 1. 准备 需要瑞萨仓库中的两个软件: MDK_Device_Packs.zipse…

Android 代码自定义drawble文件实现View圆角背景

简介 相信大多数Android开发都会遇到一个场景,给TextView或Button添加背景颜色,修改圆角,描边等需求。一看到这样的实现效果,自然就是创建drawble文件,设置相关属性shap,color,radius等。然后将…

基于单片机电流变送器系统仿真设计

**单片机设计介绍,基于单片机电流变送器系统仿真设计 文章目录 一 概要二、功能设计设计思路 三、 软件设计原理图 五、 程序六、 文章目录 一 概要 基于单片机电流变送器系统的仿真设计,主要目标是利用仿真技术,模拟单片机与电流变送器之间…

二分答案 蓝桥杯 2022 省A 青蛙过河

有些地方需要解释: 1.从学校到家和从家到学校,跳跃都是一样的,直接看作2*x次过河就可以。 2.对于一个跳跃能力 y,青蛙能跳过河 2x 次,当且仅当对于每个长度为 y 的区间,这个区间内 h 的和都大于等于…

hololens 2 投屏 报错

使用Microsoft HoloLens投屏时,ip地址填对了,但是仍然报错,说hololens 2没有打开, 首先检查 开发人员选项 都打开,设备门户也打开 然后检查系统–体验共享,把共享都打开就可以了

计算机网络—HTTP协议:深入解析与应用实践

​ 🎬慕斯主页:修仙—别有洞天 ♈️今日夜电波:ヒステリックナイトガール 1:03━━━━━━️💟──────── 5:06 🔄 ◀️ ⏸ ▶️ ☰…

java算法day45 | 动态规划part07 ● 70. 爬楼梯 (进阶) ● 322. 零钱兑换 ● 279.完全平方数

70. 爬楼梯 &#xff08;进阶&#xff09; 题目描述&#xff1a; 假设你正在爬楼梯。需要 n 阶你才能到达楼顶。 每次你可以爬至多m (1 < m < n)个台阶。你有多少种不同的方法可以爬到楼顶呢&#xff1f; 注意&#xff1a;给定 n 是一个正整数。 输入描述&#xff1a;输入…

java对象是怎么在jvm中new出来的

java对象是怎么在jvm中new出来的 查看java对象字段属性在内存中的值 java 对象 创建 流程 附上java源码 public class MiDept {private int innerFiled999;public MiDept() {System.out.println("new MiDept--------------");}public String show(int data) {Sy…