基于pytorch使用特征图输出进行特征图可视化

使用特征图输出进行特征图可视化

文章目录

  • 前言
  • 效果展示
  • 获取某一层特征图输出
      • 原图
      • 方法一:使用IntermediateLayerGetter类
      • 方法二:使用hook机制(推荐)
  • 总结


前言

提示:这里可以添加本文要记录的大概内容:

例如:随着人工智能的不断发展,机器学习这门技术也越来越重要,很多人都开启了学习机器学习,本文就介绍了基于pytorch使用特征图输出进行特征图可视化的方法

特征图输出就是某个图像(序列)经过该层时的输出


以下是本篇文章正文内容

效果展示

在这里插入图片描述

获取某一层特征图输出

原图

在这里插入图片描述

方法一:使用IntermediateLayerGetter类

# 返回输出结果
import randomimport cv2
import torchvision
import torch
from matplotlib import pyplot as plt
import numpy as np
from torchvision import transforms
from torchvision import models# 定义函数,随机从0-end的一个序列中抽取size个不同的数
def random_num(size, end):range_ls = [i for i in range(end)]num_ls = []for i in range(size):num = random.choice(range_ls)range_ls.remove(num)num_ls.append(num)return num_lspath = "img_1.png"
transformss = transforms.Compose([transforms.ToTensor(),transforms.Resize((224, 224)),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])# 注意如果有中文路径需要先解码,最好不要用中文
img = cv2.imread(path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)# 转换维度
img = transformss(img).unsqueeze(0)model = models.resnet50(pretrained=True)
new_model = torchvision.models._utils.IntermediateLayerGetter(model, {'layer1': '1', 'layer2': '2', "layer3": "3"})
out = new_model(img)tensor_ls = [(k, v) for k, v in out.items()]# 这里选取layer2的输出画特征图
v = tensor_ls[1][1]# 选择目标卷积层
target_layer = model.layer2[2]
"""
如果要选layer3的输出特征图只需把第一个索引值改为2,即:
v=tensor_ls[2][1]
只需把第一个索引更换为需要输出的特征层对应的位置索引即可
"""
# 取消Tensor的梯度并转成三维tensor,否则无法绘图
v = v.data.squeeze(0)print(v.shape)  # torch.Size([512, 28, 28])# 随机选取25个通道的特征图
channel_num = random_num(25, v.shape[0])
plt.figure(figsize=(10, 10))
for index, channel in enumerate(channel_num):ax = plt.subplot(5, 5, index + 1, )plt.imshow(v[channel, :, :])
plt.savefig("./img/feature.jpg", dpi=300)

输出的结果如下:
在这里插入图片描述

方法二:使用hook机制(推荐)

如下代码所示:

# 返回输出结果
import randomimport cv2
import torchvision
import torch
from matplotlib import pyplot as plt
import numpy as np
from torchvision import transforms
from torchvision import models# 定义函数,随机从0-end的一个序列中抽取size个不同的数
def random_num(size, end):range_ls = [i for i in range(end)]num_ls = []for i in range(size):num = random.choice(range_ls)range_ls.remove(num)num_ls.append(num)return num_lspath = "img_1.png"
transformss = transforms.Compose([transforms.ToTensor(),transforms.Resize((224, 224)),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])# 注意如果有中文路径需要先解码,最好不要用中文
img = cv2.imread(path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)# 转换维度
img = transformss(img).unsqueeze(0)model = models.resnet50(pretrained=True)# 选择目标层
target_layer = model.layer2[2]
# 注册钩子函数,用于获取目标卷积层的输出
outputs = []
def hook(module, input, output):outputs.append(output)hook_handle = target_layer.register_forward_hook(hook)_ = model(img)v = outputs[-1]"""
如果要选layer3的输出特征图只需把第一个索引值改为2,即:
v=tensor_ls[2][1]
只需把第一个索引更换为需要输出的特征层对应的位置索引即可
"""
# 取消Tensor的梯度并转成三维tensor,否则无法绘图
v = v.data.squeeze(0)print(v.shape)  # torch.Size([512, 28, 28])# 随机选取25个通道的特征图
channel_num = random_num(25, v.shape[0])
plt.figure(figsize=(10, 10))
for index, channel in enumerate(channel_num):ax = plt.subplot(5, 5, index + 1, )plt.imshow(v[channel, :, :])
plt.savefig("./img/feature2.jpg", dpi=300)

总结

以上就是今天要讲的内容

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/137501.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【 云原生 | K8S 】kubectl 详解

