yolov8通过训练完成的模型生成图片热力图--论文需要

源代码来自于网络

使用pytorch_grad_cam,对特定图片生成热力图结果。
请添加图片描述

安装热力图工具

pip install pytorch_grad_cam
pip install grad-cam
# get_params中的参数:
# weight:
#         模型权重文件,代码默认是yolov8m.pt
# cfg:
#         模型文件,代码默认是yolov8m.yaml,需要注意的是需要跟weight中的预训练文件的配置是一样的,不然会报错
# device:
#         选择使用GPU还是CPU
# method:
#         选择grad-cam方法,默认是GradCAM,这里是提供了几种,可能对效果有点不一样,大家大胆尝试。
# layer::
#         选择需要可视化的层数,只需要修改数字即可,比如想用第9层,也就是model.model[9]。
# backward_type:
#         反向传播的方式,可以是以conf的loss传播,也可以class的loss传播,一般选用all,效果比较好一点。
# conf_threshold:
#         置信度,默认是0.6。
# ratio:
#         默认是0.02,就是用来筛选置信度高的结果,低的就舍弃,0.02则是筛选置信度最高的前2%的图像来进行热力图。![请添加图片描述](https://img-blog.csdnimg.cn/direct/4403f71e29314c68909ca28c037bd2b2.png)
import warningswarnings.filterwarnings('ignore')
warnings.simplefilter('ignore')
import torch, cv2, os, shutil
import numpy as npnp.random.seed(0)
import matplotlib.pyplot as plt
from tqdm import trange
from PIL import Image
from ultralytics.nn.tasks import DetectionModel as Model
from ultralytics.utils.torch_utils import intersect_dicts
from ultralytics.utils.ops import xywh2xyxy
from pytorch_grad_cam import GradCAMPlusPlus, GradCAM, XGradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from pytorch_grad_cam.activations_and_gradients import ActivationsAndGradientsdef letterbox(im, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleFill=False, scaleup=True, stride=32):# Resize and pad image while meeting stride-multiple constraintsshape = im.shape[:2]  # current shape [height, width]if isinstance(new_shape, int):new_shape = (new_shape, new_shape)# Scale ratio (new / old)r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])if not scaleup:  # only scale down, do not scale up (for better val mAP)r = min(r, 1.0)# Compute paddingratio = r, r  # width, height ratiosnew_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1]  # wh paddingif auto:  # minimum rectangledw, dh = np.mod(dw, stride), np.mod(dh, stride)  # wh paddingelif scaleFill:  # stretchdw, dh = 0.0, 0.0new_unpad = (new_shape[1], new_shape[0])ratio = new_shape[1] / shape[1], new_shape[0] / shape[0]  # width, height ratiosdw /= 2  # divide padding into 2 sidesdh /= 2if shape[::-1] != new_unpad:  # resizeim = cv2.resize(im, new_unpad, interpolation=cv2.INTER_LINEAR)top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))left, right = int(round(dw - 0.1)), int(round(dw + 0.1))im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color)  # add borderreturn im, ratio, (dw, dh)class yolov8_heatmap:def __init__(self, weight, cfg, device, method, layer, backward_type, conf_threshold, ratio):device = torch.device(device)ckpt = torch.load(weight)model_names = ckpt['model'].namescsd = ckpt['model'].float().state_dict()  # checkpoint state_dict as FP32model = Model(cfg, ch=3, nc=len(model_names)).to(device)csd = intersect_dicts(csd, model.state_dict(), exclude=['anchor'])  # intersectmodel.load_state_dict(csd, strict=False)  # loadmodel.eval()print(f'Transferred {len(csd)}/{len(model.state_dict())} items')target_layers = [eval(layer)]method = eval(method)colors = np.random.uniform(0, 255, size=(len(model_names), 3)).astype(np.int32)self.__dict__.update(locals())def post_process(self, result):logits_ = result[:, 4:]boxes_ = result[:, :4]sorted, indices = torch.sort(logits_.max(1)[0], descending=True)return torch.transpose(logits_[0], dim0=0, dim1=1)[indices[0]], torch.transpose(boxes_[0], dim0=0, dim1=1)[indices[0]], xywh2xyxy(torch.transpose(boxes_[0], dim0=0, dim1=1)[indices[0]]).cpu().detach().numpy()def draw_detections(self, box, color, name, img):xmin, ymin, xmax, ymax = list(map(int, list(box)))cv2.rectangle(img, (xmin, ymin), (xmax, ymax), tuple(int(x) for x in color), 2)cv2.putText(img, str(name), (xmin, ymin - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.8, tuple(int(x) for x in color), 2,lineType=cv2.LINE_AA)return imgdef __call__(self, img_path, save_path):# remove dir if existif os.path.exists(save_path):shutil.rmtree(save_path)# make dir if not existos.makedirs(save_path, exist_ok=True)# img processimg = cv2.imread(img_path)img = letterbox(img)[0]img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)img = np.float32(img) / 255.0tensor = torch.from_numpy(np.transpose(img, axes=[2, 0, 1])).unsqueeze(0).to(self.device)# init ActivationsAndGradientsgrads = ActivationsAndGradients(self.model, self.target_layers, reshape_transform=None)# get ActivationsAndResultresult = grads(tensor)activations = grads.activations[0].cpu().detach().numpy()# postprocess to yolo outputpost_result, pre_post_boxes, post_boxes = self.post_process(result[0])print(post_result.size(0))for i in trange(int(post_result.size(0) * self.ratio)):if float(post_result[i].max()) < self.conf_threshold:breakself.model.zero_grad()# get max probability for this predictionif self.backward_type == 'class' or self.backward_type == 'all':score = post_result[i].max()score.backward(retain_graph=True)if self.backward_type == 'box' or self.backward_type == 'all':for j in range(4):score = pre_post_boxes[i, j]score.backward(retain_graph=True)# process heatmapif self.backward_type == 'class':gradients = grads.gradients[0]elif self.backward_type == 'box':gradients = grads.gradients[0] + grads.gradients[1] + grads.gradients[2] + grads.gradients[3]else:gradients = grads.gradients[0] + grads.gradients[1] + grads.gradients[2] + grads.gradients[3] + \grads.gradients[4]b, k, u, v = gradients.size()weights = self.method.get_cam_weights(self.method, None, None, None, activations,gradients.detach().numpy())weights = weights.reshape((b, k, 1, 1))saliency_map = np.sum(weights * activations, axis=1)saliency_map = np.squeeze(np.maximum(saliency_map, 0))saliency_map = cv2.resize(saliency_map, (tensor.size(3), tensor.size(2)))saliency_map_min, saliency_map_max = saliency_map.min(), saliency_map.max()if (saliency_map_max - saliency_map_min) == 0:continuesaliency_map = (saliency_map - saliency_map_min) / (saliency_map_max - saliency_map_min)# add heatmap and box to imagecam_image = show_cam_on_image(img.copy(), saliency_map, use_rgb=True)cam_image = Image.fromarray(cam_image)cam_image.save(f'{save_path}/{i}.png')def get_params():params = {'weight': './weights/bz-yolov8-aspp-s-100.pt', # 这选择想要热力可视化的模型权重路径'cfg': './ultralytics/cfg/models/cfg2024/YOLOv8-金字塔结构改进/YOLOv8-ASPP.yaml', # 这里选择与训练上面模型权重相对应的.yaml文件路径'device': 'cpu', # 选择设备,其中0表示0号显卡。如果使用CPU可视化 # 'device': 'cpu' cuda:0'method': 'GradCAM', # GradCAMPlusPlus, GradCAM, XGradCAM'layer': 'model.model[6]',   # 选择特征层'backward_type': 'all', # class, box, all'conf_threshold': 0.65, # 置信度阈值默认0.65, 可根据情况调节'ratio': 0.02 # 取前多少数据,默认是0.02,可根据情况调节}return paramsif __name__ == '__main__':model = yolov8_heatmap(**get_params()) # 初始化model('output_002.jpg', './result') # 第一个参数是图片的路径,第二个参数是保存路径,比如是result的话,其会创建一个名字为result的文件夹,如果result文件夹不为空,其会先清空文件夹。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/26537.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【设计模式】行为型-模板方法模式

