AlzheimerDetection / utils.py
Jiranuwat's picture
Upload 10 files
201936b verified
raw
history blame
7.49 kB
import torch
from torchvision import transforms
import matplotlib.pyplot as plt
import cv2
import numpy as np
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
class split_white_and_gray():
def __init__(self,threshold=120) -> None:
"""
Initialize the class with a threshold value.
Args:
threshold (int, optional): The threshold value to be set. Defaults to 120.
"""
self.threshold = threshold
def __call__(self,tensor):
"""
Apply thresholding to the input tensor and return the white matter, gray matter, and the original tensor.
Parameters:
tensor (torch.Tensor): The input tensor to be thresholded.
Returns:
torch.Tensor: The thresholded white matter.
torch.Tensor: The thresholded gray matter.
torch.Tensor: The original input tensor.
"""
tensor = (tensor*255).to(torch.int64)
# Apply thresholding
white_matter = torch.where(tensor >= self.threshold,tensor,0)
white_matter = (white_matter/255).to(torch.float64)
gray_matter = torch.where(tensor < self.threshold,tensor,0)
gray_matter = (gray_matter/255).to(torch.float64)
tensor = (tensor/255).to(torch.float64)
return white_matter, gray_matter,tensor
def showcam_withoutmask(original_image, grayscale_cam, image_title='Original Image'):
"""This function applies the CAM mask to the original image and returns the Matplotlib Figure object.
:param original_image: The original image tensor in PyTorch format.
:param grayscale_cam: The CAM mask tensor in PyTorch format.
:return: Matplotlib Figure object.
"""
# Assuming you have two tensors: 'original_image' and 'cam_mask'
# Make sure both tensors are on the CPU
original_image = torch.squeeze(original_image).cpu() # torch.Size([3, 150, 150])
cam_mask = grayscale_cam.cpu() # torch.Size([1, 150, 150])
# Convert the tensors to NumPy arrays
original_image_np = original_image.numpy()
cam_mask_np = cam_mask.numpy()
# Apply the mask to the original image
masked_image = original_image_np * cam_mask_np
# Normalize the masked_image
masked_image_norm = (masked_image - np.min(masked_image)) / (np.max(masked_image) - np.min(masked_image))
# Create Matplotlib Figure
fig, axes = plt.subplots(1, 3, figsize=(18, 5))
# Plot the original image
axes[0].imshow(original_image_np.transpose(1, 2, 0)) # Assuming your original image is in (C, H, W) format
axes[0].set_title(image_title)
# Plot the CAM mask
axes[1].imshow(cam_mask_np[0], cmap='jet') # Assuming your mask is grayscale
axes[1].set_title('CAM Mask')
# Plot the overlay (normalized)
axes[2].imshow(masked_image_norm.transpose(1, 2, 0)) # Assuming your original image is in (C, H, W) format
axes[2].set_title('Overlay (Normalized)')
return fig
def showcam_withmask(img_tensor: torch.Tensor,
mask_tensor: torch.Tensor,
use_rgb: bool = False,
colormap: int = cv2.COLORMAP_JET,
image_weight: float = 0.5,
image_title: str = 'Original Image') -> plt.Figure:
""" This function overlays the CAM mask on the image as a heatmap and returns the Figure object.
By default, the heatmap is in BGR format.
:param img_tensor: The base image tensor in PyTorch format.
:param mask_tensor: The CAM mask tensor in PyTorch format.
:param use_rgb: Whether to use an RGB or BGR heatmap; set to True if 'img_tensor' is in RGB format.
:param colormap: The OpenCV colormap to be used.
:param image_weight: The final result is image_weight * img + (1-image_weight) * mask.
:return: Matplotlib Figure object.
"""
# Convert PyTorch tensors to NumPy arrays
img = img_tensor.cpu().numpy().transpose(1, 2, 0)
mask = mask_tensor.cpu().numpy()
# Convert the mask to a single-channel image
mask_single_channel = np.uint8(255 * mask[0])
heatmap = cv2.applyColorMap(mask_single_channel, colormap)
if use_rgb:
heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
heatmap = np.float32(heatmap) / 255
if np.max(img) > 1:
raise Exception("The input image should be in the range [0, 1]")
if image_weight < 0 or image_weight > 1:
raise Exception(f"image_weight should be in the range [0, 1]. Got: {image_weight}")
cam = (1 - image_weight) * heatmap + image_weight * img
cam = cam / np.max(cam)
# Create Matplotlib Figure
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
# Plot the original image
axes[0].imshow(img)
axes[0].set_title(image_title)
# Plot the CAM mask
axes[1].imshow(mask[0], cmap='jet')
axes[1].set_title('CAM Mask')
# Plot the overlay
axes[2].imshow(cam)
axes[2].set_title('Overlay')
return fig
def predict_and_gradcam(pil_image, model, target=100, plot_type='withmask'):
transform = transforms.Compose([
transforms.Resize((150, 150)),
transforms.Grayscale(num_output_channels=3),
transforms.ToTensor(),
split_white_and_gray(120),
])
white_matter_tensor, gray_matter_tensor, origin_tensor = transform(pil_image)
white_matter_tensor, gray_matter_tensor, origin_tensor = white_matter_tensor.unsqueeze(0).to(torch.float32),\
gray_matter_tensor.unsqueeze(0).to(torch.float32),\
origin_tensor.unsqueeze(0).to(torch.float32)
def calculate_gradcammask(model_grad, input_tensor):
target_layer = [model_grad.layer4[-1]]
gradcam = GradCAM(model=model_grad, target_layers=target_layer)
targets = [ClassifierOutputTarget(target)]
grayscale_cam = gradcam(input_tensor=input_tensor, targets=targets, aug_smooth=True, eigen_smooth=True)
grayscale_cam = torch.tensor(grayscale_cam)
return grayscale_cam
origin_model = model.resnet18_model
white_model = model.whitematter_resnet18_model
gray_model = model.graymatter_resnet18_model
origin_cam = calculate_gradcammask(origin_model, origin_tensor)
white_cam = calculate_gradcammask(white_model, white_matter_tensor)
gray_cam = calculate_gradcammask(gray_model, gray_matter_tensor)
class_idx = {0: 'Moderate Demented', 1: 'Mild Demented', 2: 'Very Mild Demented', 3: 'Non Demented'}
prediction = model(white_matter_tensor, gray_matter_tensor, origin_tensor)
predicted_class_index = torch.argmax(prediction).item()
predicted_class_label = class_idx[predicted_class_index]
if plot_type == 'withmask':
return predicted_class_label, showcam_withmask(torch.squeeze(origin_tensor), origin_cam),\
showcam_withmask(torch.squeeze(white_matter_tensor), white_cam, image_title='White Matter'),\
showcam_withmask(torch.squeeze(gray_matter_tensor), gray_cam, image_title='Gray Matter')
elif plot_type == 'withoutmask':
return predicted_class_label, showcam_withoutmask(torch.squeeze(origin_tensor),origin_cam),\
showcam_withoutmask(torch.squeeze(white_matter_tensor),white_cam, image_title='White Matter'),\
showcam_withoutmask(torch.squeeze(gray_matter_tensor),gray_cam , image_title='Gray Matter')
else:
raise ValueError("plot_type must be either 'withmask' or 'withoutmask'")