我们要将网络学习到的特征进行可视化。
import os
import cv2
import numpy as np
import torch
from PIL import Image
import matplotlib.pyplot as plt
from torchvision import models
from torchvision import transforms
from utils import GradCAM, show_cam_on_image, center_crop_img
from lianxi import *
import torch.nn as nn
from dataset import CESM
from torch.utils.data import DataLoader
def main():net = CustomResNet50( in_channels=1,num_classes=2,chunk=1)path1 = r'E:\pycharmproject\CR-DLcode\CR_fussion\MF\checkpoint\resnet50\Thursday_14_March_2024_00h_12m_45s\resnet50-3-best.pth'net.load_state_dict(torch.load(path1))model = nettarget_layers = [net.model.layer4]CESMdata2 = CESM(base_dir=r'F:\CR的均值化数据\test',transform=transforms.Compose([transforms.ToTensor(),]))CESM_10_test_l = DataLoader(CESMdata2, batch_size=1, shuffle=False, drop_last=True,pin_memory=torch.cuda.is_available())for i, x in enumerate(CESM_10_test_l):input_tensor = x['LOW_ENERGY']data= input_tensor.squeeze(0).numpy()cam = GradCAM(model=model, target_layers=target_layers, use_cuda=False)target_category = x['label'] # tabby, tabby catgrayscale_cam = cam(input_tensor=input_tensor, target_category=target_category)grayscale_cam = grayscale_cam[0, :]visualization = show_cam_on_image(data / 255.,grayscale_cam,use_rgb=True)# plt.imshow(visualization)# plt.show()cv2.imshow('Image', visualization)cv2.waitKey(0)cv2.destroyAllWindows()if __name__ == '__main__':main()
导入我们训练好的模型参数和模型。导入数据。注意我们这里导入的是一维的灰度图像。就可以画出Grade_CAM图。
我们的utils文件
import cv2
import numpy as npclass ActivationsAndGradients:""" Class for extracting activations andregistering gradients from targeted intermediate layers """def \__init__(self, model, target_layers, reshape_transform):self.model = modelself.gradients = []self.activations = []self.reshape_transform = reshape_transformself.handles = []for target_layer in target_layers:self.handles.append(target_layer.register_forward_hook(self.save_activation))# Backward compatibility with older pytorch versions:if hasattr(target_layer, 'register_full_backward_hook'):self.handles.append(target_layer.register_full_backward_hook(self.save_gradient))else:self.handles.append(target_layer.register_backward_hook(self.save_gradient))def save_activation(self, module, input, output):activation = outputif self.reshape_transform is not None:activation = self.reshape_transform(activation)self.activations.append(activation.cpu().detach())def save_gradient(self, module, grad_input, grad_output):# Gradients are computed in reverse ordergrad = grad_output[0]if self.reshape_transform is not None:grad = self.reshape_transform(grad)self.gradients = [grad.cpu().detach()] + self.gradientsdef __call__(self, x):self.gradients = []self.activations = []return self.model(x)def release(self):for handle in self.handles:handle.remove()class GradCAM:def __init__(self,model,target_layers,reshape_transform=None,use_cuda=False):self.model = model.eval()self.target_layers = target_layersself.reshape_transform = reshape_transformself.cuda = use_cudaif self.cuda:self.model = model.cuda()self.activations_and_grads = ActivationsAndGradients(self.model, target_layers, reshape_transform)""" Get a vector of weights for every channel in the target layer.Methods that return weights channels,will typically need to only implement this function. """@staticmethoddef get_cam_weights(grads):return np.mean(grads, axis=(2, 3), keepdims=True)@staticmethoddef get_loss(output, target_category):loss = 0for i in range(len(target_category)):loss = loss + output[i, target_category[i]]return lossdef get_cam_image(self, activations, grads):weights = self.get_cam_weights(grads)weighted_activations = weights * activationscam = weighted_activations.sum(axis=1)return cam@staticmethoddef get_target_width_height(input_tensor):width, height = input_tensor.size(-1), input_tensor.size(-2)return width, heightdef compute_cam_per_layer(self, input_tensor):activations_list = [a.cpu().data.numpy()for a in self.activations_and_grads.activations]grads_list = [g.cpu().data.numpy()for g in self.activations_and_grads.gradients]target_size = self.get_target_width_height(input_tensor)cam_per_target_layer = []# Loop over the saliency image from every layerfor layer_activations, layer_grads in zip(activations_list, grads_list):cam = self.get_cam_image(layer_activations, layer_grads)cam[cam < 0] = 0 # works like mute the min-max scale in the function of scale_cam_imagescaled = self.scale_cam_image(cam, target_size)cam_per_target_layer.append(scaled[:, None, :])return cam_per_target_layerdef aggregate_multi_layers(self, cam_per_target_layer):cam_per_target_layer = np.concatenate(cam_per_target_layer, axis=1)cam_per_target_layer = np.maximum(cam_per_target_layer, 0)result = np.mean(cam_per_target_layer, axis=1)return self.scale_cam_image(result)@staticmethoddef scale_cam_image(cam, target_size=None):result = []for img in cam:img = img - np.min(img)img = img / (1e-7 + np.max(img))if target_size is not None:img = cv2.resize(img, target_size)result.append(img)result = np.float32(result)return resultdef __call__(self, input_tensor, target_category=None):if self.cuda:input_tensor = input_tensor.cuda()# 正向传播得到网络输出logits(未经过softmax)output = self.activations_and_grads(input_tensor)if isinstance(target_category, int):target_category = [target_category] * input_tensor.size(0)if target_category is None:target_category = np.argmax(output.cpu().data.numpy(), axis=-1)print(f"category id: {target_category}")else:assert (len(target_category) == input_tensor.size(0))self.model.zero_grad()loss = self.get_loss(output, target_category)loss.backward(retain_graph=True)# In most of the saliency attribution papers, the saliency is# computed with a single target layer.# Commonly it is the last convolutional layer.# Here we support passing a list with multiple target layers.# It will compute the saliency image for every image,# and then aggregate them (with a default mean aggregation).# This gives you more flexibility in case you just want to# use all conv layers for example, all Batchnorm layers,# or something else.cam_per_layer = self.compute_cam_per_layer(input_tensor)return self.aggregate_multi_layers(cam_per_layer)def __del__(self):self.activations_and_grads.release()def __enter__(self):return selfdef __exit__(self, exc_type, exc_value, exc_tb):self.activations_and_grads.release()if isinstance(exc_value, IndexError):# Handle IndexError here...print(f"An exception occurred in CAM with block: {exc_type}. Message: {exc_value}")return Truedef show_cam_on_image(img: np.ndarray,mask: np.ndarray,use_rgb: bool = False,colormap: int = cv2.COLORMAP_JET) -> np.ndarray:""" This function overlays the cam mask on the image as an heatmap.By default the heatmap is in BGR format.:param img: The base image in RGB or BGR format.:param mask: The cam mask.:param use_rgb: Whether to use an RGB or BGR heatmap, this should be set to True if 'img' is in RGB format.:param colormap: The OpenCV colormap to be used.:returns: The default image with the cam overlay."""# 将其转换为 NumPy 数组并复制为三通道img = cv2.cvtColor(img.squeeze(0), cv2.COLOR_GRAY2BGR) # 将单通道灰度图像转换为三通道彩色图像heatmap = cv2.applyColorMap(np.uint8(255 * mask), colormap)if use_rgb:heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)heatmap = np.float32(heatmap) / 255if np.max(img) > 1:raise Exception("The input image should np.float32 in the range [0, 1]")cam = heatmap*0.002 + imgcam = cam / np.max(cam)return np.uint8(255 * cam)def center_crop_img(img: np.ndarray, size: int):h, w, c = img.shapeif w == h == size:return imgif w < h:ratio = size / wnew_w = sizenew_h = int(h * ratio)else:ratio = size / hnew_h = sizenew_w = int(w * ratio)img = cv2.resize(img, dsize=(new_w, new_h))if new_w == size:h = (new_h - size) // 2img = img[h: h+size]else:w = (new_w - size) // 2img = img[:, w: w+size]return img