方法千变万化&#xff0c;心灵如潮&#xff0c;模板如画&#xff0c;画出生活的韵味。 文章目录 一、茶与咖啡二、模板方法模式三、模板方法模式的核心组成四、运用模板方法模式五、模板方法模式的应用场景六、小结推荐阅读 一、茶与咖啡 场景假设&#xff1a;我们需要完成茶…

基于Unet++在kaggle—2018dsb数据集上实现图像分割

目录 1. 作者介绍2. 理论知识介绍2.1 Unet模型介绍 3. 实验过程3.1 数据集介绍3.2 代码实现3.3 结果 4. 参考链接 1. 作者介绍 郭冠群&#xff0c;男&#xff0c;西安工程大学电子信息学院&#xff0c;2023级研究生 研究方向&#xff1a;机器视觉与人工智能 电子邮件&#xff…

Go变量作用域精讲及代码实战

1. 变量的作用域概述 在编程中&#xff0c;变量的作用域&#xff08;Scope&#xff09;定义了变量在程序中的可见性和生命周期。理解变量的作用域对于编写健壮且可维护的代码至关重要。Go语言&#xff08;简称Go&#xff09;提供了几种不同的作用域类型&#xff0c;使得开发者可…

在大数据时代:为何硬盘仍是数据中心存储的核心

