pytorch forward_pytorch使用hook打印中间特征图、计算网络算力等

fc2acf1e55d9cc946bfefc2a72a63ee5.png

0、参考

https://oldpan.me/archives/pytorch-autograd-hook

https://pytorch.org/docs/stable/search.html?q=hook&check_keywords=yes&area=default

https://github.com/pytorch/pytorch/issues/598

https://github.com/sksq96/pytorch-summary

https://github.com/allensll/test/blob/591c7ce3671dbd9687b3e84e1628492f24116dd9/net_analysis/viz_lenet.py

1、背景

在神经网络的反向传播当中个,流程只保存叶子节点的梯度,对于中间变量的梯度没有进行保存。

import torch
x = torch.tensor([1,2],dtype=torch.float32,requires_grad=True)
y = x+2
z = torch.mean(torch.pow(y, 2))
lr = 1e-3
z.backward()
x.data -= lr*x.grad.data
print(y.grad)

此时输出就是:None,这个时候hook的作用就派上,hook可以通过自定义一些函数,从而完成中间变量的输出,比如中间特征图、中间层梯度修正等。

​ 在pytorch docs搜索hook,可以发现有四个hook相关的函数,分别为register_hook,register_backward_hook,register_forward_hook,register_forward_pre_hook。其中register_hook属于tensor类,而后面三个属于moudule类。

  • register_hook函数属于torch.tensor类,函数在tensor梯度计算的时候就会执行,这个函数主要处理梯度相关的数据,表现形式$hook(grad) rightarrow Tensor or None$.
import torch
x = torch.tensor([1,2],dtype=torch.float32,requires_grad=True)
y = x * 2
y.register_hook(print)
<torch.utils.hooks.RemovableHandle at 0x7f765e876f60>
z = torch.mean(y)
z.backward()
tensor([ 0.5000,  0.5000])
  • Register_backward_hook等三个属于torch.nn,属于moudule中的方法。
hook(module, grad_input, grad_output) -> Tensor or None

写个demo,参考:

下面的计算为

import torch
import torch.nn as nn
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")def print_hook(grad):print ("register hook:", grad)return gradclass TestNet(nn.Module):def __init__(self):super(TestNet, self).__init__()self.f1 = nn.Linear(4, 1, bias=True)self.weights_init()def weights_init(self):self.f1.weight.data.fill_(4)self.f1.bias.data.fill_(0.1)def forward(self, input):self.input = inputout = input * 0.75out = self.f1(out)out = out / 4return outdef back_hook(self, moudle, grad_input, grad_output):print ("back hook in:", grad_input)print ("back hook out:", grad_output)# 修改梯度# grad_input = list(grad_input)# grad_input[0] = grad_input[0] * 100# print (grad_input)return tuple(grad_input)if __name__ == '__main__':input = torch.tensor([1, 2, 3, 4], dtype=torch.float32, requires_grad=True).to(device)net = TestNet()net.to(device)net.register_backward_hook(net.back_hook)ret = net(input)print ("result", ret)ret.backward()print('input.grad:', input.grad)for param in net.parameters():print('{}:grad->{}'.format(param, param.grad))

输出:

result tensor([7.5250], grad_fn=<DivBackward0>)
back hook in: (tensor([0.2500]), None)
back hook out: (tensor([1.]),)
input.grad: tensor([0.7500, 0.7500, 0.7500, 0.7500])
Parameter containing:
tensor([[4., 4., 4., 4.]], requires_grad=True):grad->tensor([[0.1875, 0.3750, 0.5625, 0.7500]])
Parameter containing:
tensor([0.1000], requires_grad=True):grad->tensor([0.2500])

输出结果以及梯度都很明显,简单分析一下w权重的梯度,

另外,hook中有个bug,假设我们bug,假设我们注释掉out = out / 4这行,可以发现输出变成back hook in: (tensor([1.]), tensor([1.]))。这种情况就不符合上面我们的梯度计算公式,是因为这个时候:

则此时的偏导只是对

进行计算,所以都是1,1。这是pytorch的设计缺陷

c0f3c2eaea4270329b9560d7b13622ff.png
  • register_forward_hook跟Register_backward_hook差不多,就不过多复述。
  • register_forward_pre_hook,可以发现其输入只有hook(module, input) -> None
    其主要是针对推理时的hook.

2、应用

2.1 特征图打印

​ 直接利用pytorch已有的resnet18进行特征图打印,只打印卷积层的特征图,

