Grad_CAM图

我们要将网络学习到的特征进行可视化。

import os
import cv2
import numpy as np
import torch
from PIL import Image
import matplotlib.pyplot as plt
from torchvision import models
from torchvision import transforms
from utils import GradCAM, show_cam_on_image, center_crop_img
from lianxi import *
import torch.nn as nn
from dataset import CESM
from torch.utils.data import DataLoader
def main():net = CustomResNet50( in_channels=1,num_classes=2,chunk=1)path1 = r'E:\pycharmproject\CR-DLcode\CR_fussion\MF\checkpoint\resnet50\Thursday_14_March_2024_00h_12m_45s\resnet50-3-best.pth'net.load_state_dict(torch.load(path1))model = nettarget_layers = [net.model.layer4]CESMdata2 = CESM(base_dir=r'F:\CR的均值化数据\test',transform=transforms.Compose([transforms.ToTensor(),]))CESM_10_test_l = DataLoader(CESMdata2, batch_size=1, shuffle=False, drop_last=True,pin_memory=torch.cuda.is_available())for i, x in enumerate(CESM_10_test_l):input_tensor  = x['LOW_ENERGY']data= input_tensor.squeeze(0).numpy()cam = GradCAM(model=model, target_layers=target_layers, use_cuda=False)target_category = x['label']  # tabby, tabby catgrayscale_cam = cam(input_tensor=input_tensor, target_category=target_category)grayscale_cam = grayscale_cam[0, :]visualization = show_cam_on_image(data / 255.,grayscale_cam,use_rgb=True)# plt.imshow(visualization)# plt.show()cv2.imshow('Image', visualization)cv2.waitKey(0)cv2.destroyAllWindows()if __name__ == '__main__':main()

 导入我们训练好的模型参数和模型。导入数据。注意我们这里导入的是一维的灰度图像。就可以画出Grade_CAM图。

我们的utils文件

