【Pytorch】Visualization of Feature Maps(4)——Saliency Maps

在这里插入图片描述

学习参考来自

  • Saliency Maps的原理与简单实现(使用Pytorch实现)
  • https://github.com/wmn7/ML_Practice/tree/master/2019_07_08/Saliency%20Maps

Saliency Maps 原理

《Deep Inside Convolutional Networks: Visualising Image Classification Models and Saliency Maps》(arXiv-2013)

在这里插入图片描述

A saliency map tells us the degree to which each pixel in the image affects the classification score for that image.
To compute it, we compute the gradient of the unnormalized score corresponding to the correct class (which is a scalar)
with respect to the pixels of the image. If the image has shape (3, H, W) then this gradient will also have shape (3, H, W);
for each pixel in the image, this gradient tells us the amount by which the classification score will change if the pixel
changes by a small amount. To compute the saliency map, we take the absolute value of this gradient, then take the maximum value over the 3 input channels; the final saliency map thus has shape (H, W) and all entries are non-negative.

Saliency Maps相当于是计算图像的每一个pixel是如何影响一个分类器的, 或者说分类器对图像中每一个pixel哪些认为是重要的.

会计算图像每一个像素点的梯度。如果图像的形状是(3, H, W),这个梯度的形状也是(3, H, W);对于图像中的每个像素点,
这个梯度告诉我们当像素点发生轻微改变时,正确分类分数变化的幅度。

计算 saliency map 的时候,需要计算出梯度的绝对值,然后再取三个颜色通道的最大值;

因此最后的 saliency map的形状是(H, W)为一个通道的灰度图。


直接来代码,先载入些数据,用的是 cs231n 作业里面的 imagenet_val_25.npz,含有 imagenet 数据中验证集的 25 张图片

import torch
import torchvision
import torchvision.transforms as T
import matplotlib.pyplot as plt
import numpy as np
import os
from PIL import ImageSQUEEZENET_MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32)
SQUEEZENET_STD = np.array([0.229, 0.224, 0.225], dtype=np.float32)def load_imagenet_val(num=None):"""Load a handful of validation images from ImageNet.Inputs:- num: Number of images to load (max of 25)Returns:- X: numpy array with shape [num, 224, 224, 3]- y: numpy array of integer image labels, shape [num]- class_names: dict mapping integer label to class name"""imagenet_fn = 'imagenet_val_25.npz'if not os.path.isfile(imagenet_fn):print('file %s not found' % imagenet_fn)print('Run the following:')print('cd cs231n/datasets')print('bash get_imagenet_val.sh')assert False, 'Need to download imagenet_val_25.npz'f = np.load(imagenet_fn, allow_pickle=True)X = f['X']  # (25, 224, 224, 3)y = f['y']  # (25, )class_names = f['label_map'].item()  # 999if num is not None:X = X[:num]y = y[:num]return X, y, class_names

图像的前处理,resize,变成向量,减均值除以方差

# 辅助函数
def preprocess(img, size=224):transform = T.Compose([T.Resize(size),T.ToTensor(),T.Normalize(mean=SQUEEZENET_MEAN.tolist(),std=SQUEEZENET_STD.tolist()),T.Lambda(lambda x: x[None]),])return transform(img)

在这里插入图片描述

数据集和实验的模型

链接:https://pan.baidu.com/s/1vb2Y0IiHdH_Fb9wibTta4Q?pwd=zuvw
提取码:zuvw


核心代码,计算 saliency maps

def compute_saliency_maps(X, y, model):"""X表示图片, y表示分类结果, model表示使用的分类模型Input : - X : Input images : Tensor of shape (N, 3, H, W)- y : Label for X : LongTensor of shape (N,)- model : A pretrained CNN that will be used to computer the saliency mapReturn :- saliency : A Tensor of shape (N, H, W) giving the saliency maps for the input images"""# 确保model是test模式model.eval()# 确保X是需要gradientX.requires_grad_() # 仅开启了输入图片的梯度saliency = Nonelogits = model.forward(X)  # torch.Size([5, 1000]), 前向获取 logitslogits = logits.gather(1, y.view(-1, 1)).squeeze()  # torch.Size([5]) 得到正确分类 logits (5张图片标签相应类别的 logits)logits.backward(torch.FloatTensor([1., 1., 1., 1., 1.]))  # 只计算正确分类部分的loss(正确类别梯度为 1 回传)saliency = abs(X.grad.data)  # 返回X的梯度绝对值大小, torch.Size([5, 3, 224, 224])saliency, _ = torch.max(saliency, dim=1)  # torch.Size([5, 224, 224]),取 rgb 3通道的最大值return saliency.squeeze()

显示 saliency maps

def show_saliency_maps(X, y):# Convert X and y from numpy arrays to Torch TensorsX_tensor = torch.cat([preprocess(Image.fromarray(x)) for x in X], dim=0) # torch.Size([5, 3, 224, 224])y_tensor = torch.LongTensor(y)# Compute saliency maps for images in Xsaliency = compute_saliency_maps(X_tensor, y_tensor, model)# Convert the saliency map from Torch Tensor to numpy array and show images# and saliency maps together.saliency = saliency.numpy()N = X.shape[0]  # 5for i in range(N):plt.subplot(2, N, i + 1)plt.imshow(X[i])plt.axis('off')plt.title(class_names[y[i]])plt.subplot(2, N, N + i + 1)plt.imshow(saliency[i], cmap=plt.cm.hot)plt.axis('off')plt.gcf().set_size_inches(12, 5)plt.show()