import torch
from torchvision.models import resnet18
import torch.nn as nn
from torchvision import transformsimport matplotlib.pyplot as pltdef viz(module, input):x = input[0][0]#最多显示4张图min_num = np.minimum(4, x.size()[0])for i in range(min_num):plt.subplot(1, 4, i+1)plt.imshow(x[i].cpu())plt.show()import cv2
import numpy as np
def main():t = transforms.Compose([transforms.ToPILImage(),transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])])device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")model = resnet18(pretrained=True).to(device)for name, m in model.named_modules():# if not isinstance(m, torch.nn.ModuleList) and #         not isinstance(m, torch.nn.Sequential) and #         type(m) in torch.nn.__dict__.values():# 这里只对卷积层的feature map进行显示if isinstance(m, torch.nn.Conv2d):m.register_forward_pre_hook(viz)img = cv2.imread('./cat.jpeg')img = t(img).unsqueeze(0).to(device)with torch.no_grad():model(img)if __name__ == '__main__':main()

直接放几张中间层的图

70f5f173e24d17147e5b8704cce87918.png
图1 第一层卷积层输入

f955bb70ffca1cf64b78dbd6e7ce59ae.png
图2 第四层卷积层的输入

2.2 模型大小,算力计算

同样的用法,可以直接参考pytorch-summary这个项目。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/246122.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Games101现代图形学入门Lecture 4: Transformation Cont知识点总结

视频链接&#xff1a;https://www.bilibili.com/video/BV1X7411F744?p4 课程主页链接&#xff1a;http://games-cn.org/intro-graphics/ 课件PPT链接&#xff1a;http://games-cn.org/graphics-intro-ppt-video/ 1. 3D变换 缩放和平移矩阵 旋转矩阵 欧拉角&#xff1a;rol…

Hash和红黑树以及其在C#中的应用

参考资料&#xff1a; .Net 中HashTable&#xff0c;HashMap 和 Dictionary<key,value> 和List<T>和DataTable的比较 - 王若伊_恩赐解脱 - 博客园 c#HashSet源码解析_fdyshlk的博客-CSDN博客_c# hashset 红黑树和哈希表的区别 - 安全技术 - 亿速云 一、基本概念…

networkx 标签_networkx绘制BA无标度网络

step1: 导入networkx复杂网络库、matplotlib.pyplot、pandasimport networkx as nximport matplotlib.pyplot as pltimport pandas as pdstep2: 绘制BA无标度网络Gnx.barabasi_albert_graph(1000,1) #generate BA networkposnx.spring_layout(G) #set layoutnodecolorG.degree(…

Unity URP中的多Pass Shader和Planer shadow

一 .Unity移动端软阴影技术总结&#xff1a; https://blog.csdn.net/jxw167/article/details/82422891 二. 平面阴影的原理 https://zhuanlan.zhihu.com/p/42781261 https://zhuanlan.zhihu.com/p/31504088 王者荣耀游戏使用的就是该方法&#xff0c;已经有上线产品验证过…

java连接mongodb_第78天: Python 操作 MongoDB 数据库介绍

MongoDB 是一款面向文档型的 NoSQL 数据库&#xff0c;是一个基于分布式文件存储的开源的非关系型数据库系统&#xff0c;其内容是以 K/V 形式存储&#xff0c;结构不固定&#xff0c;它的字段值可以包含其他文档、数组和文档数组等。其采用的 BSON(二进制 JSON )的数据结构&am…

URP中的2D Light光照在移动端不生效的问题

最近在尝试用URP推出的还在preview阶段的2D Render系统&#xff0c;发现2D光照在打成APK后失效&#xff0c;尝试了些方法后发现把2d光照用到的shader放进设置中的built in shader后可以解决问题&#xff1a;

大连开发区取暖费能微信支付吗_下半年教资报考人数增加,那到底能不能异地报考呢?...

想要每周获取两篇群文件快扫码进群吧~因为教师资格证认定的问题&#xff0c;最近教师资格证备考又被广大考生提上了日程&#xff0c;由于“先上岗&#xff0c;后考证”政策&#xff0c;小编预测下一年教师资格证考试的通过率肯定没有以前那么高了&#xff0c;不少人就想选择异地…

python3项目源代码下载_2019年最值得关注的34个Python开源项目——Let's go!

踏着人工智能、区块链的东风&#xff0c;近年来一路“横冲直撞”的 Python 在实现了从小众语言到主流的完美转身后&#xff0c;一头扎进了 2019&#xff0c;依旧没有透出丝毫停下来的架势&#xff0c;反倒有些越烧越热的味道。本文将为你介绍 2019 年最值得关注的 34 个 Python…

Unity 音频优化方案

参考资料&#xff1a; https://www.cnblogs.com/bearhb/p/11210136.html https://blog.csdn.net/chenfujun818/article/details/81710895 文件格式 mp3:失真小&#xff0c;适合音质要求高的文件&#xff0c;例如BGM wav:资源大&#xff0c;不推荐 ogg:压缩比高&#xff0c;适…

android home键后计时拉起app_使用React Native完成App软件