import cv2
import numpy as npclass ActivationsAndGradients:""" Class for extracting activations andregistering gradients from targeted intermediate layers """def \__init__(self, model, target_layers, reshape_transform):self.model = modelself.gradients = []self.activations = []self.reshape_transform = reshape_transformself.handles = []for target_layer in target_layers:self.handles.append(target_layer.register_forward_hook(self.save_activation))# Backward compatibility with older pytorch versions:if hasattr(target_layer, 'register_full_backward_hook'):self.handles.append(target_layer.register_full_backward_hook(self.save_gradient))else:self.handles.append(target_layer.register_backward_hook(self.save_gradient))def save_activation(self, module, input, output):activation = outputif self.reshape_transform is not None:activation = self.reshape_transform(activation)self.activations.append(activation.cpu().detach())def save_gradient(self, module, grad_input, grad_output):# Gradients are computed in reverse ordergrad = grad_output[0]if self.reshape_transform is not None:grad = self.reshape_transform(grad)self.gradients = [grad.cpu().detach()] + self.gradientsdef __call__(self, x):self.gradients = []self.activations = []return self.model(x)def release(self):for handle in self.handles:handle.remove()class GradCAM:def __init__(self,model,target_layers,reshape_transform=None,use_cuda=False):self.model = model.eval()self.target_layers = target_layersself.reshape_transform = reshape_transformself.cuda = use_cudaif self.cuda:self.model = model.cuda()self.activations_and_grads = ActivationsAndGradients(self.model, target_layers, reshape_transform)""" Get a vector of weights for every channel in the target layer.Methods that return weights channels,will typically need to only implement this function. """@staticmethoddef get_cam_weights(grads):return np.mean(grads, axis=(2, 3), keepdims=True)@staticmethoddef get_loss(output, target_category):loss = 0for i in range(len(target_category)):loss = loss + output[i, target_category[i]]return lossdef get_cam_image(self, activations, grads):weights = self.get_cam_weights(grads)weighted_activations = weights * activationscam = weighted_activations.sum(axis=1)return cam@staticmethoddef get_target_width_height(input_tensor):width, height = input_tensor.size(-1), input_tensor.size(-2)return width, heightdef compute_cam_per_layer(self, input_tensor):activations_list = [a.cpu().data.numpy()for a in self.activations_and_grads.activations]grads_list = [g.cpu().data.numpy()for g in self.activations_and_grads.gradients]target_size = self.get_target_width_height(input_tensor)cam_per_target_layer = []# Loop over the saliency image from every layerfor layer_activations, layer_grads in zip(activations_list, grads_list):cam = self.get_cam_image(layer_activations, layer_grads)cam[cam < 0] = 0  # works like mute the min-max scale in the function of scale_cam_imagescaled = self.scale_cam_image(cam, target_size)cam_per_target_layer.append(scaled[:, None, :])return cam_per_target_layerdef aggregate_multi_layers(self, cam_per_target_layer):cam_per_target_layer = np.concatenate(cam_per_target_layer, axis=1)cam_per_target_layer = np.maximum(cam_per_target_layer, 0)result = np.mean(cam_per_target_layer, axis=1)return self.scale_cam_image(result)@staticmethoddef scale_cam_image(cam, target_size=None):result = []for img in cam:img = img - np.min(img)img = img / (1e-7 + np.max(img))if target_size is not None:img = cv2.resize(img, target_size)result.append(img)result = np.float32(result)return resultdef __call__(self, input_tensor, target_category=None):if self.cuda:input_tensor = input_tensor.cuda()# 正向传播得到网络输出logits(未经过softmax)output = self.activations_and_grads(input_tensor)if isinstance(target_category, int):target_category = [target_category] * input_tensor.size(0)if target_category is None:target_category = np.argmax(output.cpu().data.numpy(), axis=-1)print(f"category id: {target_category}")else:assert (len(target_category) == input_tensor.size(0))self.model.zero_grad()loss = self.get_loss(output, target_category)loss.backward(retain_graph=True)# In most of the saliency attribution papers, the saliency is# computed with a single target layer.# Commonly it is the last convolutional layer.# Here we support passing a list with multiple target layers.# It will compute the saliency image for every image,# and then aggregate them (with a default mean aggregation).# This gives you more flexibility in case you just want to# use all conv layers for example, all Batchnorm layers,# or something else.cam_per_layer = self.compute_cam_per_layer(input_tensor)return self.aggregate_multi_layers(cam_per_layer)def __del__(self):self.activations_and_grads.release()def __enter__(self):return selfdef __exit__(self, exc_type, exc_value, exc_tb):self.activations_and_grads.release()if isinstance(exc_value, IndexError):# Handle IndexError here...print(f"An exception occurred in CAM with block: {exc_type}. Message: {exc_value}")return Truedef show_cam_on_image(img: np.ndarray,mask: np.ndarray,use_rgb: bool = False,colormap: int = cv2.COLORMAP_JET) -> np.ndarray:""" This function overlays the cam mask on the image as an heatmap.By default the heatmap is in BGR format.:param img: The base image in RGB or BGR format.:param mask: The cam mask.:param use_rgb: Whether to use an RGB or BGR heatmap, this should be set to True if 'img' is in RGB format.:param colormap: The OpenCV colormap to be used.:returns: The default image with the cam overlay."""# 将其转换为 NumPy 数组并复制为三通道img = cv2.cvtColor(img.squeeze(0), cv2.COLOR_GRAY2BGR)  # 将单通道灰度图像转换为三通道彩色图像heatmap = cv2.applyColorMap(np.uint8(255 * mask), colormap)if use_rgb:heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)heatmap = np.float32(heatmap) / 255if np.max(img) > 1:raise Exception("The input image should np.float32 in the range [0, 1]")cam = heatmap*0.002 + imgcam = cam / np.max(cam)return np.uint8(255 * cam)def center_crop_img(img: np.ndarray, size: int):h, w, c = img.shapeif w == h == size:return imgif w < h:ratio = size / wnew_w = sizenew_h = int(h * ratio)else:ratio = size / hnew_h = sizenew_w = int(w * ratio)img = cv2.resize(img, dsize=(new_w, new_h))if new_w == size:h = (new_h - size) // 2img = img[h: h+size]else:w = (new_w - size) // 2img = img[:, w: w+size]return img

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/743048.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

ES6(三):Iterator、Generator、类的用法、类的继承

