# Outline # Import packages # Import modules # Constants # Load model # Function to process user uploaded image/ examples # Inference function # Gradio examples # Gradio App # Import packages required for the app import gradio as gr # Import custom modules import modules.config as config import numpy as np import torch # import torchvision from modules.custom_resnet import CustomResNet from modules.visualize import plot_gradcam_images, plot_misclassified_images from pytorch_grad_cam import GradCAM from pytorch_grad_cam.utils.image import show_cam_on_image from torchvision import transforms # Load and initialize the model model = CustomResNet() # Define device cpu = torch.device("cpu") # Using the checkpoint path present in config, load the trained model model.load_state_dict(torch.load(config.MODEL_PATH, map_location=cpu), strict=False) # Send model to CPU model.to(cpu) # Make the model in evaluation mode model.eval() print(f"Model Device: {next(model.parameters()).device}") # Load the misclassified images data misclassified_image_data = torch.load(config.MISCLASSIFIED_PATH, map_location=cpu) # Class Names classes = list(config.CIFAR_CLASSES) # Allowed model names model_layer_names = ["prep", "layer1_x", "layer1_r1", "layer2", "layer3_x", "layer3_r2"] def get_target_layer(layer_name): """Get target layer for visualization""" if layer_name == "prep": return [model.prep[-1]] elif layer_name == "layer1_x": return [model.layer1_x[-1]] elif layer_name == "layer1_r1": return [model.layer1_r1[-1]] elif layer_name == "layer2": return [model.layer2[-1]] elif layer_name == "layer3_x": return [model.layer3_x[-1]] elif layer_name == "layer3_r2": return [model.layer3_r2[-1]] else: return None def generate_prediction(input_image, num_classes=3, show_gradcam=True, transparency=0.6, layer_name="layer3_x"): """ "Given an input image, generate the prediction, confidence and visualization""" mean = list(config.CIFAR_MEAN) std = list(config.CIFAR_STD) transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)]) with torch.no_grad(): orginal_img = input_image input_image = transform(input_image).unsqueeze(0).to(cpu) print(f"Input Device: {input_image.device}") outputs = model(input_image).to(cpu) print(f"Output Device: {outputs.device}") o = torch.exp(outputs).to(cpu) print(f"Output Exp Device: {o.device}") o_np = np.squeeze(np.asarray(o.numpy())) # get indexes of probabilties in descending order sorted_indexes = np.argsort(o_np)[::-1] # sort the probabilities in descending order final_class = classes[o_np.argmax()] confidences = {} for cnt in range(int(num_classes)): # set the confidence of highest class with highest probability confidences[classes[sorted_indexes[cnt]]] = float(o_np[sorted_indexes[cnt]]) # Show Grad Cam if show_gradcam: # Get the target layer target_layers = get_target_layer(layer_name) cam = GradCAM(model=model, target_layers=target_layers, use_cuda=False) grayscale_cam = cam(input_tensor=input_image, targets=None) grayscale_cam = grayscale_cam[0, :] visualization = show_cam_on_image(orginal_img / 255, grayscale_cam, use_rgb=True, image_weight=transparency) else: visualization = orginal_img return final_class, confidences, visualization def app_interface( input_image, num_classes, show_gradcam, layer_name, transparency, show_misclassified, num_misclassified, show_gradcam_misclassified, num_gradcam_misclassified, ): """Function which provides the Gradio interface""" # Get the prediction for the input image along with confidence and visualization final_class, confidences, visualization = generate_prediction( input_image, num_classes, show_gradcam, transparency, layer_name ) if show_misclassified: misclassified_fig, misclassified_axs = plot_misclassified_images( data=misclassified_image_data, class_label=classes, num_images=num_misclassified ) else: misclassified_fig = None if show_gradcam_misclassified: gradcam_fig, gradcam_axs = plot_gradcam_images( model=model, data=misclassified_image_data, class_label=classes, # Use penultimate block of resnet18 layer 3 as the target layer for gradcam # Decided using model summary so that dimensions > 7x7 target_layers=get_target_layer(layer_name), targets=None, num_images=num_gradcam_misclassified, image_weight=transparency, ) else: gradcam_fig = None # # delete ununsed axises # del misclassified_axs # del gradcam_axs return final_class, confidences, visualization, misclassified_fig, gradcam_fig TITLE = "CIFAR10 Image classification using a Custom ResNet Model" DESCRIPTION = "Gradio App to infer using a Custom ResNet model and get GradCAM results" examples = [ ["assets/images/airplane.jpg", 3, True, "layer3_x", 0.6, True, 5, True, 5], ["assets/images/bird.jpeg", 4, True, "layer3_x", 0.7, True, 10, True, 20], ["assets/images/car.jpg", 5, True, "layer3_x", 0.5, True, 15, True, 5], ["assets/images/cat.jpeg", 6, True, "layer3_x", 0.65, True, 20, True, 10], ["assets/images/deer.jpg", 7, False, "layer2", 0.75, True, 5, True, 5], ["assets/images/dog.jpg", 8, True, "layer2", 0.55, True, 10, True, 5], ["assets/images/frog.jpeg", 9, True, "layer2", 0.8, True, 15, True, 15], ["assets/images/horse.jpg", 10, False, "layer1_r1", 0.85, True, 20, True, 5], ["assets/images/ship.jpg", 3, True, "layer1_r1", 0.4, True, 5, True, 15], ["assets/images/truck.jpg", 4, True, "layer1_r1", 0.3, True, 5, True, 10], ] inference_app = gr.Interface( app_interface, inputs=[ # This accepts the image after resizing it to 32x32 which is what our model expects gr.Image(shape=(32, 32)), gr.Number(value=3, maximum=10, minimum=1, step=1.0, precision=0, label="#Classes to show"), gr.Checkbox(True, label="Show GradCAM Image"), gr.Dropdown(model_layer_names, value="layer3_x", label="Visulalization Layer from Model"), # How much should the image be overlayed on the original image gr.Slider(0, 1, 0.6, label="Image Overlay Factor"), gr.Checkbox(True, label="Show Misclassified Images?"), gr.Slider(value=10, maximum=25, minimum=5, step=5.0, precision=0, label="#Misclassified images to show"), gr.Checkbox(True, label="Visulize GradCAM for Misclassified images?"), gr.Slider(value=10, maximum=25, minimum=5, step=5.0, precision=0, label="#GradCAM images to show"), ], outputs=[ gr.Textbox(label="Top Class", container=True), gr.Label(label="Confidences", container=True), gr.Image(shape=(32, 32), label="Grad CAM/ Input Image", container=True).style(width=256, height=256), gr.Plot(label="Misclassified images", container=True), gr.Plot(label="Grad CAM of Misclassified images"), ], title=TITLE, description=DESCRIPTION, examples=examples, ) inference_app.launch()