目录 1 kubectl 2 基本信息查看 2.1 查看 master 节点状态 2.2 查看命名空间 2.3 查看default命名空间的所有资源 2.4 创建命名空间app 2.5 删除命名空间app 2.6 在命名空间kube-public 创建副本控制器(deployment)来启动Pod(nginx-wl…

大数据-之LibrA数据库系统告警处理(ALM-12036 license文件即将过期)

告警解释 系统每天零点检查一次当前系统中的license文件,如果当前时间距离过期时间不足60天,则license文件即将过期,产生该告警。 当重新导入一个正常license,告警恢复。 说明: 如果当前集群使用节点数小于等于10节…

linux环境安装SVN,以及常用的SVN操作

1、检查系统是否已经安装如果安装就卸载 检查: svnserve --version 卸载: yum remove subversion 2、安装 yum install subversion 3、建立SVN库(文件位置可自由) 创建仓库文件夹: mkdir -p /opt/svn/repositor…

RLHF的替代算法之DPO原理解析:从Zephyr的DPO到Claude的RAILF

前言 本文的成就是一个点顺着一个点而来的,成文过程颇有意思 首先,如上文所说,我司正在做三大LLM项目,其中一个是论文审稿GPT第二版,在模型选型的时候,关注到了Mistral 7B(其背后的公司Mistral AI号称欧洲…

049-第三代软件开发-软件部署脚本(一)

第三代软件开发-软件部署脚本(一) 文章目录 第三代软件开发-软件部署脚本(一)项目介绍软件部署脚本(一)其他方式 关键字: Qt、 Qml、 bash、 shell、 脚本 项目介绍 欢迎来到我们的 QML & C 项目!这个项目结合了 QML(Qt Meta-Object…

使用iperf3在macOS上进行网络性能测试

iperf3是一个用于测量网络性能的工具,它可以帮助你了解两台服务器之间的带宽和延迟。本博客将指导你在macOS上安装iperf3,并展示如何连接服务器进行网络性能测试。 步骤1:安装Homebrew 如果你尚未安装Homebrew,可以通过以下步骤…

nfs配置

1.NFS介绍 NFS就是Network File System的缩写,它最大的功能就是可以通过网络,让不同的机器、不同的操 作系统可以共享彼此的文件。 NFS服务器可以让PC将网络中的NFS服务器共享的目录挂载到本地端的文 件系统中,而在本地端的系统中来看&#…

P1908 逆序对 题解

文章目录 题目描述输入格式输出格式样例样例输入样例输出 数据范围与提示完整代码 题目描述 猫猫 TOM 和小老鼠 JERRY 最近又较量上了,但是毕竟都是成年人,他们已经不喜欢再玩那种你追我赶的游戏,现在他们喜欢玩统计。 最近,TOM…

【Git】Gui图形化管理、SSH协议私库集成IDEA使用

一、Gui图形化界面使用 1、根据自己需求打开管理器 2、克隆现有的库 3、图形化界面介绍 1、首先在本地仓库更新一个代码文件,进行使用: 2、进入图形管理界面刷新代码资源: 3、点击Stage changed 跟踪文件,将文件处于暂存区 4、通过…

后端Java日常实习生面试(2023年11月10日)

面试岗位为:Java 后端开发实习生 面试时长:30分钟 面试时间:2023年11月10日 首先介绍一下项目吧 这里介绍时有一个失误,没有主动把屏幕共享给打开,因为我在面试之前已经在 processon 上画好了项目的流程图&#xf…

详解机器学习最优化算法

前言 对于几乎所有机器学习算法,无论是有监督学习、无监督学习,还是强化学习,最后一般都归结为求解最优化问题。因此,最优化方法在机器学习算法的推导与实现中占据中心地位。在这篇文章中,小编将对机器学习中所使用的…

算法之路(一)

🖊作者 : D. Star. 📘专栏 :算法小能手 😆今日分享 : 如何学习? 在学习的过程中,不仅要知道如何学习,还要知道避免学习的陷阱。1. 睡眠不足;2. 被动学习和重读;3. 强调标记或画线&am…

解析邮件文本内容; Mime文本解析; MimeStreamParser; multipart解析

原始文本 ------_Part_46705_715015081.1699589700255 Content-Type: text/html;charsetUTF-8 Content-Transfer-Encoding: base64PGh0bWwCiAgICA8aGVhZD4KICAgICAgICA8bWV0YSBodHRwLW VxdWl2PSJDb250ZW50LVR5cGUiIGNvbnRlbnQ9InRleHQvaHRt bDsgY2hhcnNldD1VVEYtOCICiAgICAgIC…

使用Ruby编写通用爬虫程序

目录 一、引言 二、环境准备 三、爬虫程序设计 1. 抓取网页内容 2. 解析HTML内容 3. 提取特定信息 4. 数据存储 四、优化和扩展 五、结语 一、引言 网络爬虫是一种自动抓取互联网信息的程序。它们按照一定的规则和算法,遍历网页并提取所需的信息。使用Rub…

初识Linux:目录路径

目录 提示:以下指令均在Xshell 7 中进行 一、基本指令: 二、文件 文件内容文件属性 三、ls 指令拓展 1、 ls -l : 2、ls -la: 3、ls [目录名] : 4、ls -ld [目录名]: 四、Linux中的文件和…

串口通信(11)-CRC校验介绍算法

本文为博主 日月同辉,与我共生,csdn原创首发。希望看完后能对你有所帮助,不足之处请指正!一起交流学习,共同进步! > 发布人:日月同辉,与我共生_单片机-CSDN博客 > 欢迎你为独创博主日月同…

2023.11.10联赛 T3题解

题目大意 题目思路 感性理解一下,将一个数的平方变成多个数平方的和,为了使代价最小,这些数的大小应该尽可能的平均。 我们可以将 ∣ b i − a i ∣ |b_i-a_i| ∣bi​−ai​∣放入大根堆,同时将这个数划分的次数以及多划分一段减…

Xmake v2.8.5 发布,支持链接排序和单元测试

Xmake 是一个基于 Lua 的轻量级跨平台构建工具。 它非常的轻量,没有任何依赖,因为它内置了 Lua 运行时。 它使用 xmake.lua 维护项目构建,相比 makefile/CMakeLists.txt,配置语法更加简洁直观,对新手非常友好&#x…

tcpreplay命令后加上“--maxsleep=num“,num表示最大延迟时间(单位毫秒)

这个参数的含义是控制在发送每个数据包之间的最大延迟时间,单位是毫秒。它可以用来模拟真实网络中的一些延迟情况,比如网络拥塞、带宽限制等。 使用方法是在tcpreplay命令后加上"--maxsleepnum",num表示最大延迟时间,例…

java传base64返回给数据报404踩坑

一、问题复现 1.可能因为base64字符太长,导致后端处理时出错,表现为前端请求报400错误; 这一步debug进去发现base64数据是正常传值的 所以排除掉不是后端问题,但是看了下前端请求,猜测可能是转换base64时间太长数据过大导致的404 2.前端传…