神经网络的解释方法之CAM、Grad-CAM、Grad-CAM++、LayerCAM

原理优点缺点
GAP将多维特征映射降维为一个固定长度的特征向量①减少了模型的参数量;②保留更多的空间位置信息;③可并行计算,计算效率高;④具有一定程度的不变性①可能导致信息的损失;②忽略不同尺度的空间信息
CAM利用最后一个卷积层的特征图×权重(用GAP代替全连接层,重新训练,经过GAP分类后概率最大的神经元的权重效果已经很不错需要修改原模型的结构,导致需要重新训练该模型,大大限制了使用场景,如果模型已经上线了,或着训练的成本非常高,我们几乎是不可能为了它重新训练的。
Grad-CAM最后一个卷积层的特征图×权重(通过对特征图梯度的全局平均来计算权重①解决了CAM的缺点,适用于任何卷积神经网络;②利用特征图的梯度,可视化结果更准确和精细
Grad-CAM++1. 定位更准确
2. 更适合同类多目标的情况

GAP全局平均池化

论文:Network In Network

GAP (Global Average Pooling,全局平均池化),在上述论文中提出,用于避免全连接层的过拟合问题。全局平均池化就是对整个特征映射应用平均池化。

图1:将原本h × w × d的三维特征图,具体大小为6 × 6 × 3,经过GAP池化为1 × 1 × 3 输出值。也就是每一个channel的h × w 平均池化为一个值。特征图经过 GAP 处理后每一个特征图包含了不同类别的信息。 

GAP平均池化的操作步骤如下:

  1. 经过卷积操作和激活函数后,得到最后一个卷积层的特征图。
  2. 对每个通道的特征图进行平均池化,即计算每个通道上所有元素的平均值。这将每个通道的特征图转化为一个标量值。
  3. 将每个通道的标量值组合成一个特征向量。这些标量值的顺序与通道的顺序相同。
  4. 最终得到的特征向量可以作为分类器的输入,用于进行图像分类。

CAM

论文:Learning Deep Features for Discriminative Localization

原理:利用最后一个卷积层的特征图与经过GAP分类后概率最大的神经元权重进行叠加。

图2:解释了在CNN中使用全局平均池化(GAP)生成类激活映射(CAM)的过程:

经过最后一层卷积操作之后,得到的特征图包含多个channel,如图1中的不同颜色的3个channel,也就是在GAP之前所对应的不同的channel特征图,f_{k}就表示第k个channel的特征图。然后经过GAP处理后每个channel的特征图包含了不同类别的信息,w_{k}就表示分类概率最大的神经元(图2黑色神经元)所对应连接的第k个神经元的权重。

Grad-CAM 

Grad-CAM的前身是 CAM,CAM 的基本的思想是求分类网络某一类别得分对高维特征图 (卷积层的输出) 的偏导数,从而可以得到该高维特征图每个通道对该类别得分的权值;而高维特征图的激活信息 (正值) 又代表了卷积神经网络的所感兴趣的信息,加权后使用热力图呈现得到 CAM。

原理:Grad-CAM的关键思想是将输出类别的梯度(相对于特定卷积层的输出)与该层的输出相乘,然后取平均,得到一个“粗糙”的热力图。这个热力图可以被放大并叠加到原始图像上,以显示模型在分类时最关注的区域。

具体步骤如下:

  1. 选择网络的最后一个卷积层,因为它既包含了高级特征,也保留了空间信息。
  2. 前向传播图像到网络,得到你想解释的类别的得分。
  3. 计算此得分相对于我们选择的卷积层输出的梯度。
  4. 对于该卷积层的每个通道,使用上述梯度的全局平均值对该通道进行加权。
  5. 结果是一个与卷积层的空间维度相同的加权热力图。

因为热力图关心的是对分类有正面影响的特征,所以在线性组合的技术上加上了ReLU,以移除负值 。

w_{k}^{c}第 k 个特征图对应于类别 c 的权重,
A^{k}表示:第 k 个特征图,
Z表示特征图的像素个数,
y^{c}表示: 第c类得分的梯度,
A_{ij}^{k}表示: 第 k个特征图中坐标( i , j )位置处的像素值;

Grad-CAM代码:

import torch
import cv2
import torch.nn.functional as F
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from PIL import Imageclass GradCAM:def __init__(self, model, target_layer):self.model = model  # 要进行Grad-CAM处理的模型self.target_layer = target_layer  # 要进行特征可视化的目标层self.feature_maps = None  # 存储特征图self.gradients = None  # 存储梯度# 为目标层添加钩子,以保存输出和梯度target_layer.register_forward_hook(self.save_feature_maps)target_layer.register_backward_hook(self.save_gradients)def save_feature_maps(self, module, input, output):"""保存特征图"""self.feature_maps = output.detach()def save_gradients(self, module, grad_input, grad_output):"""保存梯度"""self.gradients = grad_output[0].detach()def generate_cam(self, image, class_idx=None):"""生成CAM热力图"""# 将模型设置为评估模式self.model.eval()# 正向传播output = self.model(image)if class_idx is None:class_idx = torch.argmax(output).item()# 清空所有梯度self.model.zero_grad()# 对目标类进行反向传播one_hot = torch.zeros((1, output.size()[-1]), dtype=torch.float32)one_hot[0][class_idx] = 1output.backward(gradient=one_hot.cuda(), retain_graph=True)# 获取平均梯度和特征图pooled_gradients = torch.mean(self.gradients, dim=[0, 2, 3])activation = self.feature_maps.squeeze(0)for i in range(activation.size(0)):activation[i, :, :] *= pooled_gradients[i]# 创建热力图heatmap = torch.mean(activation, dim=0).squeeze().cpu().numpy()heatmap = np.maximum(heatmap, 0)heatmap /= torch.max(heatmap)heatmap = cv2.resize(heatmap, (image.size(3), image.size(2)))heatmap = np.uint8(255 * heatmap)heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)# 将热力图叠加到原始图像上original_image = self.unprocess_image(image.squeeze().cpu().numpy())superimposed_img = heatmap * 0.4 + original_imagesuperimposed_img = np.clip(superimposed_img, 0, 255).astype(np.uint8)return heatmap, superimposed_imgdef unprocess_image(self, image):"""反预处理图像,将其转回原始图像"""mean = np.array([0.485, 0.456, 0.406])std = np.array([0.229, 0.224, 0.225])image = (((image.transpose(1, 2, 0) * std) + mean) * 255).astype(np.uint8)return imagedef visualize_gradcam(model, input_image_path, target_layer):"""可视化Grad-CAM热力图"""# 加载图像img = Image.open(input_image_path)preprocess = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])input_tensor = preprocess(img).unsqueeze(0).cuda()# 创建GradCAMgradcam = GradCAM(model, target_layer)heatmap, result = gradcam.generate_cam(input_tensor)# 显示图像和热力图plt.figure(figsize=(10,10))plt.subplot(1,2,1)plt.imshow(heatmap)plt.title('热力图')plt.axis('off')plt.subplot(1,2,2)plt.imshow(result)plt.title('叠加后的图像')plt.axis('off')plt.show()# 以下是示例代码,显示如何使用上述代码。
# 首先,你需要加载你的模型和权重。
# model = resnet20()
# model.load_state_dict(torch.load("path_to_your_weights.pth"))
# model.to('cuda')# 然后,调用`visualize_gradcam`函数来查看结果。
# visualize_gradcam(model, "path_to_your_input_image.jpg", model.layer3[-1])

 Grad-CAM++

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/126578.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

前端 :用HTML , CSS ,JS 做一个秒表

1.HTML&#xff1a; <body><div id "content"><div id "top"><div id"time">00:00:000</div></div><div id "bottom"><div id "btn_start">开始</div><div …

04.Oracle的体系架构

Oracle的体系架构 一、主要组件 一、主要组件 下面是一张网图&#xff0c;大家可以了解一下oracle的体系架构 Oracle数据库的体系架构可以分为以下几个主要组件&#xff1a;实例&#xff08;Instance&#xff09;、数据库&#xff08;Database&#xff09;、表空间&#xff…

瑞数专题五

今日文案&#xff1a;焦虑&#xff0c;想象力过度发酵的产物。 网址&#xff1a;https://www.iyiou.com/ 专题五主要是分享瑞数6代。6代很少见&#xff0c;所以找理想哥要的&#xff0c;感谢感谢。 关于瑞数作者之前已经分享过4篇文章&#xff0c;全都收录在瑞数专栏中了&am…

21. 合并两个有序链表、Leetcode的Python实现

博客主页&#xff1a;&#x1f3c6;看看是李XX还是李歘歘 &#x1f3c6; &#x1f33a;每天不定期分享一些包括但不限于计算机基础、算法、后端开发相关的知识点&#xff0c;以及职场小菜鸡的生活。&#x1f33a; &#x1f497;点关注不迷路&#xff0c;总有一些&#x1f4d6;知…

正式启航!指导品牌开拓下一个增长蓝海

种草的商品总在不经意间推送到面前&#xff0c;深夜刷了会儿短视频&#xff0c;不小心又下单了一个不太熟悉的产品&#xff0c;明星达人素人全部入局直播带货&#xff0c;社交平台演变成购物场&#xff0c;无人幸免的兴趣电商时代强势来临。尤其到了每年一度的双11大促节点&…

数据库概念和sql语句

数据库概念和sql语句 数据&#xff1a;数&#xff1a;数字信息 据&#xff1a;属性 对一系列对象的具体属性的描述的集合 数据库&#xff1a;数据库就是用来组织&#xff08;各个数据之间是有关联&#xff0c;是按照规则组织起来的&#xff09;&#xff0c;存储和管理&…

音视频rtsp rtmp gb28181在浏览器上的按需拉流

按需拉流是从客户视角来看待音视频的产品功能&#xff0c;直观&#xff0c;好用&#xff0c;为啥hls flv大行其道也是这个原因&#xff0c;不过上述存在的问题是延迟没法降到实时毫秒级延迟&#xff0c;也不能随心所欲的控制。通过一段时间的努力&#xff0c;结合自己闭环技术栈…

C++新版本学习资源整理

链接资源推荐&#xff1a; C11/14/17/20 特性介绍 转 | 有点博客

Web APIs——日期对象的使用

1、日期对象 日期对象&#xff1a;用来表示时间的对象 作用&#xff1a;可以得到当前系统时间 1.1实例化 在代码中发现了new关键字时&#xff0c;一般将这个操作称为实例化 创建一个时间对象并获取时间 获得当前时间 const date new Date() <script>// 实例化 new //…

UE5 Android下载zip文件并解压缩到指定位置

一、下载是使用市场的免费插件 二、解压缩是使用市场的免费插件 三、Android路径问题 windows平台下使用该插件没有问题&#xff0c;只是在Android平台下&#xff0c;只有使用绝对路径才能进行解压缩&#xff0c;所以如何获得Android下的绝对路径&#xff1f;增加C文件获得And…

铁轨(Rails, ACM/ICPC CERC 1997, UVa 514)rust解法

有一个火车站&#xff0c;铁轨铺设如图6-1所示。有n节车厢从A方向驶入车站&#xff0c;按进站顺序编号为1&#xff5e;n。你的任务是判断是否能让它们按照某种特定的顺序进入B方向的铁轨并驶出车站。例如&#xff0c;出栈顺序(5 4 1 2 3)是不可能的&#xff0c;但(5 4 3 2 1)是…

python使用requests+excel进行接口自动化测试

在当今的互联网时代中&#xff0c;接口自动化测试越来越成为软件测试的重要组成部分。Python是一种简单易学&#xff0c;高效且可扩展的语言&#xff0c;自然而然地成为了开发人员的首选开发语言。而requests和xlwt这两个常用的Python标准库&#xff0c;能够帮助我们轻松地开发…

29、枚举

枚举 枚举使用场景枚举语法及特性特性&#xff1a; 手动给枚举赋值手动赋值项和未手动赋值项重复手动赋值项智能赋值数字&#xff1f;NO常数项和计算项常数枚举外部枚举 枚举使用场景 枚举类型 用于取值被限定在一定范围内的场景。 demo&#xff1a; 一周只能有七天&#xff0…

sqlLite 如何使用数据库连接池

文章底部有个人公众号&#xff1a;热爱技术的小郑。主要分享开发知识、学习资料、毕业设计指导等。有兴趣的可以关注一下。为何分享&#xff1f; 踩过的坑没必要让别人在再踩&#xff0c;自己复盘也能加深记忆。利己利人、所谓双赢。 一、前言 编写的一个jar包工具中&#xff…

JS(JavaScript) 实现延迟等待(sleep方法)

起因&#xff1a; 只使用 setTimeout 会产生嵌套等方面的问题&#xff0c;达不到想要的效果。 解决方法&#xff1a; 使用 async/await 还有 Promise 相结合的方式来解决问题。 直接上代码&#xff1a; function sleep(time) {return new Promise((resolve) > setTimeout…

公众号留言功能报价是多少?值得开通吗?

为什么公众号没有留言功能&#xff1f;根据要求&#xff0c;自2018年2月12日起&#xff0c;新申请的微信公众号默认无留言功能。有些人听过一个说法&#xff1a;公众号粉丝累计到一定程度或者原创文章数量累计到一定程度就可以开通留言功能。其实这个方法是2018年之前才可以&am…

三氧化二铁纳米片

&#xff08;西&#xff09;三氧化二铁纳米片 &#xff08;安&#xff09;名称&#xff1a;三氧化二铁纳米片 &#xff08;瑞&#xff09;CAS&#xff1a;1309-37-1 &#xff08;禧&#xff09;分子式&#xff1a;Fe2O3 &#xff08;生&#xff09;外观&#xff1a;白色粉末…

链表的引入

什么是链表 链表一种线性的数据结构&#xff0c;通过指针将一个个零散的内存块连接起来&#xff0c;链表的每个内存块称为结点。结构体指针在这里得到了充分的利用。 为什么要使用链表 链表可以动态的进行存储分配&#xff0c;也就是说&#xff0c;链表是一个功能极为强大的数…

518抽奖软件,是否支持作弊~内定~指定中奖人~设置范围

518抽奖软件简介 518抽奖软件&#xff0c;518我要发&#xff0c;超好用的年会抽奖软件&#xff0c;简约设计风格。 包含文字号码抽奖、照片抽奖两种模式&#xff0c;支持姓名抽奖、号码抽奖、数字抽奖、照片抽奖。(www.518cj.net) 主打纯净&#xff0c;不可作弊 市面上&…

7. 一文快速学懂常用工具——Makefile

本章讲解知识点 引言 Makefile Makefile 入门 本专栏适合于软件开发刚入职的学生或人士&#xff0c;有一定的编程基础&#xff0c;帮助大家快速掌握工作中必会的工具和指令。本专栏针对面试题答案进行了优化&#xff0c;尽量做到好记、言简意赅。如专栏内容有错漏&#xff0…