pytorch生成CAM热力图-单张图像

利用ImageNet预训练模型生成CAM热力图-单张图像

  • 一、环境搭建
  • 二、主要代码
  • 三、结果展示

代码和图片等资源均来源于哔哩哔哩up主:同济子豪兄
讲解视频:CAM可解释性分析-算法讲解

一、环境搭建

1,安装所需的包

pip install numpy pandas matplotlib requests tqdm opencv-python pillow -i https://pypi.tuna.tsinghua.edu.cn/simple

2,安装 Pytorch

pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113

3,安装 mmcv-full

pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/cu113/torch1.10.0/index.html

4,下载中文字体文件(用于显示和打印汉字文字)

wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20220716-mmclassification/dataset/SimHei.ttf

5,下载 ImageNet 1000类别信息

wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20220716-mmclassification/dataset/meta_data/imagenet_class_index.csv

6,创建 test_img 文件夹,并下载测试图像到该文件夹

import os
os.mkdir('test_img')wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20220716-mmclassification/test/border-collie.jpg -P test_img
wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20220716-mmclassification/test/cat_dog.jpg -P test_img
wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20220716-mmclassification/test/0818/room_video.mp4 -P test_img

7,下载安装 torchcam

git clone https://github.com/frgfm/torch-cam.git
pip install -e torch-cam/.

二、主要代码

from PIL import Imageimport torch
# 有 GPU 就用 GPU,没有就用 CPU
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('device', device)# 导入ImageNet预训练模型
from torchvision.models import resnet18
model = resnet18(pretrained=True).eval().to(device)# 导入自己训练的模型
# model = torch.load('自己训练的模型.pth')
# model = model.eval().to(device)# 可解释性分析方法有:CAM GradCAM GradCAMpp ISCAM LayerCAM SSCAM ScoreCAM SmoothGradCAMpp XGradCAM# 方法一:导入可解释性分析方法SmoothGradCAMpp
# from torchcam.methods import SmoothGradCAMpp 
# cam_extractor = SmoothGradCAMpp(model)# 方法二:导入可解释性分析方法GradCAM
from torchcam.methods import GradCAM
target_layer = model.layer4[-1]    # 选择目标层
cam_extractor = GradCAM(model, target_layer)# 图片预处理
from torchvision import transforms
# 测试集图像预处理-RCTN:缩放、裁剪、转 Tensor、归一化
test_transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])# 图片分类预测
img_path = 'test_img/border-collie.jpg'
img_pil = Image.open(img_path)
input_tensor = test_transform(img_pil).unsqueeze(0).to(device) # 预处理
pred_logits = model(input_tensor)
# topk()方法用于返回输入数据中特定维度上的前k个最大的元素
pred_top1 = torch.topk(pred_logits, 1)
# pred_id 为图片所属分类对应的索引号,分类和索引号存储在imagenet_class_index.csv
pred_id = pred_top1[1].detach().cpu().numpy().squeeze().item()# 生成可解释性分析热力图
activation_map = cam_extractor(pred_id, pred_logits)
activation_map = activation_map[0][0].detach().cpu().numpy()# 可视化
from torchcam.utils import overlay_mask# overlay_mask 用于构建透明的叠加层
# fromarray 实现array到image的转换
result = overlay_mask(img_pil, Image.fromarray(activation_map), alpha=0.7)# 为图片添加中文类别显示# 载入ImageNet 1000 类别中文释义
import pandas as pd
df = pd.read_csv('imagenet_class_index.csv')
idx_to_labels = {}
idx_to_labels_cn = {}
for idx, row in df.iterrows():idx_to_labels[row['ID']] = row['class']idx_to_labels_cn[row['ID']] = row['Chinese']# 显示所有中文类别
# idx_to_labels_cn# 可视化热力图的类别ID,如果为 None,则为置信度最高的预测类别ID
# show_class_id = 231		# 例如 牧羊犬:231 虎猫:281
show_class_id = None# 可视化热力图的类别ID,如果不指定,则为置信度最高的预测类别ID
if show_class_id:show_id = show_class_id
else:show_id = pred_idshow_class_id = pred_id# 是否显示中文类别
Chinese = True
# Chinese = Falsefrom PIL import ImageDraw
# 在图像上写字
draw = ImageDraw.Draw(result)if Chinese:# 在图像上写中文text_pred = 'Pred Class: {}'.format(idx_to_labels_cn[pred_id])text_show = 'Show Class: {}'.format(idx_to_labels_cn[show_class_id])
else:# 在图像上写英文text_pred = 'Pred Class: {}'.format(idx_to_labels[pred_id])text_show = 'Show Class: {}'.format(idx_to_labels[show_class_id])from PIL import ImageFont, ImageDraw
# 导入中文字体,指定字体大小
font = ImageFont.truetype('SimHei.ttf', 30)# 文字坐标,中文字符串,字体,rgba颜色
draw.text((10, 10), text_pred, font=font, fill=(255, 0, 0, 1))
draw.text((10, 50), text_show, font=font, fill=(255, 0, 0, 1))#输出结果图
result

