Torch 模型 感受野可视化

前言:感受野是卷积神经网络 (CNN) 中一个重要的概念,它表示 CNN 每一层输出的特征图上的像素点在输入图像上映射的区域。感受野的大小和形状直接影响到网络对输入图像的感知范围和精度,进而调整网络结构、卷积核大小和步长等参数,以改善网络的性能。

效果:本文的实验在 torchvision.models 中的 resnet18 上进行,分别绘制了理论感受野、训练前感受野、训练后感受野

5db41ff89046413db29b9d2546c6e5b9.png

开发环境:PyTorch 1.9.0

适用模型:最大池化层使用 nn.MaxPool 而不是 torch.nn.functional.max_pool 的模型

声明:本文所使用代码不开源,觉得本文的思路可行的话,请加 QQ - 1398173074 购买 (¥40,注明来意)

商品仅包含一份 120+ 行的代码。本文所使用的代码基于 torch、matplotlib 以及其它标准库。其中包含一个名为 ReceptiveField 的类,用于绘制图像识别网络的感受野

代码实现

ReceptiveField 提供了以下函数:

  • _replace:将 MaxPool (这种求最大值的操作会影响感受野的正确性) 替换为 AvgPool
  • __init__:注册前向传播的“挂钩”,用于提取目标层的特征图用于反向传播
  • _backward:前向推导图像,利用“挂钩”获取特征图,从特征图中心点反向传播梯度,进行一系列处理后将梯度图转换为感受野图
  • theoretical:结合 _backward 函数求解理论感受野,其结果经过 sum、sqrt 之后即为理论感受野的尺寸
  • effective:默认情况下结合 _backward 函数求解训练前感受野 (即随机权重的模型);给定 state_dict 时将加载权重,求解训练后的感受野
  • compare:使用 matplotlib 绘制理论感受野、训练前感受野、训练后感受野
class ReceptiveField:""" :param model: 需要进行可视化的模型:param tar_layer: 感兴趣的层, 其所输出特征图需有 4 个维度 [B, C, H, W]:param img_size: 测试时使用的图像尺寸"""def make_input(self, n_sample): ...def __init__(self,model: nn.Module,tar_layer: Union[int, nn.Module],img_size: Union[int, Tuple[int, int]],use_cuda: bool = False,use_copy: bool = False): ...def compare(self, theoretical=True, original=True, state_dict=None, **imshow_kw):""" :param theoretical: 是否绘制理论感受野:param original: 是否绘制训练前的感受野:param state_dict: 模型权值, 如果提供则绘制训练后的感受野"""def effective(self, state_dict=None):""" :param state_dict: 模型权值, 如果提供则绘制训练后的感受野"""def theoretical(self, light=1.):""" :param light: 理论感受野的亮度 [0, 1]"""def _replace(self, model): ...def _backward(self, x): ...

在本文的示例中,对 resnet18 的 layer3 进行了可视化,并计算出理论感受野的尺寸为 211×211

if __name__ == "__main__":from torchvision.models import resnet18# Step 1: 刚完成初始化的模型, 权重<完全随机>, 表 "训练前"m = resnet18()# Step 2: 训练完成后的 state_dict, 等待 ReceptiveField 加载state_dict = resnet18(pretrained=True).state_dict()# Step 3: 绘制感受野 (设置 ReceptiveField 的 use_copy=True, 将创建模型的深拷贝副本)with ReceptiveField(m, tar_layer=m.layer3, img_size=256, use_copy=True) as r:r.compare(state_dict=state_dict)# 理论感受野的尺寸s = round(r.theoretical().sum() ** 0.5)print(f"Theoretical RF: {s}×{s}")plt.show()# Step 4: 加载模型的参数m.load_state_dict(state_dict)

如果将 resnet18 中的某一个卷积改成空洞卷积,感受野将进一步增大到 243×243

