【Pytorch】Visualization of Feature Maps(4)——Saliency Maps

在这里插入图片描述

学习参考来自

  • Saliency Maps的原理与简单实现(使用Pytorch实现)
  • https://github.com/wmn7/ML_Practice/tree/master/2019_07_08/Saliency%20Maps

Saliency Maps 原理

《Deep Inside Convolutional Networks: Visualising Image Classification Models and Saliency Maps》(arXiv-2013)

在这里插入图片描述

A saliency map tells us the degree to which each pixel in the image affects the classification score for that image.
To compute it, we compute the gradient of the unnormalized score corresponding to the correct class (which is a scalar)
with respect to the pixels of the image. If the image has shape (3, H, W) then this gradient will also have shape (3, H, W);
for each pixel in the image, this gradient tells us the amount by which the classification score will change if the pixel
changes by a small amount. To compute the saliency map, we take the absolute value of this gradient, then take the maximum value over the 3 input channels; the final saliency map thus has shape (H, W) and all entries are non-negative.

Saliency Maps相当于是计算图像的每一个pixel是如何影响一个分类器的, 或者说分类器对图像中每一个pixel哪些认为是重要的.

会计算图像每一个像素点的梯度。如果图像的形状是(3, H, W),这个梯度的形状也是(3, H, W);对于图像中的每个像素点,
这个梯度告诉我们当像素点发生轻微改变时,正确分类分数变化的幅度。

计算 saliency map 的时候,需要计算出梯度的绝对值,然后再取三个颜色通道的最大值;

因此最后的 saliency map的形状是(H, W)为一个通道的灰度图。


直接来代码,先载入些数据,用的是 cs231n 作业里面的 imagenet_val_25.npz,含有 imagenet 数据中验证集的 25 张图片

import torch
import torchvision
import torchvision.transforms as T
import matplotlib.pyplot as plt
import numpy as np
import os
from PIL import ImageSQUEEZENET_MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32)
SQUEEZENET_STD = np.array([0.229, 0.224, 0.225], dtype=np.float32)def load_imagenet_val(num=None):"""Load a handful of validation images from ImageNet.Inputs:- num: Number of images to load (max of 25)Returns:- X: numpy array with shape [num, 224, 224, 3]- y: numpy array of integer image labels, shape [num]- class_names: dict mapping integer label to class name"""imagenet_fn = 'imagenet_val_25.npz'if not os.path.isfile(imagenet_fn):print('file %s not found' % imagenet_fn)print('Run the following:')print('cd cs231n/datasets')print('bash get_imagenet_val.sh')assert False, 'Need to download imagenet_val_25.npz'f = np.load(imagenet_fn, allow_pickle=True)X = f['X']  # (25, 224, 224, 3)y = f['y']  # (25, )class_names = f['label_map'].item()  # 999if num is not None:X = X[:num]y = y[:num]return X, y, class_names

图像的前处理,resize,变成向量,减均值除以方差

# 辅助函数
def preprocess(img, size=224):transform = T.Compose([T.Resize(size),T.ToTensor(),T.Normalize(mean=SQUEEZENET_MEAN.tolist(),std=SQUEEZENET_STD.tolist()),T.Lambda(lambda x: x[None]),])return transform(img)

在这里插入图片描述

数据集和实验的模型

链接:https://pan.baidu.com/s/1vb2Y0IiHdH_Fb9wibTta4Q?pwd=zuvw
提取码:zuvw


核心代码,计算 saliency maps

def compute_saliency_maps(X, y, model):"""X表示图片, y表示分类结果, model表示使用的分类模型Input : - X : Input images : Tensor of shape (N, 3, H, W)- y : Label for X : LongTensor of shape (N,)- model : A pretrained CNN that will be used to computer the saliency mapReturn :- saliency : A Tensor of shape (N, H, W) giving the saliency maps for the input images"""# 确保model是test模式model.eval()# 确保X是需要gradientX.requires_grad_() # 仅开启了输入图片的梯度saliency = Nonelogits = model.forward(X)  # torch.Size([5, 1000]), 前向获取 logitslogits = logits.gather(1, y.view(-1, 1)).squeeze()  # torch.Size([5]) 得到正确分类 logits (5张图片标签相应类别的 logits)logits.backward(torch.FloatTensor([1., 1., 1., 1., 1.]))  # 只计算正确分类部分的loss(正确类别梯度为 1 回传)saliency = abs(X.grad.data)  # 返回X的梯度绝对值大小, torch.Size([5, 3, 224, 224])saliency, _ = torch.max(saliency, dim=1)  # torch.Size([5, 224, 224]),取 rgb 3通道的最大值return saliency.squeeze()

显示 saliency maps

