import torch from torchvision import transforms import matplotlib.pyplot as plt import cv2 import numpy as np from pytorch_grad_cam import GradCAM from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget class split_white_and_gray(): def __init__(self,threshold=120) -> None: """ Initialize the class with a threshold value. Args: threshold (int, optional): The threshold value to be set. Defaults to 120. """ self.threshold = threshold def __call__(self,tensor): """ Apply thresholding to the input tensor and return the white matter, gray matter, and the original tensor. Parameters: tensor (torch.Tensor): The input tensor to be thresholded. Returns: torch.Tensor: The thresholded white matter. torch.Tensor: The thresholded gray matter. torch.Tensor: The original input tensor. """ tensor = (tensor*255).to(torch.int64) # Apply thresholding white_matter = torch.where(tensor >= self.threshold,tensor,0) white_matter = (white_matter/255).to(torch.float64) gray_matter = torch.where(tensor < self.threshold,tensor,0) gray_matter = (gray_matter/255).to(torch.float64) tensor = (tensor/255).to(torch.float64) return white_matter, gray_matter,tensor def showcam_withoutmask(original_image, grayscale_cam, image_title='Original Image'): """This function applies the CAM mask to the original image and returns the Matplotlib Figure object. :param original_image: The original image tensor in PyTorch format. :param grayscale_cam: The CAM mask tensor in PyTorch format. :return: Matplotlib Figure object. """ # Assuming you have two tensors: 'original_image' and 'cam_mask' # Make sure both tensors are on the CPU original_image = torch.squeeze(original_image).cpu() # torch.Size([3, 150, 150]) cam_mask = grayscale_cam.cpu() # torch.Size([1, 150, 150]) # Convert the tensors to NumPy arrays original_image_np = original_image.numpy() cam_mask_np = cam_mask.numpy() # Apply the mask to the original image masked_image = original_image_np * cam_mask_np # Normalize the masked_image masked_image_norm = (masked_image - np.min(masked_image)) / (np.max(masked_image) - np.min(masked_image)) # Create Matplotlib Figure fig, axes = plt.subplots(1, 3, figsize=(18, 5)) # Plot the original image axes[0].imshow(original_image_np.transpose(1, 2, 0)) # Assuming your original image is in (C, H, W) format axes[0].set_title(image_title) # Plot the CAM mask axes[1].imshow(cam_mask_np[0], cmap='jet') # Assuming your mask is grayscale axes[1].set_title('CAM Mask') # Plot the overlay (normalized) axes[2].imshow(masked_image_norm.transpose(1, 2, 0)) # Assuming your original image is in (C, H, W) format axes[2].set_title('Overlay (Normalized)') return fig def showcam_withmask(img_tensor: torch.Tensor, mask_tensor: torch.Tensor, use_rgb: bool = False, colormap: int = cv2.COLORMAP_JET, image_weight: float = 0.5, image_title: str = 'Original Image') -> plt.Figure: """ This function overlays the CAM mask on the image as a heatmap and returns the Figure object. By default, the heatmap is in BGR format. :param img_tensor: The base image tensor in PyTorch format. :param mask_tensor: The CAM mask tensor in PyTorch format. :param use_rgb: Whether to use an RGB or BGR heatmap; set to True if 'img_tensor' is in RGB format. :param colormap: The OpenCV colormap to be used. :param image_weight: The final result is image_weight * img + (1-image_weight) * mask. :return: Matplotlib Figure object. """ # Convert PyTorch tensors to NumPy arrays img = img_tensor.cpu().numpy().transpose(1, 2, 0) mask = mask_tensor.cpu().numpy() # Convert the mask to a single-channel image mask_single_channel = np.uint8(255 * mask[0]) heatmap = cv2.applyColorMap(mask_single_channel, colormap) if use_rgb: heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB) heatmap = np.float32(heatmap) / 255 if np.max(img) > 1: raise Exception("The input image should be in the range [0, 1]") if image_weight < 0 or image_weight > 1: raise Exception(f"image_weight should be in the range [0, 1]. Got: {image_weight}") cam = (1 - image_weight) * heatmap + image_weight * img cam = cam / np.max(cam) # Create Matplotlib Figure fig, axes = plt.subplots(1, 3, figsize=(15, 5)) # Plot the original image axes[0].imshow(img) axes[0].set_title(image_title) # Plot the CAM mask axes[1].imshow(mask[0], cmap='jet') axes[1].set_title('CAM Mask') # Plot the overlay axes[2].imshow(cam) axes[2].set_title('Overlay') return fig def predict_and_gradcam(pil_image, model, target=100, plot_type='withmask'): transform = transforms.Compose([ transforms.Resize((150, 150)), transforms.Grayscale(num_output_channels=3), transforms.ToTensor(), split_white_and_gray(120), ]) white_matter_tensor, gray_matter_tensor, origin_tensor = transform(pil_image) white_matter_tensor, gray_matter_tensor, origin_tensor = white_matter_tensor.unsqueeze(0).to(torch.float32),\ gray_matter_tensor.unsqueeze(0).to(torch.float32),\ origin_tensor.unsqueeze(0).to(torch.float32) def calculate_gradcammask(model_grad, input_tensor): target_layer = [model_grad.layer4[-1]] gradcam = GradCAM(model=model_grad, target_layers=target_layer) targets = [ClassifierOutputTarget(target)] grayscale_cam = gradcam(input_tensor=input_tensor, targets=targets, aug_smooth=True, eigen_smooth=True) grayscale_cam = torch.tensor(grayscale_cam) return grayscale_cam origin_model = model.resnet18_model white_model = model.whitematter_resnet18_model gray_model = model.graymatter_resnet18_model origin_cam = calculate_gradcammask(origin_model, origin_tensor) white_cam = calculate_gradcammask(white_model, white_matter_tensor) gray_cam = calculate_gradcammask(gray_model, gray_matter_tensor) class_idx = {0: 'Moderate Demented', 1: 'Mild Demented', 2: 'Very Mild Demented', 3: 'Non Demented'} prediction = model(white_matter_tensor, gray_matter_tensor, origin_tensor) predicted_class_index = torch.argmax(prediction).item() predicted_class_label = class_idx[predicted_class_index] if plot_type == 'withmask': return predicted_class_label, showcam_withmask(torch.squeeze(origin_tensor), origin_cam),\ showcam_withmask(torch.squeeze(white_matter_tensor), white_cam, image_title='White Matter'),\ showcam_withmask(torch.squeeze(gray_matter_tensor), gray_cam, image_title='Gray Matter') elif plot_type == 'withoutmask': return predicted_class_label, showcam_withoutmask(torch.squeeze(origin_tensor),origin_cam),\ showcam_withoutmask(torch.squeeze(white_matter_tensor),white_cam, image_title='White Matter'),\ showcam_withoutmask(torch.squeeze(gray_matter_tensor),gray_cam , image_title='Gray Matter') else: raise ValueError("plot_type must be either 'withmask' or 'withoutmask'")