深度学习绘制热力图heatmap、使模型具有可解释性

思路

获取想要解释的那一层的特征图,然后根据特征图梯度计算出权重值,加在原图上面。

Demo

在这里插入图片描述
加上类激活(cam)
在这里插入图片描述
可以看到,cam将模型认为有利于分类的特征标注了出来。
下面以ResNet50为例:
Trick:
使用

for i in model._modules.items():

可以获得模型名称和对应层。

# coding: utf-8
import os
import cv2
import numpy as np
from PIL import Image
import matplotlib.pyplot as pltimport torch
import torch.autograd as autograd
import torchvision.transforms as transformsimport torchvision.models as models# 训练过的模型路径
#resume_path = r"D:\TJU\GBDB\set113\cross_validation\test1\epoch_0257_checkpoint.pth.tar"
# 输入图像路径
single_img_path = r'bicycle.jpg'
# 绘制的热力图存储路径
save_path = r'heatmap/bicycle_layer4.jpg'# 网络层的层名列表, 需要根据实际使用网络进行修改
layers_names = ['conv1', 'bn1', 'relu', 'maxpool', 'layer1', 'layer2', 'layer3', 'layer4', 'avgpool']
# 指定层名
out_layer_name = "layer4"features_grad = 0# 为了读取模型中间参数变量的梯度而定义的辅助函数
def extract(g):global features_gradfeatures_grad = gdef draw_CAM(model, img_path, save_path, transform=None, visual_heatmap=False, out_layer=None):"""绘制 Class Activation Map:param model: 加载好权重的Pytorch model:param img_path: 测试图片路径:param save_path: CAM结果保存路径:param transform: 输入图像预处理方法:param visual_heatmap: 是否可视化原始heatmap(调用matplotlib):return:"""# 读取图像并预处理global layer2img = Image.open(img_path).convert('RGB')if transform:img = transform(img)img = img.unsqueeze(0)  # (1, 3, 448, 448)# model转为eval模式model.eval()# 获取模型层的字典layers_dict = {layers_names[i]: None for i in range(len(layers_names))}for name,module in model._modules.items():#print(i, (name, module))layers_dict[name] = module# 遍历模型的每一层, 获得指定层的输出特征图# features: 指定层输出的特征图, features_flatten: 为继续完成前端传播而设置的变量features = imgstart_flatten = Falsefeatures_flatten = Nonefor name, layer in layers_dict.items():if name != out_layer and start_flatten is False:    # 指定层之前features = layer(features)elif name == out_layer and start_flatten is False:  # 指定层features = layer(features)start_flatten = Trueelse:   # 指定层之后if name == "fc":breakif features_flatten is None:features_flatten = layer(features)else:features_flatten = layer(features_flatten)#print(features_flatten.shape)features_flatten = torch.flatten(features_flatten, 1)#print(features_flatten.shape)output = model.fc(features_flatten)# 预测得分最高的那一类对应的输出scorepred = torch.argmax(output, 1).item()pred_class = output[:, pred]# 求中间变量features的梯度# 方法1# features.register_hook(extract)# pred_class.backward()# 方法2features_grad = autograd.grad(pred_class, features, allow_unused=True)[0]grads = features_grad  # 获取梯度pooled_grads = torch.nn.functional.adaptive_avg_pool2d(grads, (1, 1))# 此处batch size默认为1,所以去掉了第0维(batch size维)pooled_grads = pooled_grads[0]features = features[0]print("pooled_grads:", pooled_grads.shape)print("features:", features.shape)# features.shape[0]是指定层feature的通道数for i in range(features.shape[0]):features[i, ...] *= pooled_grads[i, ...]# 计算heatmapheatmap = features.detach().cpu().numpy()heatmap = np.mean(heatmap, axis=0)heatmap = np.maximum(heatmap, 0)heatmap /= np.max(heatmap)# 可视化原始热力图if visual_heatmap:plt.matshow(heatmap)plt.show()img = cv2.imread(img_path)  # 用cv2加载原始图像heatmap = cv2.resize(heatmap, (img.shape[1], img.shape[0]))  # 将热力图的大小调整为与原始图像相同heatmap = np.uint8(255 * heatmap)  # 将热力图转换为RGB格式heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)  # 将热力图应用于原始图像superimposed_img = heatmap * 0.7 + img  # 这里的0.4是热力图强度因子cv2.imwrite(save_path, superimposed_img)  # 将图像保存到硬盘if __name__ == '__main__':model = models.resnet50(pretrained=True)#model.eval()transform = transforms.Compose([transforms.Resize(448),transforms.ToTensor(),transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])# 构建模型并加载预训练参数#seresnet50 = FineTuneSEResnet50(num_class=113).cuda()#checkpoint = torch.load(resume_path)#seresnet50.load_state_dict(checkpoint['state_dict'])draw_CAM(model, single_img_path, save_path, transform=transform, visual_heatmap=True, out_layer=out_layer_name)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/759048.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

