深度学习绘制热力图heatmap、使模型具有可解释性

思路

获取想要解释的那一层的特征图,然后根据特征图梯度计算出权重值,加在原图上面。

Demo

在这里插入图片描述
加上类激活(cam)
在这里插入图片描述
可以看到,cam将模型认为有利于分类的特征标注了出来。
下面以ResNet50为例:
Trick:
使用

for i in model._modules.items():

可以获得模型名称和对应层。

# coding: utf-8
import os
import cv2
import numpy as np
from PIL import Image
import matplotlib.pyplot as pltimport torch
import torch.autograd as autograd
import torchvision.transforms as transformsimport torchvision.models as models# 训练过的模型路径
#resume_path = r"D:\TJU\GBDB\set113\cross_validation\test1\epoch_0257_checkpoint.pth.tar"
# 输入图像路径
single_img_path = r'bicycle.jpg'
# 绘制的热力图存储路径
save_path = r'heatmap/bicycle_layer4.jpg'# 网络层的层名列表, 需要根据实际使用网络进行修改
layers_names = ['conv1', 'bn1', 'relu', 'maxpool', 'layer1', 'layer2', 'layer3', 'layer4', 'avgpool']
# 指定层名
out_layer_name = "layer4"features_grad = 0# 为了读取模型中间参数变量的梯度而定义的辅助函数
def extract(g):global features_gradfeatures_grad = gdef draw_CAM(model, img_path, save_path, transform=None, visual_heatmap=False, out_layer=None):"""绘制 Class Activation Map:param model: 加载好权重的Pytorch model:param img_path: 测试图片路径:param save_path: CAM结果保存路径:param transform: 输入图像预处理方法:param visual_heatmap: 是否可视化原始heatmap(调用matplotlib):return:"""# 读取图像并预处理global layer2img = Image.open(img_path).convert('RGB')if transform:img = transform(img)img = img.unsqueeze(0)  # (1, 3, 448, 448)# model转为eval模式model.eval()# 获取模型层的字典layers_dict = {layers_names[i]: None for i in range(len(layers_names))}for name,module in model._modules.items():#print(i, (name, module))layers_dict[name] = module# 遍历模型的每一层, 获得指定层的输出特征图# features: 指定层输出的特征图, features_flatten: 为继续完成前端传播而设置的变量features = imgstart_flatten = Falsefeatures_flatten = Nonefor name, layer in layers_dict.items():if name != out_layer and start_flatten is False:    # 指定层之前features = layer(features)elif name == out_layer and start_flatten is False:  # 指定层features = layer(features)start_flatten = Trueelse:   # 指定层之后if name == "fc":breakif features_flatten is None:features_flatten = layer(features)else:features_flatten = layer(features_flatten)#print(features_flatten.shape)features_flatten = torch.flatten(features_flatten, 1)#print(features_flatten.shape)output = model.fc(features_flatten)# 预测得分最高的那一类对应的输出scorepred = torch.argmax(output, 1).item()pred_class = output[:, pred]# 求中间变量features的梯度# 方法1# features.register_hook(extract)# pred_class.backward()# 方法2features_grad = autograd.grad(pred_class, features, allow_unused=True)[0]grads = features_grad  # 获取梯度pooled_grads = torch.nn.functional.adaptive_avg_pool2d(grads, (1, 1))# 此处batch size默认为1,所以去掉了第0维(batch size维)pooled_grads = pooled_grads[0]features = features[0]print("pooled_grads:", pooled_grads.shape)print("features:", features.shape)# features.shape[0]是指定层feature的通道数for i in range(features.shape[0]):features[i, ...] *= pooled_grads[i, ...]# 计算heatmapheatmap = features.detach().cpu().numpy()heatmap = np.mean(heatmap, axis=0)heatmap = np.maximum(heatmap, 0)heatmap /= np.max(heatmap)# 可视化原始热力图if visual_heatmap:plt.matshow(heatmap)plt.show()img = cv2.imread(img_path)  # 用cv2加载原始图像heatmap = cv2.resize(heatmap, (img.shape[1], img.shape[0]))  # 将热力图的大小调整为与原始图像相同heatmap = np.uint8(255 * heatmap)  # 将热力图转换为RGB格式heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)  # 将热力图应用于原始图像superimposed_img = heatmap * 0.7 + img  # 这里的0.4是热力图强度因子cv2.imwrite(save_path, superimposed_img)  # 将图像保存到硬盘if __name__ == '__main__':model = models.resnet50(pretrained=True)#model.eval()transform = transforms.Compose([transforms.Resize(448),transforms.ToTensor(),transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])# 构建模型并加载预训练参数#seresnet50 = FineTuneSEResnet50(num_class=113).cuda()#checkpoint = torch.load(resume_path)#seresnet50.load_state_dict(checkpoint['state_dict'])draw_CAM(model, single_img_path, save_path, transform=transform, visual_heatmap=True, out_layer=out_layer_name)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/759048.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