一、迭代器Iterator 迭代器是访问数据的一个接口&#xff0c;是用于遍历数据结构的一个指针&#xff0c;迭代器就是遍历器 const items[one,two,three];//创建新的迭代器const ititems[Symbol.iterator]();console.log(it.next()); done&#xff1a;返回false表示遍历继续&a…

Python 查找PDF中的指定文本并高亮显示

在处理大量PDF文档时&#xff0c;有时我们需要快速找到特定的文本信息。本文将提供以下三个Python示例来帮助你在PDF文件中快速查找并高亮指定的文本。 查找并高亮PDF中所有的指定文本查找并高亮PDF某个区域内的指定文本使用正则表达式搜索指定文本并高亮 本文将用到国产第三方…

linux安全--CentOS7安装Tomcat,远程管理ManagerApp

目录 1.Tomcat安装 2.Tomcat远程管理 1.Tomcat安装 下载安装包并解压 tar xf apache-tomcat-7.0.54.tar.gz -C /usr/local/apache-tomcat_7.0.54/tomcat启停 启动 ./startup.sh 停止 ./shutdown.sh 2.Tomcat远程管理 找到tomcat文件夹中webapps/manager/META-INF/contex…

人工智能(AI)-机器学习-深度学习-大语言模型LLM(chatgtp)

【一文读懂“大语言模型” - CSDN App】 国产大语言模型是指由中国公司或机构开发的大规模预训练语言模型。目前&#xff0c;国产大语言模型主要有以下几种&#xff1a; 中文GPT&#xff08;GPT-3&#xff09;&#xff1a;由华为公司开发&#xff0c;是一个基于Transformer架…

Linux系统---Haproxy高性能负载均衡软件

目录 一、Haproxy介绍 1.Haproxy定义 2.Haproxy主要特性 3.Haproxy调度算法原理 3.1RR&#xff08;Round Robin&#xff09; 3.2LC&#xff08;Least Connections&#xff09; 3.3SH&#xff08;Source Hashing&#xff09; 二、安装Haproxy 1.yum安装 2.第三方rpm包安…

Android中compile,implementation和api的区别,以及gradle-wrapper的详解

前些天发现了一个蛮有意思的人工智能学习网站,8个字形容一下"通俗易懂&#xff0c;风趣幽默"&#xff0c;感觉非常有意思,忍不住分享一下给大家。 &#x1f449;点击跳转到教程 前言&#xff1a; compile,implementation和api的区别和其作用 compile和api会进行传递…

【深度学习目标检测】二十三、基于深度学习的行人检测计数系统-含数据集、GUI和源码(python,yolov8)

行人检测计数系统是一种重要的智能交通监控系统&#xff0c;它能够通过图像处理技术对行人进行实时检测、跟踪和计数&#xff0c;为城市交通规划、人流控制和安全管理提供重要数据支持。本系统基于先进的YOLOv8目标检测算法和PyQt5图形界面框架开发&#xff0c;具有高效、准确、…

叶子分享站PHP源码

叶子网盘分享站PHP网站源码&#xff0c;创建无限级文件夹&#xff0c;上传文件&#xff0c;可进行删除&#xff0c;下载等能很好的兼容服务器。方便管理者操作&#xff0c;查看更多的下载资源以及文章&#xff0c;新增分享功能&#xff0c;异步上传文件/资源等 PHP网盘源码优势…

蓝桥杯 递增三元组

Problem: 蓝桥杯 递增三元组 文章目录 思路解题方法复杂度前缀和Code二分Code双指针Code 思路 这是一个关于数组的问题&#xff0c;我们需要找到一个递增的三元组。这个三元组由三个数组中的元素组成&#xff0c;每个数组提供一个元素&#xff0c;并且这三个元素满足递增的关系…

Unix环境高级编程-学习-05-TCP/IP协议与套接字

目录 一、概念 二、TCP/IP参考模型 三、客户端和服务端使用TCP通信过程 1、同一以太网下 四、函数介绍 1、socket &#xff08;1&#xff09;声明 &#xff08;2&#xff09;作用 &#xff08;3&#xff09;参数 &#xff08;4&#xff09;返回值 &#xff08;5&…