def show_saliency_maps(X, y):# Convert X and y from numpy arrays to Torch TensorsX_tensor = torch.cat([preprocess(Image.fromarray(x)) for x in X], dim=0) # torch.Size([5, 3, 224, 224])y_tensor = torch.LongTensor(y)# Compute saliency maps for images in Xsaliency = compute_saliency_maps(X_tensor, y_tensor, model)# Convert the saliency map from Torch Tensor to numpy array and show images# and saliency maps together.saliency = saliency.numpy()N = X.shape[0]  # 5for i in range(N):plt.subplot(2, N, i + 1)plt.imshow(X[i])plt.axis('off')plt.title(class_names[y[i]])plt.subplot(2, N, N + i + 1)plt.imshow(saliency[i], cmap=plt.cm.hot)plt.axis('off')plt.gcf().set_size_inches(12, 5)plt.show()

下面开始调用,首先载入模型,使其梯度冻结,仅打开输入图片的梯度,这样反向传播的时候会更新图片,得到我们想要的 saliency maps

# Download and load the pretrained SqueezeNet model.
model = torchvision.models.squeezenet1_1(pretrained=True)# We don't want to train the model, so tell PyTorch not to compute gradients
# with respect to model parameters.
for param in model.parameters():param.requires_grad = False

加载一些图片看看,25 张中抽出来 5 张

X, y, class_names = load_imagenet_val(num=5)  # X: (5, 224, 224, 3) | y: (5,) | class_names: 999"show images"plt.figure(figsize=(12, 6))
for i in range(5):plt.subplot(1, 5, i + 1)plt.imshow(X[i])plt.title(class_names[y[i]])plt.axis('off')
plt.gcf().tight_layout()
plt.show()

显示图片
在这里插入图片描述
把五张图片的 saliency maps 画出来

show_saliency_maps(X, y)

我把 25 张都画出来了
在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述
在这里插入图片描述


核心代码中涉及到了 gather 函数,下面来个简单的例子就明白了