注意:

  1. 可解释性方法的选择有多种,代码中提供了 SmoothGradCAMpp 和 GradCAM 两种方法;
  2. 模型选择也有pytorch预训练模型和自己训练的模型两种,代码中演示了 ImageNet图像分类 模型,图片类别文件为 imagenet_class_index.csv;若为自己的模型则还需要修改 “为图片载入类别的部分代码”

三、结果展示

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/79714.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

原生微信小程序中进行 API 请求

原生微信小程序中进行 API 请求 当在原生微信小程序中进行 API 请求时,封装请求可以提高代码的可维护性和可扩展性。在本篇博客中,我们将一步步介绍如何进一步封装请求,并添加请求超时、拦截器和请求取消功能。 第一步:基本请求封…

混淆矩阵细致理解

1、什么是混淆矩阵 混淆矩阵(Confusion Matrix)是深度学习和机器学习领域中的一个重要工具,用于评估分类模型的性能。它提供了一个清晰的视觉方式来展示模型的预测结果与真实标签之间的关系,尤其在分类任务中,帮助我们…

【Unity基础】2.网格材质贴图与资源打包

【Unity基础】2.网格材质贴图与资源打包 大家好,我是Lampard~~ 欢迎来到Unity基础系列博客,所学知识来自B站阿发老师~感谢 (一)网格材质纹理 第一次接触3D物体的话,会觉得好神奇啊,这个物体究竟是由什么组…

基于安卓Java试题库在线考试系统uniapp 微信小程序

本文首先分析了题库app应用程序的需求,从系统开发环境、系统目标、设计流程、功能设计等几个方面对系统进行了系统设计。开发出本题库app,主要实现了学生、教师、测试卷、试题、考试等。总体设计主要包括系统功能设计、该系统里充分综合应用Mysql数据库、…

Ae 效果:CC Particle Systems II

模拟/CC Particle Systems II Simulation/CC Particle Systems II CC Particle Systems II(CC 粒子系统 II)可用于生成和模拟各种类型的粒子系统,包括火焰、雨、雪、爆炸、烟雾等等。 与 CC Particle World 效果相比有许多类似的属性。最大的…

前端该了解的网络知识

网络 前端开发需要了解的网络知识 URL URL(uniform resource locator,统一资源定位符)用于定位网络服务. URL是一个固定格式的字符串 它表达了: 从网络中哪台计算机(domain)中的哪个服务(port),获取服务器上资源的路径(path),以及要用什么样的协议通信(schema). 注意: 当…

C# wpf 实现桌面放大镜

文章目录 前言一、如何实现?1、制作无边框窗口2、Viewbox放大3、截屏显示(1)、截屏(2)、转BitmapSource(3)、显示 4、定时截屏 二、完整代码三、效果预览总结 前言 做桌面截屏功能时需要放大镜…

卫星物联网生态建设全面加速,如何抓住机遇?

当前,卫星通信无疑是行业最热门的话题之一。近期发布的华为Mate 60 Pro“向上捅破天”技术再次升级,成为全球首款支持卫星通话的大众智能手机,支持拨打和接听卫星电话,还可自由编辑卫星消息。 据悉,华为手机的卫星通话…

【Unity每日一记】资源加载相关和检测相关