深度学习500问——Chapter03:深度学习基础(4)

文章目录 3.7 预训练与微调(fine tuning) 3.7.1 为什么无监督预训练可以帮助深度学习 3.7.2 什么是模型微调 fine tuning 3.7.3 微调时候网络参数是否更新 3.7.4 fine-tuning模型的三种状态 3.8 权重偏差和初始化 3.8.1 全都初始化为0 3.8.2 全都初始化为…

Android Launcher开发注意事项

在开发Android Launcher时,需要关注性能、用户体验、权限管理、兼容性等方面,同时遵循相关的开发者政策和最佳实践。有几个重要的注意事项,希望对大家有所帮助。北京木奇移动技术有限公司,专业的软件外包开发公司,欢迎…

选择word中的表格VBA

打开开发工具 选择Visual Basic插入代码 Sub 选择word中的表格() Dim t As Table an MsgBox("即将选择选区内所有表格,若无选区,则选择全文表格。", vbYesNo, "提示") If an - 6 Then Exit Sub Set rg IIf(Selection.Type wdSel…

[HFCTF 2021 Final]easyflask

[HFCTF 2021 Final]easyflask [[python反序列化]] 首先题目给了提示,有文件读取漏洞,读取源码 #!/usr/bin/python3.6 import os import picklefrom base64 import b64decode from flask import Flask, request, render_template, sessionapp Flask(_…

HarmonyOS NEXT应用开发之侧滑返回事件拦截案例

介绍 在编辑场景中,存在用户误触返回,导致内容未保存就退出编辑页的现象; 本示例介绍使用NavDestination组件的onBackPressed回调对返回事件进行拦截,提示用户保存编辑内容,并使用preferences实例持久化保存内容。 效果预览图 使…

C数据类型(C语言)---变量的类型决定了什么?

目录 数据类型(Data Type) 变量的类型决定了什么? (1)不同类型数据占用的内存大小不同 如何计算变量或类型占内存的大小 (2)不同数据类型的表数范围不同 (3)不同类型…

元素定位之xpath和css

元素定位 xpath绝对路径相对路径案例xpath策略(路径)案例xpath策略(层级、扩展)属性层级与属性层级与属性拓展层级与属性综合 csscss选择器(id、类、标签、属性)id选择器类选择器标签选择器属性选择器案例-…

Spark源码(一)-SparkRPC示例

一、何为SparkRPC RPC全称为远程过程调用(Remote Procedure Call),它是一种计算机通信协议,允许一个计算机程序调用另一个计算机上的子程序,而无需了解底层网络细节。通过RPC,一个计算机程序可以像调用本地…

谷歌Gemma大模型部署记录

谷歌Gemma大模型部署记录 配置信息 1.系统:Ubuntu20 2.显卡:RTX3060 6G 一、安装Ollama 官网地址:https://ollama.com/download/linux 按照指令安装 curl -fsSL https://ollama.com/install.sh | sh二、运行模型 输入指令:…

【Java】:类和对象

1.面向对象的初步认知 1.1 什么是面向对象 Java是一门面向对象的语言,在面向对象的世界里,一切皆为对象。面向对象是解决问题的一种思想,主要依靠对象之间的交互完成一件事情。用面向对象的思想来涉及程序,更符合人们对事物的认知…

【LeetCode-114.二叉树展开为链表】

题目详情: 给你二叉树的根结点 root ,请你将它展开为一个单链表: 展开后的单链表应该同样使用 TreeNode ,其中 right 子指针指向链表中下一个结点,而左子指针始终为 null 。展开后的单链表应该与二叉树 先序遍历 顺序…

seleniumUI自动化实例(CSDN发布文章)

1.CSDN登陆成功后,点击发布 源码: #点击首页中的发布按钮 CSDNconf.driver.find_element(By.LINK_TEXT,"发布").click() time.sleep(15) 2.输入标题 #输入文章标题,标题格式“selenium UI自动化测试实例今天的日期” CSDNconf.d…

POI和EasyExcel区别和操作Excel

POI和EasyExcel操作Excel 常用场景 1、将用户信息导出为excel表格(导出数据… ) 2、将Excel表中的信息录入到网站数据库(文件数据上传… ) 开发中经常会设计到excel的处理,如导出Excel,导入Excel到数据库…

springboot+itextpdf+thymeleaf+ognl根据静态模版文件实现动态生成pdf文件并导出demo

第一步&#xff1a;导入maven依赖 <!-- 导出为PDF依赖包 --><dependency><groupId>com.itextpdf</groupId><artifactId>itextpdf</artifactId></dependency><dependency><groupId>com.itextpdf</groupId><art…

HarmonyOS(鸿蒙)应用开发——(一)

目录 1 创建hellopro项目 2 了解ArkTS 3 了解ArkTS的组件 4 组件介绍 4.1 常用基础组件&#xff1a; 4.1.1 Text 4.1.2 Button 4.1.3 TextInput 4.2 容器组件 4.2.1 Column 4.2.2 Row 5 案例——实现一个简易登录页面 5.1 在实现预览效果之前&#xff0c;我们…

【机器学习】基于果蝇算法优化的BP神经网络分类预测(FOA-BP)

目录 1.原理与思路2.设计与实现3.结果预测4.代码获取 1.原理与思路 【智能算法应用】智能算法优化BP神经网络思路【智能算法】果蝇算法&#xff08;FOA&#xff09;原理及实现 2.设计与实现 数据集&#xff1a; 多输入多输出&#xff1a;样本特征24&#xff0c;标签类别4。…

【计算机视觉】三、图像处理——实验:图像去模糊和去噪、提取边缘特征

文章目录 0. 实验环境1. 理论基础1.1 滤波器&#xff08;卷积核&#xff09;1.2 PyTorch:卷积操作 2. 图像处理2.1 图像读取2.2 查看通道2.3 图像处理 3. 图像去模糊4. 图像去噪4.1 添加随机噪点4.2 图像去噪 0. 实验环境 本实验使用了PyTorch深度学习框架&#xff0c;相关操作…

bezier曲线拟合椭圆弧线

椭圆弧线用bezier曲线拟合 。 先计算出 椭圆中心 起始角度 旋转角度 S t e p 1 : C o m p u t e ( x 1 ′ , y 1 ′ ) Step 1: Compute(x_1, y_1) Step1:Compute(x1′​,y1′​) ( x 1 ′ y 1 ′ ) ( cos ⁡ φ sin ⁡ φ − sin ⁡ φ cos ⁡ φ ) ⋅ ( x 1 − x 2 2 y 1 −…

some/ip CAN CANFD

关于SOME/IP的理解 在CAN总线的车载网络中&#xff0c;通信过程是面向信号的 当ECU的信号的值发生了改变&#xff0c;或者发送周期到了&#xff0c;就会发送消息&#xff0c;而不考虑接收者是否需要&#xff0c;这样就会造成总线上出现不必要的信息&#xff0c;占用了带宽 …

RabbitMQ详细讲解

目录 4.0 AMQP协议的回顾 4.1 RabbitMQ支持的消息模型 4.2 引入依赖 4.3 第一种模型(直连) 1. 开发生产者 2. 开发消费者 3. 参数的说明 4.4 第二种模型(work quene) 1. 开发生产者 2.开发消费者-1 3.开发消费者-2 4.测试结果 5.消息自动确认机制 4.5 第三种模型(…