《python深度学习》笔记(二十):神经网络的解释方法之CAM、Grad-CAM、Grad-CAM++、LayerCAM

原理优点缺点
GAP将多维特征映射降维为一个固定长度的特征向量①减少了模型的参数量;②保留更多的空间位置信息;③可并行计算,计算效率高;④具有一定程度的不变性①可能导致信息的损失;②忽略不同尺度的空间信息
CAM利用最后一个卷积层的特征图×权重(用GAP代替全连接层,重新训练,经过GAP分类后概率最大的神经元的权重效果已经很不错需要修改原模型的结构,导致需要重新训练该模型,大大限制了使用场景,如果模型已经上线了,或着训练的成本非常高,我们几乎是不可能为了它重新训练的。
Grad-CAM最后一个卷积层的特征图×权重(通过对特征图梯度的全局平均来计算权重①解决了CAM的缺点,适用于任何卷积神经网络;②利用特征图的梯度,可视化结果更准确和精细
Grad-CAM++1. 定位更准确
2. 更适合同类多目标的情况

目录

GAP全局平均池化

CAM

Grad-CAM 

Grad-CAM++


GAP全局平均池化

论文:Network In Network

GAP (Global Average Pooling,全局平均池化),在上述论文中提出,用于避免全连接层的过拟合问题。全局平均池化就是对整个特征映射应用平均池化。

图1:将原本h × w × d的三维特征图,具体大小为6 × 6 × 3,经过GAP池化为1 × 1 × 3 输出值。也就是每一个channel的h × w 平均池化为一个值。特征图经过 GAP 处理后每一个特征图包含了不同类别的信息。 

GAP平均池化的操作步骤如下:

  1. 经过卷积操作和激活函数后,得到最后一个卷积层的特征图。
  2. 对每个通道的特征图进行平均池化,即计算每个通道上所有元素的平均值。这将每个通道的特征图转化为一个标量值。
  3. 将每个通道的标量值组合成一个特征向量。这些标量值的顺序与通道的顺序相同。
  4. 最终得到的特征向量可以作为分类器的输入,用于进行图像分类。

CAM

论文:Learning Deep Features for Discriminative Localization

原理:利用最后一个卷积层的特征图与经过GAP分类后概率最大的神经元权重进行叠加。

图2:解释了在CNN中使用全局平均池化(GAP)生成类激活映射(CAM)的过程:

经过最后一层卷积操作之后,得到的特征图包含多个channel,如图1中的不同颜色的3个channel,也就是在GAP之前所对应的不同的channel特征图,f_{k}就表示第k个channel的特征图。然后经过GAP处理后每个channel的特征图包含了不同类别的信息,w_{k}就表示分类概率最大的神经元(图2黑色神经元)所对应连接的第k个神经元的权重。

Grad-CAM 

Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization (arxiv.org)

Grad-CAM的前身是 CAM,CAM 的基本的思想是求分类网络某一类别得分对高维特征图 (卷积层的输出) 的偏导数,从而可以得到该高维特征图每个通道对该类别得分的权值;而高维特征图的激活信息 (正值) 又代表了卷积神经网络的所感兴趣的信息,加权后使用热力图呈现得到 CAM。

原理:Grad-CAM的关键思想是将输出类别的梯度(相对于特定卷积层的输出)与该层的输出相乘,然后取平均,得到一个“粗糙”的热力图。这个热力图可以被放大并叠加到原始图像上,以显示模型在分类时最关注的区域。

具体步骤如下:

  1. 选择网络的最后一个卷积层,因为它既包含了高级特征,也保留了空间信息。
  2. 前向传播图像到网络,得到你想解释的类别的得分。
  3. 计算此得分相对于我们选择的卷积层输出的梯度。
  4. 对于该卷积层的每个通道,使用上述梯度的全局平均值对该通道进行加权。
  5. 结果是一个与卷积层的空间维度相同的加权热力图。

因为热力图关心的是对分类有正面影响的特征,所以在线性组合的技术上加上了ReLU,以移除负值 。

w_{k}^{c}第 k 个特征图对应于类别 c 的权重,
A^{k}表示:第 k 个特征图,
Z表示特征图的像素个数,
y^{c}表示: 第c类得分的梯度,
A_{ij}^{k}表示: 第 k个特征图中坐标( i , j )位置处的像素值;

Grad-CAM代码:

import torch
import cv2
import torch.nn.functional as F
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from PIL import Imageclass GradCAM:def __init__(self, model, target_layer):self.model = model  # 要进行Grad-CAM处理的模型self.target_layer = target_layer  # 要进行特征可视化的目标层self.feature_maps = None  # 存储特征图self.gradients = None  # 存储梯度# 为目标层添加钩子,以保存输出和梯度target_layer.register_forward_hook(self.save_feature_maps)target_layer.register_backward_hook(self.save_gradients)def save_feature_maps(self, module, input, output):"""保存特征图"""self.feature_maps = output.detach()def save_gradients(self, module, grad_input, grad_output):"""保存梯度"""self.gradients = grad_output[0].detach()def generate_cam(self, image, class_idx=None):"""生成CAM热力图"""# 将模型设置为评估模式self.model.eval()# 正向传播output = self.model(image)if class_idx is None:class_idx = torch.argmax(output).item()# 清空所有梯度self.model.zero_grad()# 对目标类进行反向传播one_hot = torch.zeros((1, output.size()[-1]), dtype=torch.float32)one_hot[0][class_idx] = 1output.backward(gradient=one_hot.cuda(), retain_graph=True)# 获取平均梯度和特征图pooled_gradients = torch.mean(self.gradients, dim=[0, 2, 3])activation = self.feature_maps.squeeze(0)for i in range(activation.size(0)):activation[i, :, :] *= pooled_gradients[i]# 创建热力图heatmap = torch.mean(activation, dim=0).squeeze().cpu().numpy()heatmap = np.maximum(heatmap, 0)heatmap /= torch.max(heatmap)heatmap = cv2.resize(heatmap, (image.size(3), image.size(2)))heatmap = np.uint8(255 * heatmap)heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)# 将热力图叠加到原始图像上original_image = self.unprocess_image(image.squeeze().cpu().numpy())superimposed_img = heatmap * 0.4 + original_imagesuperimposed_img = np.clip(superimposed_img, 0, 255).astype(np.uint8)return heatmap, superimposed_imgdef unprocess_image(self, image):"""反预处理图像,将其转回原始图像"""mean = np.array([0.485, 0.456, 0.406])std = np.array([0.229, 0.224, 0.225])image = (((image.transpose(1, 2, 0) * std) + mean) * 255).astype(np.uint8)return imagedef visualize_gradcam(model, input_image_path, target_layer):"""可视化Grad-CAM热力图"""# 加载图像img = Image.open(input_image_path)preprocess = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])input_tensor = preprocess(img).unsqueeze(0).cuda()# 创建GradCAMgradcam = GradCAM(model, target_layer)heatmap, result = gradcam.generate_cam(input_tensor)# 显示图像和热力图plt.figure(figsize=(10,10))plt.subplot(1,2,1)plt.imshow(heatmap)plt.title('热力图')plt.axis('off')plt.subplot(1,2,2)plt.imshow(result)plt.title('叠加后的图像')plt.axis('off')plt.show()# 以下是示例代码,显示如何使用上述代码。
# 首先,你需要加载你的模型和权重。
# model = resnet20()
# model.load_state_dict(torch.load("path_to_your_weights.pth"))
# model.to('cuda')# 然后,调用`visualize_gradcam`函数来查看结果。
# visualize_gradcam(model, "path_to_your_input_image.jpg", model.layer3[-1])

 Grad-CAM++

Grad-CAM++: Improved Visual Explanations for Deep Convolutional Networks (arxiv.org) 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/129948.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Servlet对象生命周期

Servlet 生命周期包括加载与实例化、初始化、服务请求、销毁等阶段。 ervlet 的生命周期包括以下阶段: 加载与实例化:当容器启动或者第一次请求到达时,Servlet 容器加载 Servlet 类并创建 Servlet 实例。 初始化:在 Servlet 实例…

网络安全演练(一句话木马)

在享受互联网带来的便利的同时,也充满了各种网络安全风险,本文通过搭建实验环境,演示一句话木马获取主机权限。 演示环境 服务端:安装LAMP环境,部署web网站,上传一句话木马文件 客户端:安装A…

qt6-QPushButton无法显示为类

问题 在编写QT程序时,不同颜色表示不同的含义。在设计基本的界面,需要使用QRadioButton时,相应的字符为紫色,紫色为类名。这篇简单说明了下,也可以鼠标点击页面,可以出现提示。 但是上面图片中显示&#…

视频转序列图片:掌握技巧,轻松转换

随着社交媒体和视频平台的日益普及,视频已成为我们生活中不可或缺的一部分。有时,我们需要将视频转换为图片序列,例如制作GIF动图或提取视频中的特定画面。现在一起来看云炫AI智剪如何将视频转换为序列图片,并轻松实现转换。 操作…

OpenShift - 利用容器的特权配置实现对OpenShift攻击

《OpenShift / RHEL / DevSecOps 汇总目录》 说明:本文已经在 OpenShift 4.14 的环境中验证 本文是《容器安全 - 利用容器的特权配置实现对Kubernetes攻击》的后续篇,来介绍 在 OpenShift 环境中的容器特权配置和攻击过程和 Kubernetes 环境的差异。 文…

框架安全-CVE 复现Apache ShiroApache Solr漏洞复现

文章目录 服务攻防-框架安全&CVE 复现&Apache Shiro&Apache Solr漏洞复现中间件列表常见开发框架Apache Shiro-组件框架安全暴露的安全问题漏洞复现Apache Shiro认证绕过漏洞(CVE-2020-1957)CVE-2020-11989验证绕过漏洞CVE_2016_4437 Shiro-…

城市内涝解决方案:实时监测,提前预警,让城市更安全

城市内涝积水问题是指城市地区在短时间内遭遇强降雨后,地面积水过多,导致城市交通堵塞、居民生活不便、财产损失等问题。近年来,随着全球气候变化和城市化进程的加速,城市内涝积水问题越来越突出,成为城市发展中的一大…

基于设深度学习的人脸性别年龄识别系统 计算机竞赛

文章目录 0 前言1 课题描述2 实现效果3 算法实现原理3.1 数据集3.2 深度学习识别算法3.3 特征提取主干网络3.4 总体实现流程 4 具体实现4.1 预训练数据格式4.2 部分实现代码 5 最后 0 前言 🔥 优质竞赛项目系列,今天要分享的是 基于深度学习机器视觉的…

软件测试/测试开发丨ChatGPT能否成为PPT最佳伴侣

点此获取更多相关资料 简介 PPT 已经渗透到我们的日常工作中,无论是工作汇报、商务报告、学术演讲、培训材料都常常要求编写一个正式的 PPT,协助完成一次汇报或一次演讲。PPT相比于传统文本的就是有布局、图片、动画效果等,可以给到观众更好…

【leetcode】26. 删除有序数组中的重复项(图解)

目录 1. 思路(图解)2. 代码 题目链接: leetcode 26. 删除有序数组中的重复项 题目描述: 注意返回的是去重后的数组长度,但是输出的是去重后的数组元素。 1. 思路(图解) 思路:快慢…

在maven官网中如何下载低版本的maven

链接:https://archive.apache.org/dist/maven/maven-3/

快速了解相似检索方法

一、相似检索方法总体分析 相似检索方法是一种用于从大量数据中找到与查询数据相似的数据项的技术。这种方法通常用于信息检索、推荐系统、图像处理、自然语言处理等领域。相似检索主要方法可以总体分为以下几类: 基于距离度量的方法: 余弦相似度&…

Postman接口测试工具,提高SpringBoot开发效率

文章目录 🌺工具—postman⭐作用🏳️‍🌈安装🎈创建工作空间 🎄简单参数⭐原始方式🎈我们建立springboot项目,输入下面的代码🎈运行 ⭐SpringBoot方式 🎄实体参数&#x…

正点原子嵌入式linux驱动开发——Linux 音频驱动

音频是最常用到的功能,音频也是linux和安卓的重点应用场合。STM32MP1带有SAI接口,正点原子的STM32MP1开发板通过此接口外接了一个CS42L51音频DAC芯片,本章就来学习一下如何使能CS42L51驱动,并且CS42L51通过芯片来完成音乐播放与录…

Day39 QTableWidget类的使用

1.简介 介绍QtableWidget各种属性的用法,以及常用的一些信号,最后利用这些特性,制作一个用于下发设备运行参数的表格。该表格可以实现折叠和取消折叠,在源代码中用了事件过滤器实现,也可以用自带的click信号。显示了图…

“第五十九天”

这是昨天那道题,这个后面自己的处理思路还是差了点,这道题关键感觉就是对进位的处理的,由于进位的存在,所以处理数据的时候只能从最低位开始,我一开始是从高位处理的,而且后面越来越迷,这个点一…

自家开发VS第三方美颜SDK:技术和资源的比较

开发直播平台时,开发人员面临一个关键决策:是选择使用第三方美颜SDK,还是自家开发美颜算法?本文将深入探讨这两种方法的技术和资源方面的比较,帮助开发者更好地决定哪种途径最适合他们的应用。 一、第三方美颜SDK&am…

智能电表和互感器一起安装有什么效果?

智能电表和互感器的普及,为用电管理提供了更为精确和便捷的方式。那么,当智能电表和互感器一起安装时,会产生怎样的"化学反应"呢?下面,小编就来为大家详细的讲解下智能电表和互感器一起安装的作用吧&#xf…

lua-web-utils库

lua--导入所需的库local web_utilsrequire("lua-web-utils")--定义要下载的URLlocal url"https://jshk.com.cn/"--定义代理服务器的主机名和端口号local proxy_port8000--使用web_utils的download函数下载URLlocal file_pathweb_utils.download(url,proxy_…

2023最新C语言编程练习题大全(一)

目录 一、初识C语言1.1 第一个C语言程序1.2 一个完整的C语言程序1.3 输出名言1.4 计算正方形的周长 二、一个简单的C语言程序2.1 输出一个正方形2.2 输出直角三角形2.3 设计一个简单的求和程序2.4 求10!2.5 三个数由小到大排序2.6 猴子吃桃2.7 阳阳买苹果 一、初识C语言 1.1 第…