if __name__ == "__main__":from torchvision.models import resnet18# Step 1: 刚完成初始化的模型, 权重<完全随机>, 表 "训练前"m = resnet18()print(m)m.layer3[1].conv1.dilation = 2m.layer3[1].conv1.padding = 2# Step 2: 训练完成后的 state_dict, 等待 ReceptiveField 加载state_dict = resnet18(pretrained=True).state_dict()# Step 3: 绘制感受野 (设置 ReceptiveField 的 use_copy=True, 将创建模型的深拷贝副本)with ReceptiveField(m, tar_layer=m.layer3, img_size=256, use_copy=True) as r:r.compare(state_dict=state_dict)# 理论感受野的尺寸s = round(r.theoretical().sum() ** 0.5)print(f"Theoretical RF: {s}×{s}")plt.show()# Step 4: 加载模型的参数m.load_state_dict(state_dict)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/2165.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

javaweb-maven

前端HTML,CSS,JS,Vue&#xff0c;Element&#xff0c;Nginx最后去复习&#xff0c; Java开发工程师 主要学习方向是服务端 所以进入javaweb的服务端的第一个知识点 maven 什么是maven 用于管理和构建java项目的工具 maven的官方网站 Maven – Welcome to Apache Maven …

Flink面试(1)

1.Flink 的并行度的怎么设置的&#xff1f; Flink设置并行度的几种方式 1.代码中设置setParallelism() 全局设置&#xff1a; 1 env.setParallelism(3);  算子设置&#xff08;部分设置&#xff09;&#xff1a; 1 sum(1).setParallelism(3) 2.客户端CLI设置&#xff0…

邀请全球创作者参与 The Sandbox 创作者训练营

作为首屈一指的元宇宙平台之一&#xff0c;The Sandbox 的使命是成为全球创作者的中心。随着我们对 Game Maker 的不断改进、旨在激发创作者灵感的定期 Game Jams、革命性的 "创作者挑战 "以及众多其他活动的开展&#xff0c;我们见证了大量个人加入我们充满活力的创…

opencv_5_图像像素的算术操作

方法1&#xff1a;调用库函数 void ColorInvert::mat_operator(Mat& image) { Mat dst; Mat m Mat::zeros(image.size(), image.type()); m Scalar(2, 2, 2); multiply(image, m, dst); m1 Scalar(50,50, 50); //divide(image, m, dst); //add(im…

WordPress social-warfare插件XSS和RCE漏洞【CVE-2019-9978】

WordPress social-warfare插件XSS和RCE漏洞 ~~ 漏洞编号 : CVE-2019-9978 影响版本 : WordPress social-warfare < 3.5.3 漏洞描述 : WordPress是一套使用PHP语言开发的博客平台&#xff0c;该平台支持在PHP和MySQL的服务器上架设个人博客网站。social-warfare plugin是使用…

AIGC元年大模型发展现状手册

零、AIGC大模型概览 AIGC大模型在人工智能领域取得了重大突破&#xff0c;涵盖了LLM大模型、多模态大模型、图像生成大模型以及视频生成大模型等四种类型。这些模型不仅拓宽了人工智能的应用范围&#xff0c;也提升了其处理复杂任务的能力。a.) LLM大模型通过深度学习和自然语…

MSR是个什么寄存器

MSR 这种寄存器专门用于调试、程序执行跟踪、计算机性能监控、简化软件编程、电源控制等等各种实验性功能。 什么是 MSR MSR 的概念是不易理解&#xff0c;所以这一节只说一些 MSR 的外在&#xff0c;比如形容和指令等&#xff0c;然后展开说说&#xff0c;看完整篇文章你应该…

计算机视觉 CV 八股分享 [自用](更新中......)

目录 一、深度学习中解决过拟合方法 二、深度学习中解决欠拟合方法 三、梯度消失和梯度爆炸 解决梯度消失的方法 解决梯度爆炸的方法 四、神经网络权重初始化方法 五、梯度下降法 六、BatchNorm 七、归一化方法 八、卷积 九、池化 十、激活函数 十一、预训练 十二…

【uniapp】 合成海报组件

之前公司的同事写过一个微信小程序用的 合成海报的组件 非常十分好用 最近的项目是uni的 把组件改造一下也可以用 记录一下 <template><view><canvas type"2d" class"_mycanvas" id"my-canvas" canvas-id"my-canvas" …

RT-Thread电源管理组件

电源管理组件 嵌入式系统低功耗管理的目的在于满足用户对性能需求的前提下&#xff0c;尽可能降低系统能耗以延长设备待机时间。 高性能与有限的电池能量在嵌入式系统中矛盾最为突出&#xff0c;硬件低功耗设计与软件低功耗管理的联合应用成为解决矛盾的有效手段。 现在的各种…