JavaScript和HTML/CSS之间有什么区别?它们之间的关系是什么?

下面是一个简单的代码示例&#xff0c;展示了HTML、CSS和JavaScript是如何一起工作的&#xff1a; <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-w…

「Linux系列」Shell介绍及起步

文章目录 一、Shell简介二、Shell脚本三、Shell解释器四、相关链接 一、Shell简介 Shell本身是一个用C语言编写的程序&#xff0c;它既是一种命令语言&#xff0c;又是一种程序设计语言。作为命令语言&#xff0c;它交互式地解释和执行用户输入的命令&#xff1b;作为程序设计…

深度学习500问——Chapter03:深度学习基础(4)

文章目录 3.7 预训练与微调&#xff08;fine tuning&#xff09; 3.7.1 为什么无监督预训练可以帮助深度学习 3.7.2 什么是模型微调 fine tuning 3.7.3 微调时候网络参数是否更新 3.7.4 fine-tuning模型的三种状态 3.8 权重偏差和初始化 3.8.1 全都初始化为0 3.8.2 全都初始化为…

小程序调用相机拍照上传

用的wpy框架&#xff0c;有this&#xff0c;原声小程序就按照你们的调方法就行了 //打开相机openCream() {const _this this;wx.showActionSheet({itemList: ["拍照"],//[拍照,相册]itemColor: "",//成功时回调success: function(res) {if (!res.cancel) …

Android Launcher开发注意事项

在开发Android Launcher时&#xff0c;需要关注性能、用户体验、权限管理、兼容性等方面&#xff0c;同时遵循相关的开发者政策和最佳实践。有几个重要的注意事项&#xff0c;希望对大家有所帮助。北京木奇移动技术有限公司&#xff0c;专业的软件外包开发公司&#xff0c;欢迎…

选择word中的表格VBA