# Example of using gather to select one entry from each row in PyTorch
# 用来返回matrix指定行某个位置的值
import torchdef gather_example():N, C = 4, 5s = torch.randn(N, C) # 随机生成 4 行 5 列的 tensory = torch.LongTensor([1, 2, 1, 3])print(s)print(y)print(torch.LongTensor(y).view(-1, 1))print(s.gather(1, y.view(-1, 1)).squeeze()) # 抽取每行相应的列数位置上的数值gather_example()"""
tensor([[ 0.8119,  0.2664, -1.4168, -0.1490, -0.0675],[ 0.5335,  0.6304, -0.7200, -0.0974, -0.9934],[-0.8305,  0.5189,  0.7359,  1.5875,  0.0505],[ 0.4335, -1.1389, -0.7771,  0.5779,  0.3515]])
tensor([1, 2, 1, 3])
tensor([[1],[2],[1],[3]])
tensor([ 0.2664, -0.7200,  0.5189,  0.5779])
"""

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/192799.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

vue权限管理解决方案

一. 什么是权限管理 权限控制是确保用户只能访问其被授权的资源和执行其被授权的操作的重要方面。而前端权限归根结底是请求的发起权,请求的发起可能有下面两种形式触发 页面加载触发页面上的按钮点击触发 总体而言,权限控制可以从前端路由和视图两个…

javaScript函数总结

一、函数是什么? 函数,就是一个一系列JavaScript语句的集合,这是为了完成某一个会重复使用的特定功能。在需要该功能的时候,直接调用函数即可,而不必每次都编写一大堆重复的代码。并且在需要修改该功能的时候&#xf…

LeetCode841. Keys and Rooms

文章目录 一、题目二、题解 一、题目 There are n rooms labeled from 0 to n - 1 and all the rooms are locked except for room 0. Your goal is to visit all the rooms. However, you cannot enter a locked room without having its key. When you visit a room, you m…

深度学习记录--广播(Broadcasting)

什么是广播? 广播(Broadcasting),在python中是一种矩阵初等运算的手段,用于将一个常数扩展成一个矩阵,使得运算可行 广播的作用 比如: 一个1*n的矩阵要和常数b相加,广播使得常数b扩展成一个1*n的矩阵 …

zemax之初级像差理论与像差校正——慧差

通过上节介绍,我们已经知道在轴上视场产生的球差是旋转对称的像差。在进行光学系统设计时,同时需要保证轴上物点和轴外物点的成像质量。轴外物点成像时会引入轴外像差,即轴外视场产生的慧差(coma aberration) 1.慧差概…

申请Azure学生订阅——人工验证

一:联系客服进行人工验证 点击 Services Hub 填写资料申请人工验证 点击 Azure - Sign up 进行学生验证 二:与客服的邮件沟通的记录 ​​​​一、结果(输入客服给的验证码后,笔者便得到了学生订阅): 二…

强化学习Markov重要公式推导过程

Markov决策过程(Markov Decision Process,MDP) Markov过程是一种用于描述决策问题的数学框架,是强化学习的基础。MDP中,决策者面对一系列的状态和动作,每个状态下采取不同的动作会获得不同的奖励&#xff…

k8s中批量处理Pod应用的Job和CronJob控制器、处理守护型pod的DaemonSet控制器介绍

目录 一.Job控制器 1.简介 2.Jobs较完整解释 3.示例演示 4.注意:如上例的话,执行“kubectl delete -f myJob.yaml”就可以将job删掉 二.CronJob(简写为cj) 1.简介 2.CronJob较完整解释 3.案例演示 4.如上例的话&#xf…

[原创][2]探究C#多线程开发细节-“线程的无顺序性“

[简介] 常用网名: 猪头三 出生日期: 1981.XX.XX QQ: 643439947 个人网站: 80x86汇编小站 https://www.x86asm.org 编程生涯: 2001年~至今[共22年] 职业生涯: 20年 开发语言: C/C、80x86ASM、PHP、Perl、Objective-C、Object Pascal、C#、Python 开发工具: Visual Studio、Delph…

在CentOS7版本下安装Docker与Docker Compose

目录 Docker简介 Docker安装 Docker Compose简介 Docker Compose安装 Docker简介 Docker是一个开源的应用容器引擎,让开发者可以打包他们的应用以及依赖包到一个可移植的镜像中,然后发布到任何流行的Linux或Windows操作系统的机器上,也…

【Numpy】一组标量

在NumPy中,当你对一个数组和一个向量进行逐元素相乘时,它们的形状需要满足广播规则,才能够进行元素级的乘法操作。广播是一组规则,允许NumPy在不同形状的数组上执行操作,从而使得某些操作更加灵活和高效。 在你的情况…

golang Pool实战与底层实现

使用的go版本为 go1.21.2 首先我们写一个简单的Pool的使用代码 package mainimport "sync"var bytePool sync.Pool{New: func() interface{} {b : make([]byte, 1024)return &b}, }func main() {for j : 0; j < 10; j {obj : bytePool.Get().(*[]byte) // …

Python高级数据结构——并查集(Disjoint Set)

Python中的并查集&#xff08;Disjoint Set&#xff09;&#xff1a;高级数据结构解析 并查集是一种用于处理集合的数据结构&#xff0c;它主要支持两种操作&#xff1a;合并两个集合和查找一个元素所属的集合。在本文中&#xff0c;我们将深入讲解Python中的并查集&#xff0…

Java基础-----Date类及其相关类(一)

文章目录 1. Date类1.1 简介1.2 构造方法1.3 主要方法 2. DateFormat 类2.1 简介2.2 实例化方式一&#xff1a;通过静态方法的调用2.2 实例化方式二&#xff1a;通过创建子类对象 3. Calendar类4. GregorianCalendar 1. Date类 1.1 简介 java.util.Date:表示指定的时间信息&a…

vivado实现分析与收敛技巧7-布局规划

关于布局规划 布局规划有助于设计满足时序要求。当设计难以始终如一满足时序要求或者从未满足时序要求时 &#xff0c; AMD 建议您执行布局规划。如果您与设计团队协作并且协作过程中一致性至关重要&#xff0c; 那么布局规划同样可以发挥作用。布局规划可通过减少平均布线延…

Redis-安装、配置和修改配置文件、以及在Ubuntu和CentOS上设置Redis服务的开机启动和防火墙设置,以及客户端连接。

目录 1. Redis简介 2. 离线安装 2.1 准备工作 2.2 解压、安装 2.3 修改配置文件 2.4 redis服务与关闭 2.5 redis服务的开机启动 2.5.1 Ubuntu上的配置 2.5.2 centos上的配置 3. 在线安装 4. 设置防火墙 5. 客户端连接 1. Redis简介 Redis 是完全开源免费的&#x…

鼠标点击效果.html(网上收集6)

<!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8"><title>鼠标点击</title> </head><body> <script>(function () {var a_idx 0;window.onclick function (event) {var a new Array(…

docker 如何在容器内重启 php

要在Docker容器内重启PHP&#xff0c;您可以按照以下步骤进行操作&#xff1a; 1. 进入您的Docker容器。可以使用以下命令启动一个交互式的容器会话&#xff1a; docker exec -it <container_name> /bin/bash其中<container_name>是您的容器的名称。。 2. 在容器内…

【Python从入门到进阶】43.验证码识别工具结合requests的使用

接上篇《42、使用requests的Cookie登录古诗文网站》 上一篇我们介绍了如何利用requests的Cookie登录古诗文网。本篇我们来学习如何使用验证码识别工具进行登录验证的自动识别。 一、图片验证码识别过程及手段 上一篇我们通过requests的session方法&#xff0c;带着原网页登录…

人工智能 - 人脸识别:发展历史、技术全解与实战

目录 一、人脸识别技术的发展历程早期探索&#xff1a;20世纪60至80年代技术价值点&#xff1a; 自动化与算法化&#xff1a;20世纪90年代技术价值点&#xff1a; 深度学习的革命&#xff1a;21世纪初至今技术价值点&#xff1a; 二、几何特征方法详解与实战几何特征方法的原理…