在云计算和人工智能应用场景不断涌现的时代背景下&#xff0c;数据集的价值急剧上升&#xff0c;硬盘对于数据中心运营商来说变得比以往任何时候都更为关键。硬盘存储了全球大部分的艾字节&#xff08;EB&#xff09;数据&#xff0c;行业分析师预计&#xff0c;在艾字节持续增…

Oracle数据库面试题-10

1. 描述Oracle数据库体系结构的主要组件。 Oracle数据库体系结构由多个组件组成&#xff0c;这些组件协同工作以确保数据的存储、处理和安全性。以下是Oracle数据库的一些主要组件&#xff1a; 数据库实例&#xff08;Database Instance&#xff09;&#xff1a;Oracle数据库的…

华为手机USB调试调过登录

【抓包工具】配置&#xff1a;绕过华为手机打开 USB 调试需要先登录华为账号问题 参考上面的文章。但是可能因为没有登录账号&#xff0c;没法切到生产模式。 登录荣耀账号&#xff0c;再试就可以了&#xff0c;记得默认允许电脑调试&#xff0c;然后退出荣耀账号

C++:十大排序

目录 时间复杂度分析 选择排序 引言 算法思想 动图展示 代码实现 (升序) 优化 代码实现 分析 冒泡排序 引言 算法思想 动图展示 代码实现 插入排序 引言 算法思想 动图展示 代码实现 计数排序 引言 算法思想 动图展示 代码实现 桶排序 引言 算法思…

利安科技上市首日股价大涨:2023营收净利润下滑,募资金额大幅缩水

《港湾商业观察》施子夫 6月7日&#xff0c;宁波利安科技股份有限公司&#xff08;以下简称&#xff0c;利安科技&#xff09;正式在深交所创业板挂牌上市&#xff0c;股票简称为利安科技&#xff0c;股票代码300784。 上市当天&#xff0c;利安科技股价大涨348.76%。 2022年…

46.Python-web框架-Django - 多语言配置

目录 1.Django 多语言基础知识 1.1什么是Django国际化和本地化&#xff1f; 1.2Django LANGUAGE_CODE 1.3关于languages 1.4RequestContext对象针对翻译的变量 2.windows系统下的依赖 3.django多语言配置 3.1settings.py配置 引用gettext_lazy 配置多语言中间件&#x…

深入理解Elasticsearch集群:节点与分片机制

Elasticsearch作为当下最流行的开源搜索引擎和数据分析引擎之一&#xff0c;其强大的分布式集群能力和可扩展性是其核心优势。在Elasticsearch集群中&#xff0c;节点&#xff08;Node&#xff09;和分片&#xff08;Shard&#xff09;是两个核心概念&#xff0c;它们共同构成了…