下面开始调用,首先载入模型,使其梯度冻结,仅打开输入图片的梯度,这样反向传播的时候会更新图片,得到我们想要的 saliency maps

# Download and load the pretrained SqueezeNet model.
model = torchvision.models.squeezenet1_1(pretrained=True)# We don't want to train the model, so tell PyTorch not to compute gradients
# with respect to model parameters.
for param in model.parameters():param.requires_grad = False

加载一些图片看看,25 张中抽出来 5 张

X, y, class_names = load_imagenet_val(num=5)  # X: (5, 224, 224, 3) | y: (5,) | class_names: 999"show images"plt.figure(figsize=(12, 6))
for i in range(5):plt.subplot(1, 5, i + 1)plt.imshow(X[i])plt.title(class_names[y[i]])plt.axis('off')
plt.gcf().tight_layout()
plt.show()

显示图片
在这里插入图片描述
把五张图片的 saliency maps 画出来

show_saliency_maps(X, y)

我把 25 张都画出来了
在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述
在这里插入图片描述


核心代码中涉及到了 gather 函数,下面来个简单的例子就明白了

# Example of using gather to select one entry from each row in PyTorch
# 用来返回matrix指定行某个位置的值
import torchdef gather_example():N, C = 4, 5s = torch.randn(N, C) # 随机生成 4 行 5 列的 tensory = torch.LongTensor([1, 2, 1, 3])print(s)print(y)print(torch.LongTensor(y).view(-1, 1))print(s.gather(1, y.view(-1, 1)).squeeze()) # 抽取每行相应的列数位置上的数值gather_example()"""
tensor([[ 0.8119,  0.2664, -1.4168, -0.1490, -0.0675],[ 0.5335,  0.6304, -0.7200, -0.0974, -0.9934],[-0.8305,  0.5189,  0.7359,  1.5875,  0.0505],[ 0.4335, -1.1389, -0.7771,  0.5779,  0.3515]])
tensor([1, 2, 1, 3])
tensor([[1],[2],[1],[3]])
tensor([ 0.2664, -0.7200,  0.5189,  0.5779])
"""

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/192799.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

vue权限管理解决方案

一. 什么是权限管理 权限控制是确保用户只能访问其被授权的资源和执行其被授权的操作的重要方面。而前端权限归根结底是请求的发起权,请求的发起可能有下面两种形式触发 页面加载触发页面上的按钮点击触发 总体而言,权限控制可以从前端路由和视图两个…

深度学习记录--广播(Broadcasting)

什么是广播? 广播(Broadcasting),在python中是一种矩阵初等运算的手段,用于将一个常数扩展成一个矩阵,使得运算可行 广播的作用 比如: 一个1*n的矩阵要和常数b相加,广播使得常数b扩展成一个1*n的矩阵 …

zemax之初级像差理论与像差校正——慧差

通过上节介绍,我们已经知道在轴上视场产生的球差是旋转对称的像差。在进行光学系统设计时,同时需要保证轴上物点和轴外物点的成像质量。轴外物点成像时会引入轴外像差,即轴外视场产生的慧差(coma aberration) 1.慧差概…

申请Azure学生订阅——人工验证

一:联系客服进行人工验证 点击 Services Hub 填写资料申请人工验证 点击 Azure - Sign up 进行学生验证 二:与客服的邮件沟通的记录 ​​​​一、结果(输入客服给的验证码后,笔者便得到了学生订阅): 二…

k8s中批量处理Pod应用的Job和CronJob控制器、处理守护型pod的DaemonSet控制器介绍

目录 一.Job控制器 1.简介 2.Jobs较完整解释 3.示例演示 4.注意:如上例的话,执行“kubectl delete -f myJob.yaml”就可以将job删掉 二.CronJob(简写为cj) 1.简介 2.CronJob较完整解释 3.案例演示 4.如上例的话&#xf…

[原创][2]探究C#多线程开发细节-“线程的无顺序性“

[简介] 常用网名: 猪头三 出生日期: 1981.XX.XX QQ: 643439947 个人网站: 80x86汇编小站 https://www.x86asm.org 编程生涯: 2001年~至今[共22年] 职业生涯: 20年 开发语言: C/C、80x86ASM、PHP、Perl、Objective-C、Object Pascal、C#、Python 开发工具: Visual Studio、Delph…

golang Pool实战与底层实现

使用的go版本为 go1.21.2 首先我们写一个简单的Pool的使用代码 package mainimport "sync"var bytePool sync.Pool{New: func() interface{} {b : make([]byte, 1024)return &b}, }func main() {for j : 0; j < 10; j {obj : bytePool.Get().(*[]byte) // …

Java基础-----Date类及其相关类(一)

文章目录 1. Date类1.1 简介1.2 构造方法1.3 主要方法 2. DateFormat 类2.1 简介2.2 实例化方式一&#xff1a;通过静态方法的调用2.2 实例化方式二&#xff1a;通过创建子类对象 3. Calendar类4. GregorianCalendar 1. Date类 1.1 简介 java.util.Date:表示指定的时间信息&a…

vivado实现分析与收敛技巧7-布局规划

关于布局规划 布局规划有助于设计满足时序要求。当设计难以始终如一满足时序要求或者从未满足时序要求时 &#xff0c; AMD 建议您执行布局规划。如果您与设计团队协作并且协作过程中一致性至关重要&#xff0c; 那么布局规划同样可以发挥作用。布局规划可通过减少平均布线延…

Redis-安装、配置和修改配置文件、以及在Ubuntu和CentOS上设置Redis服务的开机启动和防火墙设置,以及客户端连接。

目录 1. Redis简介 2. 离线安装 2.1 准备工作 2.2 解压、安装 2.3 修改配置文件 2.4 redis服务与关闭 2.5 redis服务的开机启动 2.5.1 Ubuntu上的配置 2.5.2 centos上的配置 3. 在线安装 4. 设置防火墙 5. 客户端连接 1. Redis简介 Redis 是完全开源免费的&#x…

鼠标点击效果.html(网上收集6)

<!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8"><title>鼠标点击</title> </head><body> <script>(function () {var a_idx 0;window.onclick function (event) {var a new Array(…

【Python从入门到进阶】43.验证码识别工具结合requests的使用

接上篇《42、使用requests的Cookie登录古诗文网站》 上一篇我们介绍了如何利用requests的Cookie登录古诗文网。本篇我们来学习如何使用验证码识别工具进行登录验证的自动识别。 一、图片验证码识别过程及手段 上一篇我们通过requests的session方法&#xff0c;带着原网页登录…

人工智能 - 人脸识别:发展历史、技术全解与实战

目录 一、人脸识别技术的发展历程早期探索&#xff1a;20世纪60至80年代技术价值点&#xff1a; 自动化与算法化&#xff1a;20世纪90年代技术价值点&#xff1a; 深度学习的革命&#xff1a;21世纪初至今技术价值点&#xff1a; 二、几何特征方法详解与实战几何特征方法的原理…

python安装与配置:在centos上使用shell脚本一键安装

介绍 Python是一种功能强大且广泛使用的编程语言&#xff0c;但在某些情况下&#xff0c;您可能需要安装和配置特定版本的Python。本教程将向您展示如何使用一个Shell脚本自动完成这个过程&#xff0c;以便您可以快速开始使用Python 3。 使用shell自动化安装教程 1. 复制脚本…

51单片机项目(19)——基于51单片机的传送带产品计数器

1.功能描述 应用背景: 某生产线的传送带上不断地有产品单向传送&#xff0c;传送时会通过光电传感器产生方波信号&#xff0c;将该信号(可以采用方波发生器来模拟该信号)直接传送给51单片机&#xff0c;利用计数器0计量产品(方波信号)的个数&#xff0c;利用.定时器1产…

Python海绵宝宝

目录 系列文章 写在前面 海绵宝宝 写在后面 系列文章 序号文章目录直达链接表白系列1浪漫520表白代码https://want595.blog.csdn.net/article/details/1306668812满屏表白代码https://want595.blog.csdn.net/article/details/1297945183跳动的爱心https://want595.blog.cs…

leetcode 209. 长度最小的子数组(优质解法)

代码&#xff1a; //时间复杂度 O(N) ,空间复杂度 O(1) class Solution {//采用滑动窗口的方法解决public int minSubArrayLen(int target, int[] nums) {int numsLengthnums.length;int minLengthInteger.MAX_VALUE;int left0;int right0;int sum0;while (right<numsLengt…

【SpringBoot】讲清楚日志文件lombok

文章目录 前言一、日志是什么&#xff1f;二、⽇志怎么⽤&#xff1f;三.自定义打印日志3.1在程序中得到日志对象3.2使用日志打印对象 四.⽇志级别4.1日志级别有什么用4.2 ⽇志级别的分类与使⽤ 五.日志持久化六.lombok6.1添加lobok依赖注意&#xff1a;使⽤ Slf4j 注解&#x…

Linux 多线程(C语言) 备查

基础 1&#xff09;线程在运行态和就绪态不停的切换。 2&#xff09;每个线程都有自己的栈区和寄存器 1&#xff09;进程是资源分配的最小单位&#xff0c;线程是操作系统调度执行的最小单位 2&#xff09;线程的上下文切换的速度比进程快得多 3&#xff09;从应用程序A中启用应…

Linux系列-1 Linux启动流程——init与systemd进程

背景&#xff1a; 最近对所有项目完成了一个切换&#xff0c;服务管理方式由&#xff1a; init-> systemd。对相关知识进行总结一下。 1.启动流程 服务器的整体启动流程如下图所示&#xff1a; POST&#xff1a; 计算机通电后进行POST( Power-On Self-Test )加电自检&am…