三星泄露微软 Copilot 新功能:用自然语言操控各种功能

3 月 11 日消息&#xff0c;微软计划本月晚些时候发布新款 Surface 电脑和适用于 Windows 11 的 Copilot 新功能&#xff0c;但三星似乎等不及了&#xff0c;在其即将推出的 Galaxy Book4 系列产品宣传材料中泄露了一些即将到来的 Copilot 功能。 三星官网上发布的图片证实了此…

在centOS服务器安装docker,并使用docker配置nacos

遇到安装慢的情况可以优先选择阿里镜像 安装docker 更新yum版本 yum update安装所需软件包 yum install -y yum-utils device-mapper-persistent-data lvm2添加Docker仓库 yum-config-manager --add-repo http://mirrors.aliyun.com/docker-ce/linux/centos/docker-ce.rep…

SQLite—免费开源数据库系列文章目录

SQLite系列相关文章较多特开本文为了便于读者阅读特写了本索引和目录之用本文将不断更新中有需要的读者可以收藏本文便于导航到各个专题( 持续更新中......)。收藏一篇等于收藏一个系列文章 简介类&#xff1a; SQLite——世界上部署最广泛的免费开源数据库&#xff08;简介&…

【海贼王的数据航海】探究二叉树的奥秘

目录 1 -> 树的概念及结构 1.1 -> 树的概念 1.2 -> 树的相关概念 1.3 -> 树的表示 1.4 -> 树在实际中的运用(表示文件系统的目录树结构) 2 -> 二叉树概念及结构 2.1 -> 二叉树的概念 2.2 -> 现实中的二叉树 2.3 -> 特殊的二叉树 2.4 ->…

Post请求出现Request header is too large

问题描述&#xff1a; 在做项目的时候&#xff0c;前端请求体太大的时候&#xff0c;出现Request header is too large问题&#xff0c;后端接口如下&#xff1a; 前端请求接口返回问题如下&#xff1a; 解决方案&#xff1a; 问题原因&#xff1a;这是因为我们在做Springboo…

旅游管理系统|基于SpringBoot+ Mysql+Java+Tomcat技术的旅游管理系统设计与实现(可运行源码+数据库+设计文档+部署说明+视频演示)

推荐阅读100套最新项目 最新ssmjava项目文档视频演示可运行源码分享 最新jspjava项目文档视频演示可运行源码分享 最新Spring Boot项目文档视频演示可运行源码分享 目录 前台功能效果图 用户功能 管理员功能登录前台功能效果图 系统功能设计 数据库E-R图设计 lunwen参考 …

反爬虫技术:如何保护你的网站数据安全

在数字化时代&#xff0c;数据的价值日益凸显&#xff0c;而爬虫技术则成为了获取这些数据的重要手段之一。然而&#xff0c;对于网站运营者来说&#xff0c;非法爬虫不仅会导致数据泄露&#xff0c;还可能给网站带来巨大的流量压力和安全隐患。因此&#xff0c;本文将探讨如何…

您的 GStreamer 安装缺少插件

最近在学习QMLQT。 在弄一个多媒体播放的软件时&#xff0c;提示我系统缺少某些组件。我的系统是20.04.1-Ubuntu。 然后我看了很多帖子&#xff0c;大概思路就是&#xff0c;要装gstreamer 相关的组件。如果是比价低的ubunutu系统 就得装 gstreamer 0.10 的插件。如果是比价…

QT6.6 android下fftw-3.3.10库编译及调用so库方法

一.实现目标 fftw-3.3.10库在QT6.6的android环境下编译为so文件,然后在android项目中进行调用测试。 说明:编译的前提是要先部署好QT的android开发环境,具体可以参照本专栏文章《QT6.6 android开发环境搭建》,文章链接: https://blog.csdn.net/xieliru/article/detail…

【深度学习】YOLOv9继续训练——断点训练方法

YOLOv9继续训练主要分为两个情况&#xff1a; 其一、训练过程中意外中断&#xff0c;未完成训练预期的epoch数量&#xff1b; 其二、训练完了&#xff0c;但是未收敛&#xff0c;在这个基础上&#xff0c;还想用这个权重、学习率等参数继续训练多一些轮次 一、训练过程中意外…