搭建开发环境安装react-native-cli&#xff1a;npm i -g react-native-cliAndroid SDK安装Android SDK并启动进行配置&#xff1a;配置环境变量export ANDROID_HOME~/Library/Android/sdk export PATH${PATH}:${ANDROID_HOME}/tools export PATH${PATH}:${ANDROID_HOME}/platfo…

Unity AssetBundle内存管理相关问题

AssetBundle机制相关资料收集 最近网友通过网站搜索Unity3D在手机及其他平台下占用内存太大. 这里写下关于Unity3D对于内存的管理与优化. Unity3D 里有两种动态加载机制&#xff1a;一个是Resources.Load&#xff0c;另外一个通过AssetBundle,其实两者区别不大。 Resources.L…

移动超级sim卡 无法下载卡_中国移动发布超级SIM卡:全变了

近日&#xff0c;中国移动正式公布了《中国移动超级SIM卡技术白皮书》&#xff0c;明确乐中国移动对于个人领域SIM卡的发展方向、架构设计、能力要求&#xff0c;旨在为行业规划设计SIM卡相关技术、产品和解决方案时提供参考和指引。据悉&#xff0c;中国移动的超级SIM卡增强了…

echart中拆线点的偏移_Qt中圆弧和扇形的绘制

在超声软件的开发中&#xff0c;超声成像模块需要绘制圆弧&#xff0c;例如绘制一个扇形的取样框&#xff0c;左右是一条直线&#xff0c;上下是一个圆弧&#xff0c;像这样。Qt中使用QPainter::drawArc绘制圆弧&#xff0c;使用QPainter::drawPie绘制扇形。圆弧和扇形的绘制接…

反向Z(Reversed-Z)的深度缓冲原理

参考文章&#xff1a;https://zhuanlan.zhihu.com/p/75517534 https://zjinc36.github.io/2020/03/10/2020-20200309-%E6%B7%B1%E5%85%A5%E7%90%86%E8%A7%A3%E6%B5%AE%E7%82%B9%E6%95%B0%E4%B8%8E%E6%B5%AE%E7%82%B9%E6%95%B0%E7%9A%84%E7%B2%BE%E5%BA%A6%E9%97%AE%E9%A2%98/ …

output怎么用_性能领先,即训即用,快速部署,飞桨首次揭秘服务器端推理库

允中 发自 凹非寺量子位 编辑 | 公众号 QbitAI假如问在深度学习实践中&#xff0c;最难的部分是什么&#xff1f;猜测80%的开发者都会说&#xff1a;“当然是调参啊。”为什么难呢&#xff1f;因为调参就像厨师根据食材找到了料理配方&#xff0c;药剂师根据药材找到了药方&…

GPU架构杂乱备忘——IMR、TBR、TBDR

原文&#xff1a;https://juejin.cn/post/6844904132864655367 GPU架构杂乱备忘——IMR、TBR、TBDR 之前觉得涉及到gpu架构相关的问题只需要知道个大概就好&#xff0c;毕竟在图形api的层面上应该把硬件的细节给隐蔽掉&#xff0c;gpu的架构千千万万&#xff0c;每家厂商每个…

requests下载大文件_11种方法教你用Python高效下载资源!

在本教程中&#xff0c;你将学习如何使用不同的Python模块从web下载文件。此外&#xff0c;你将下载常规文件、web页面、Amazon S3和其他资源。最后&#xff0c;你将学习如何克服可能遇到的各种挑战&#xff0c;例如下载重定向的文件、下载大型文件、完成一个多线程下载以及其他…

android自定义push通知_20个海外Web和App推送通知服务工具

在App和网站中使用推送通知有不同的原因&#xff0c;并且在提高流量和与客户互动方面有很多好处。推送通知是一种交互式可点击消息&#xff0c;可将访问者直接引导至你的网站。它们可以帮助你以指数方式增加流量和参与率。因此&#xff0c;营销人员&#xff0c;广告商&#xff…

linux 删除文件_Linux删除文件夹命令有哪些

今天要和大家分享的Linux常用命令是Linux删除文件夹命令,Linux删除文件夹很简单,常用的命令有rmdir和rm,以下分别介绍一下,大家根据情况选择使用即可。 Linux删除文件夹命令有哪些 ①Linux删除文件夹命令:rmdir rmdir命令使用场景: 当有空目录要删除时,可使用rmdir指令。…

url上接收到 el表达式 不渲染_一文摸透从输入URL到页面渲染的过程

一文摸透从输入URL到页面渲染的过程从输入URL到页面渲染需要Chrome浏览器的多个进程配合&#xff0c;所以我们先来谈谈现阶段Chrome浏览器的多进程架构。一、Chrome架构目前Chrome采用的是多进程的架构模式&#xff0c;可分为主要的五类进程&#xff0c;分别是&#xff1a;浏览…