PyTorch -- 最常见损失函数 LOSS 的选择

损失函数&#xff1a;度量模型的预测结果与真实值之间的差异&#xff1b;通过最小化 loss -> 最大化模型表现代码实现框架&#xff1a;设有 模型预测值 f (x), 真实值 y 方法一&#xff1a; 步骤 1. criterion torch.nn.某个Loss()&#xff1b;步骤 2. loss criterion(f(x…

广州·2025全国眼睛健康产业博览会眼科医学大会|全国眼博会

广州2025全国眼睛健康产业博览会眼科医学大会&#xff0c;2025年4月10-12日&#xff0c;在广州南丰国际会展中心举办&#xff1b; ——随着时代的进步和科技的飞速发展&#xff0c;人们的眼睛健康问题日益受到关注。为了推动眼睛健康产业的持续发展&#xff0c;加强眼科医学的…

实施ISO 26262与ISO 21434的关键要素分析

随着汽车工业的快速发展和智能化水平的不断提升&#xff0c;汽车的功能性和安全性成为了消费者关注的重点。为了确保车辆的安全性和可靠性&#xff0c;国际标准化组织&#xff08;ISO&#xff09;制定了一系列与汽车安全相关的标准&#xff0c;其中ISO 26262&#xff08;道路车…

set与map的详细封装步骤

目录 一.set与map在STL中的源码 二.修改红黑树 1.插入与查找时的比较方式 2.插入时的返回值 3.补充成员函数 三.封装set与map 1.迭代器的实现 2.函数接口 3.map中的operator[] 四.完整代码 set.h map.h RBTree.h 一.set与map在STL中的源码 想要简单实现set与map 需…

短视频矩阵工具有哪些?如何辨别是否正规?

随着短视频平台的持续火爆&#xff0c;搭建短视频矩阵成为各大品牌商家提高营销效果和完成流量变现的主要方式之一&#xff0c;类似于短视频矩阵工具有哪些等问题也在多个社群有着不小的讨论度。 而就短视频矩阵工具的市场现状而言&#xff0c;其整体呈现出数量不断增长&#x…

使用神卓互联来访问单位内部web【内网穿透神器】

在现代工作环境中&#xff0c;有时我们需要从外部访问单位内部的 web 资源&#xff0c;而神卓互联这款内网穿透神器就能完美地满足这一需求。 使用神卓互联来访问单位内部 web 其实并不复杂&#xff0c;以下是大致的使用步骤和配置方法。 首先&#xff0c;我们需要在单位内部的…

Three.js做了一个网页版的我的世界

前言 笔者在前一阵子接触到 Three.js 后, 发现了它能为前端 3D 可视化 / 动画 / 游戏方向带来的无限可能, 正好最近在与朋友重温我的世界, 便有了用 Three.js 来仿制 MineCraft 的想法, 正好也可以通过一个有趣的项目来学习一下前端 3D 领域 介绍 游戏介绍 相信大家对我的世…

模式识别与机器学习复习题解析(2023春)

文章目录 一、判断题二、填空题三、单选题四、简答题relu激活h1layer2h2 h1w2b2relu激活h2outputout h2w3 b3 一、判断题 1 单层感知机的局限性&#xff0c;它仅对线性问题具有分类能力( )。T 2.多层感知机的问题是隐藏层的权值无法训练( )。T 3.ReLU和Batch Normalization都…

vue3+ Element-Plus 点击勾选框往input中动态添加多个tag

实现效果&#xff1a; template&#xff1a; <!--产品白名单--><div class"con-item" v-if"current 0"><el-form-item label"平台名称"><div class"contaion" click"onclick"><!-- 生成的标签 …

Unity HoloLens2 MRTK 空间锚点 基础教程

Unity HoloLens2 MRTK 空间锚点 基础教程 Unity HoloLens2 空间锚点MRTK 空间锚点 准备Unity 工程创建设置切换 UWP 平台UWP 平台设置 下载并安装混合现实功能工具导入混合现实工具包和 OpenXR 包 Unity 编辑器 UWP 设置Unity 2019.4.40 设置Unity 2022.3.0 设置Unity 2022.3.0…