打开开发工具 选择Visual Basic插入代码 Sub 选择word中的表格() Dim t As Table an MsgBox("即将选择选区内所有表格&#xff0c;若无选区&#xff0c;则选择全文表格。", vbYesNo, "提示") If an - 6 Then Exit Sub Set rg IIf(Selection.Type wdSel…

[HFCTF 2021 Final]easyflask

[HFCTF 2021 Final]easyflask [[python反序列化]] 首先题目给了提示&#xff0c;有文件读取漏洞&#xff0c;读取源码 #!/usr/bin/python3.6 import os import picklefrom base64 import b64decode from flask import Flask, request, render_template, sessionapp Flask(_…

HarmonyOS NEXT应用开发之侧滑返回事件拦截案例

介绍 在编辑场景中&#xff0c;存在用户误触返回&#xff0c;导致内容未保存就退出编辑页的现象; 本示例介绍使用NavDestination组件的onBackPressed回调对返回事件进行拦截&#xff0c;提示用户保存编辑内容&#xff0c;并使用preferences实例持久化保存内容。 效果预览图 使…

什么是常用的前端开发工具和框架?列举几个常用的前端框架和其特点。

前端开发工具和框架在Web开发中起着至关重要的作用&#xff0c;它们帮助开发者更高效地构建用户界面、管理数据和交互性。以下是一些常用的前端开发工具和框架&#xff1a; 1. **开发工具**&#xff1a; * **Visual Studio Code**&#xff1a;这是一个非常流行的代码编辑器…

使用verillog编写KMP字符串匹配算法

设计思路如下: 定义模块的输入输出信号:包括时钟信号clk、复位信号rst、模式串pattern、文本串text以及输出信号match。定义所需寄存器和变量:使用寄存器来存储状态机的状态以及其他控制变量,如模式串数组P、失配函数数组F、模式串位置p_index、文本串位置t_index等。在时钟…

C数据类型(C语言)---变量的类型决定了什么?

目录 数据类型&#xff08;Data Type&#xff09; 变量的类型决定了什么&#xff1f; &#xff08;1&#xff09;不同类型数据占用的内存大小不同 如何计算变量或类型占内存的大小 &#xff08;2&#xff09;不同数据类型的表数范围不同 &#xff08;3&#xff09;不同类型…

Python基础----数据容器(持续更新中)

学习目标 1、容器里面都有什么 2、容器怎么进行切片 python里面基本的数据类型都有什么 布尔类型、整型、浮点型、字符串 (都是不可变的&#xff0c;一旦创建数据内容不可更改&#xff0c;只能更改指向内存) python中可以划分为&#xff1a;数字型、非数字型 数字型&#xf…

异步操作错误之回调地狱问题

回调地狱指的是在异步编程中回调函数过多嵌套、代码深层次嵌套&#xff0c;导致代码可读性差、难以维护和调试的情况。这种情况通常出现在多个异步操作依赖于前一个异步操作结果的情况下&#xff0c;多次嵌套回调函数会形成回调金字塔&#xff0c;代码呈现出嵌套的结构&#xf…

初学者指南 | PostgreSQL中的加密机制如何运作?

在这篇文章中&#xff0c;我们将介绍可用于加密和解密PostgreSQL数据库中数据的不同方法。拥有一些 Linux 和 PostgreSQL 经验是必要的&#xff0c;但拥有加密经验并不是必需的&#xff0c;有经验当然更好。本文是使用 Ubuntu 23.04上运行的 PostgreSQL16编写的。首先&#xff…

【Golang星辰图】抵御恶意攻击:利用Go语言的安全库构建可靠的应用程序

加固你的代码&#xff1a;了解Go语言中的安全库和技术 前言 在当今数字化的世界中&#xff0c;保护代码和数据的安全性变得至关重要。恶意攻击、数据泄露和其他安全漏洞可能给我们的系统和用户带来巨大的风险和损失。为了增强软件的安全性和可靠性&#xff0c;我们需要利用现…

使用Qt在小米平板上热点使用问题记录

Qt程序安卓平板上在运行了差不多两个月后&#xff0c;突然出现图像画面严重卡顿&#xff0c;经过问题定位发现是热点模块在接收数据后出现延迟 第一次解决是尝试设置平板的设置&#xff0c;重启等等&#xff0c;无法解决&#xff0c;然后平板恢复出厂设置&#xff0c;解决了&a…

元素定位之xpath和css

元素定位 xpath绝对路径相对路径案例xpath策略&#xff08;路径&#xff09;案例xpath策略&#xff08;层级、扩展&#xff09;属性层级与属性层级与属性拓展层级与属性综合 csscss选择器&#xff08;id、类、标签、属性&#xff09;id选择器类选择器标签选择器属性选择器案例-…

Spark源码(一)-SparkRPC示例

一、何为SparkRPC RPC全称为远程过程调用&#xff08;Remote Procedure Call&#xff09;&#xff0c;它是一种计算机通信协议&#xff0c;允许一个计算机程序调用另一个计算机上的子程序&#xff0c;而无需了解底层网络细节。通过RPC&#xff0c;一个计算机程序可以像调用本地…

谷歌Gemma大模型部署记录

谷歌Gemma大模型部署记录 配置信息 1.系统&#xff1a;Ubuntu20 2.显卡&#xff1a;RTX3060 6G 一、安装Ollama 官网地址&#xff1a;https://ollama.com/download/linux 按照指令安装 curl -fsSL https://ollama.com/install.sh | sh二、运行模型 输入指令&#xff1a;…

【Java】:类和对象

1.面向对象的初步认知 1.1 什么是面向对象 Java是一门面向对象的语言&#xff0c;在面向对象的世界里&#xff0c;一切皆为对象。面向对象是解决问题的一种思想&#xff0c;主要依靠对象之间的交互完成一件事情。用面向对象的思想来涉及程序&#xff0c;更符合人们对事物的认知…