👨‍💻个人主页:元宇宙-秩沅 👨‍💻 hallo 欢迎 点赞👍 收藏⭐ 留言📝 加关注✅! 👨‍💻 本文由 秩沅 原创 👨‍💻 收录于专栏:uni…

【计算机网络】Tcp详解

文章目录 前言Tcp协议段格式TCP的可靠性面向字节流应答机制超时重传流量控制滑动窗口(重要)拥塞控制延迟应答捎带应答标志位具体标志位三次握手四次挥手粘包问题TCP异常情况listen的第二个参数 前言 前面我们学习了传输层协议Udp,今天我们一…

使用FFmpeg+ubuntu系统转化flac无损音频为mp3

功能需求如上题,我们来具体的操作一下: 1.先在ubuntu上面安装FFmpeg:sudo apt install ffmpeg 2.进入有flac音频文件的目录使用下述命令: ffmpeg -i test.FLAC -c:a libmp3lame -q:a 2 output.mp3 3.如果没有什么意外的话,你就能看到你的文件夹里面已经有转化好的mp3文件了 批…

ubuntu中如何用docker下载华为opengauss数据库(超简单)

ubuntu中如何下载华为opengauss数据库 前言一、安装docker1.方法一:2.方法二 二、拉取openguass镜像三、创建容器四、连接数据库 ,切换到omm用户 ,用gsql连接到数据库五.最后用DateGrip远程连接测试(1)选择数据源(2)查看虚拟机ip地…

#循循渐进学51单片机#定时器与数码管#not.4

1、熟练掌握单片机定时器的原理和应用方法。 1)时钟周期:单片机时序中的最小单位,具体计算的方法就是时钟源分之一。 2)机器周期:我们的单片机完成一个操作的最短时间。 3)定时器:打开定时器“储存寄存器…

Python提取JSON数据中的键值对并保存为.csv文件

本文介绍基于Python,读取JSON文件数据,并将JSON文件中指定的键值对数据转换为.csv格式文件的方法。 在之前的文章Python提取JSON文件中的指定数据并保存在CSV或Excel表格文件内(https://blog.csdn.net/zhebushibiaoshifu/article/details/132…

Windows PostgreSql 创建多个数据库目录

1 使用默认用户Administrator 1.1初始化数据库目录 E:\Program Files\PostgreSQL\13> .\bin\initdb -D G:\DATA\pgsql\data3 -W -A md5 1.2连接数据库 这时User为Administrator,密码就是你刚才设置的,我设置的为123456,方便测试。 2 添加…

黑马JVM总结(九)

(1)StringTable_调优1 我们知道StringTable底层是一个哈希表,哈希表的性能是跟它的大小相关的,如果哈希表这个桶的个数比较多,元素相对分散,哈希碰撞的几率就会减少,查找的速度较快&#xff0c…

【微服务】六. Nacos配置管理

6.1 Nacos实现配置管理 配置更改热更新 在nacos左侧新建配置管理 Data ID:就是配置文件名称 一般命名规则:服务名称-环境名称.yaml 配置内容填写:需要热更新需求的配置 配置文件的id:[服务名称]-[profile].[后缀名] 分组&#…

Vuex详解:Vue.js的状态管理方案

🌷🍁 博主猫头虎(🐅🐾)带您 Go to New World✨🍁 🦄 博客首页——🐅🐾猫头虎的博客🎐 🐳 《面试题大全专栏》 🦕 文章图文…

启动微服务,提示驱动程序无法通过使用安全套接字层(SSL)加密与 SQL Server 建立安全连接

说明:启动一些微服务后,一直在报下面这个错误; com.microsoft.sqlserver.jdbc.SQLServerException: 驱动程序无法通过使用安全套接字层(SSL)加密与 SQL Server 建立安全连接。错误:“The server selected protocol version TLS10 is not acc…

uniapp抽取组件绑定事件中箭头函数含花括号无法解析

版本: "dcloudio/uni-ui": "^1.4.27", "vue": "> 2.6.14 < 2.7"... 箭头函数后含有花括号的时候, getData就拿不到val参数 , 解决办法就是去除花括号 // 错误代码: <SearchComp change"(val) > { getData({ val …