排序算法之桶排序

目录 一、简介二、代码实现三、应用场景 一、简介 算法平均时间复杂度最好时间复杂度最坏时间复杂度空间复杂度排序方式稳定性桶排序O(nk )O(nk)O(n^2)O(nk)Out-place稳定 稳定&#xff1a;如果A原本在B前面&#xff0c;而AB&#xff0c;排序之后A仍然在B的前面&#xff1b; 不…

Kotlin语法快速入门--条件控制和循环语句(2)

Kotlin语法入门–条件控制和循环语句&#xff08;2&#xff09; 文章目录 Kotlin语法入门--条件控制和循环语句&#xff08;2&#xff09;二、条件控制和循环语句1、if...else2、when2.1、常规用法2.2、特殊用法--并列&#xff1a;2.3、特殊用法--类型判断&#xff1a;2.4、特殊…

C语言进阶课程学习记录-第48课 - 函数设计原则

C语言进阶课程学习记录 - 函数设计原则 本文学习自狄泰软件学院 唐佐林老师的 C语言进阶课程&#xff0c;图片全部来源于课程PPT&#xff0c;仅用于个人学习记录

无人驾驶 自动驾驶汽车 环境感知 精准定位 决策与规划 控制与执行 高精地图与车联网V2X 深度神经网络学习 深度强化学习 Apollo

无人驾驶 百度apollo课程 1-5 百度apollo课程 6-8 七月在线 无人驾驶系列知识入门到提高 当今,自动驾驶技术已经成为整个汽车产业的最新发展方向。应用自动驾驶技术可以全面提升汽车驾驶的安全性、舒适性,满足更高层次的市场需求等。自动驾驶技术得益于人工智能技术的应用…

端口被占用的解决方案汇总

端口被占用的解决方案汇总 【一】windows系统端口被占用【二】Linux系统端口被占用【三】Linux的ps命令查找&#xff08;1&#xff09;ps命令常用的方式有三种&#xff08;2&#xff09;ps -ef |grep 8080 【一】windows系统端口被占用 &#xff08;1&#xff09;键盘上按住Wi…

【LeetCode刷题记录】21. 合并两个有序链表

21 合并两个有序链表 将两个升序链表合并为一个新的 升序 链表并返回。新链表是通过拼接给定的两个链表的所有节点组成的。 示例 1&#xff1a; 输入&#xff1a;l1 [1,2,4], l2 [1,3,4] 输出&#xff1a;[1,1,2,3,4,4] 示例 2&#xff1a; 输入&#xff1a;l1 [], l2 …

# MySQL中的并发控制,读写锁,和锁的粒度

MySQL中的并发控制&#xff0c;读写锁&#xff0c;和锁的粒度 并发控制的概述 在数据库系统中&#xff0c;并发控制是一种用于确保当多个用户同时访问数据库时&#xff0c;系统能够提供数据的一致性和隔离性的机制。MySQL支持多种并发控制技术&#xff0c;其中包括锁机制、多…

调试 WebSocket API 技巧分享

WebSocket 是一种在单个 TCP 连接上实现全双工通信的先进 API 技术。与传统的 HTTP 请求相比&#xff0c;WebSocket 提供了更低的延迟和更高的通信效率&#xff0c;使其成为在线游戏、实时聊天等应用的理想选择。 开始使用 Apifox 的 WebSocket 功能 首先&#xff0c;在项目界…

node和go的列表转树形, 执行速度测试对比

保证数据一致性&#xff0c;先生成4000条json数据到本地&#xff0c;然后分别读取文本执行处理 node代码 node是用midway框架 forNum1:number 0forNum2:number 0//执行测试async index(){// 生成菜单列表// const menuList await this.generateMenuList([], 4000);const men…

双周总结#008 - AIGC

本周参与了公司同事对 AIGC 的分享会&#xff0c;分享了 AIGC 在实际项目中的实践经验&#xff0c;以及如何进行 AIGC 的落地。内容分几项内容&#xff1a; 什么是 AIGCAIGC 能做什么AIGC 工具 以年终总结为例&#xff0c;分享了哪些过程应用了 AIGC&#xff0c;以及 AIGC 落地…