Spaces:
Runtime error
Runtime error
import matplotlib.pyplot as plt | |
import numpy as np | |
from pytorch_grad_cam import GradCAM | |
from pytorch_grad_cam.utils.image import show_cam_on_image | |
def convert_back_image(image): | |
"""Using mean and std deviation convert image back to normal""" | |
cifar10_mean = (0.4914, 0.4822, 0.4471) | |
cifar10_std = (0.2469, 0.2433, 0.2615) | |
image = image.numpy().astype(dtype=np.float32) | |
for i in range(image.shape[0]): | |
image[i] = (image[i] * cifar10_std[i]) + cifar10_mean[i] | |
# To stop throwing a warning that image pixels exceeds bounds | |
image = image.clip(0, 1) | |
return np.transpose(image, (1, 2, 0)) | |
def plot_sample_training_images(batch_data, batch_label, class_label, num_images=30): | |
"""Function to plot sample images from the training data.""" | |
images, labels = batch_data, batch_label | |
# Calculate the number of images to plot | |
num_images = min(num_images, len(images)) | |
# calculate the number of rows and columns to plot | |
num_cols = 5 | |
num_rows = int(np.ceil(num_images / num_cols)) | |
# Initialize a subplot with the required number of rows and columns | |
fig, axs = plt.subplots(num_rows, num_cols, figsize=(10, 10)) | |
# Iterate through the images and plot them in the grid along with class labels | |
for img_index in range(1, num_images + 1): | |
plt.subplot(num_rows, num_cols, img_index) | |
plt.tight_layout() | |
plt.axis("off") | |
plt.imshow(convert_back_image(images[img_index - 1])) | |
plt.title(class_label[labels[img_index - 1].item()]) | |
plt.xticks([]) | |
plt.yticks([]) | |
return fig, axs | |
def plot_train_test_metrics(results): | |
""" | |
Function to plot the training and test metrics. | |
""" | |
# Extract train_losses, train_acc, test_losses, test_acc from results | |
train_losses = results["train_loss"] | |
train_acc = results["train_acc"] | |
test_losses = results["test_loss"] | |
test_acc = results["test_acc"] | |
# Plot the graphs in a 1x2 grid showing the training and test metrics | |
fig, axs = plt.subplots(1, 2, figsize=(16, 8)) | |
# Loss plot | |
axs[0].plot(train_losses, label="Train") | |
axs[0].plot(test_losses, label="Test") | |
axs[0].set_title("Loss") | |
axs[0].legend(loc="upper right") | |
# Accuracy plot | |
axs[1].plot(train_acc, label="Train") | |
axs[1].plot(test_acc, label="Test") | |
axs[1].set_title("Accuracy") | |
axs[1].legend(loc="upper right") | |
return fig, axs | |
def plot_misclassified_images(data, class_label, num_images=10): | |
"""Plot the misclassified images from the test dataset.""" | |
# Calculate the number of images to plot | |
num_images = min(num_images, len(data["ground_truths"])) | |
# calculate the number of rows and columns to plot | |
num_cols = 5 | |
num_rows = int(np.ceil(num_images / num_cols)) | |
# Initialize a subplot with the required number of rows and columns | |
fig, axs = plt.subplots(num_rows, num_cols, figsize=(num_cols * 2, num_rows * 2)) | |
# Iterate through the images and plot them in the grid along with class labels | |
for img_index in range(1, num_images + 1): | |
# Get the ground truth and predicted labels for the image | |
label = data["ground_truths"][img_index - 1].cpu().item() | |
pred = data["predicted_vals"][img_index - 1].cpu().item() | |
# Get the image | |
image = data["images"][img_index - 1].cpu() | |
# Plot the image | |
plt.subplot(num_rows, num_cols, img_index) | |
plt.tight_layout() | |
plt.axis("off") | |
plt.imshow(convert_back_image(image)) | |
plt.title(f"""ACT: {class_label[label]} \nPRED: {class_label[pred]}""") | |
plt.xticks([]) | |
plt.yticks([]) | |
return fig, axs | |
# Function to plot gradcam for misclassified images using pytorch_grad_cam | |
def plot_gradcam_images( | |
model, | |
data, | |
class_label, | |
target_layers, | |
targets=None, | |
num_images=10, | |
image_weight=0.25, | |
): | |
"""Show gradcam for misclassified images""" | |
# Calculate the number of images to plot | |
num_images = min(num_images, len(data["ground_truths"])) | |
# calculate the number of rows and columns to plot | |
num_cols = 5 | |
num_rows = int(np.ceil(num_images / num_cols)) | |
# Initialize a subplot with the required number of rows and columns | |
fig, axs = plt.subplots(num_rows, num_cols, figsize=(num_cols * 2, num_rows * 2)) | |
# Initialize the GradCAM object | |
# https://github.com/jacobgil/pytorch-grad-cam/blob/master/pytorch_grad_cam/grad_cam.py | |
# https://github.com/jacobgil/pytorch-grad-cam/blob/master/pytorch_grad_cam/base_cam.py | |
# Alert: Change the device to cpu for gradio app | |
cam = GradCAM(model=model, target_layers=target_layers, use_cuda=False) | |
# Iterate through the images and plot them in the grid along with class labels | |
for img_index in range(1, num_images + 1): | |
# Extract elements from the data dictionary | |
# Get the ground truth and predicted labels for the image | |
label = data["ground_truths"][img_index - 1].cpu().item() | |
pred = data["predicted_vals"][img_index - 1].cpu().item() | |
# Get the image | |
image = data["images"][img_index - 1].cpu() | |
# Get the GradCAM output | |
# https://github.com/jacobgil/pytorch-grad-cam/blob/master/pytorch_grad_cam/utils/model_targets.py | |
grad_cam_output = cam( | |
input_tensor=image.unsqueeze(0), | |
targets=targets, | |
aug_smooth=True, | |
eigen_smooth=True, | |
) | |
grad_cam_output = grad_cam_output[0, :] | |
# Overlay gradcam on top of numpy image | |
overlayed_image = show_cam_on_image( | |
convert_back_image(image), | |
grad_cam_output, | |
use_rgb=True, | |
image_weight=image_weight, | |
) | |
# Plot the image | |
plt.subplot(num_rows, num_cols, img_index) | |
plt.tight_layout() | |
plt.axis("off") | |
plt.imshow(overlayed_image) | |
plt.title(f"""ACT: {class_label[label]} \nPRED: {class_label[pred]}""") | |
plt.xticks([]) | |
plt.yticks([]) | |
return fig, axs | |