File size: 3,409 Bytes
9883e09
 
 
 
 
 
 
 
 
 
 
3fd124e
9883e09
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4b7c2a1
0d3b91f
9883e09
 
 
 
 
 
 
 
 
 
 
0d3b91f
9883e09
 
 
 
 
 
 
490cf0f
98b2ef6
167b47a
490cf0f
98b2ef6
490cf0f
9883e09
 
 
f747a8c
 
a000073
9883e09
 
a000073
98b2ef6
a000073
 
0d3b91f
 
a000073
 
9883e09
 
0d3b91f
a000073
 
 
 
9883e09
a000073
4de1572
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import torch, torchvision
from torchvision import transforms
import numpy as np
import gradio as gr
from PIL import Image
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from resnet import ResNet18
import gradio as gr

model = ResNet18()
model.load_state_dict(torch.load("cifar10_saved_model.pth", map_location=torch.device('cpu')), strict=False)

inv_normalize = transforms.Normalize(
    mean=[-0.50/0.23, -0.50/0.23, -0.50/0.23],
    std=[1/0.23, 1/0.23, 1/0.23]
)
classes = ('plane', 'car', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck')

def resize_image_pil(image, new_width, new_height):

    # Convert to PIL image
    img = Image.fromarray(np.array(image))
    
    # Get original size
    width, height = img.size

    # Calculate scale
    width_scale = new_width / width
    height_scale = new_height / height 
    scale = min(width_scale, height_scale)

    # Resize
    resized = img.resize((int(width*scale), int(height*scale)), Image.NEAREST)
    
    # Crop to exact size
    resized = resized.crop((0, 0, new_width, new_height))

    return resized


def inference(input_img, transparency=0.5, target_layer_number=-1, grad_cam_option="Yes",top_classes=3):
    input_img = resize_image_pil(input_img, 32, 32)
    input_img = np.array(input_img)
    org_img = input_img
    input_img = input_img.reshape((32, 32, 3))
    transform = transforms.ToTensor()
    input_img = transform(input_img)
    input_img = input_img.unsqueeze(0)
    outputs = model(input_img)
    softmax = torch.nn.Softmax(dim=0)
    o = softmax(outputs.flatten())
    confidences = {classes[i]: float(o[i]) for i in range(10)}
    confidences = dict(list(confidences.items())[:top_classes])
    _, prediction = torch.max(outputs, 1)
    target_layers = [model.layer2[target_layer_number]]
    cam = GradCAM(model=model, target_layers=target_layers)
    grayscale_cam = cam(input_tensor=input_img, targets=None)
    grayscale_cam = grayscale_cam[0, :]
    img = input_img.squeeze(0)
    img = inv_normalize(img)
    print('Confidences ',confidences)
    if grad_cam_option == "Yes":
        visualization = show_cam_on_image(org_img/255, grayscale_cam, use_rgb=True, image_weight=transparency)
        return classes[prediction[0].item()], visualization, confidences
    else:
        return classes[prediction[0].item()], None, confidences

title = "CIFAR10 trained on ResNet18 Model with GradCAM"
description = "A simple Gradio interface to infer on ResNet model, and get GradCAM results"
examples = [["cat.jpg", 0.5, -1], ["dog.jpg", 0.5, -1], ["bird_1.jpg", 0.5, -1], ["cat_1.jpg", 0.5, -1], ["cat_2.jpg", 0.5, -1],
            ["dog_1.jpg", 0.5, -1], ["dog_2.jpg", 0.5, -1], ["dog_3.jpg", 0.5, -1], ["ship_1.jpg", 0.5, -1], ["ship_2.jpg", 0.5, -1]]

demo = gr.Interface(
    inference, 
    inputs=[
        gr.Image(width=256, height=256, label="Input Image"), 
        gr.Slider(0, 1, value=0.5, label="Overall Opacity of Image"), 
        gr.Slider(-2, -1, value=-2, step=1, label="Which Layer?"),
        gr.Dropdown(["Yes", "No"], label="Want to see Grad Cam Images?"),
        gr.Number(value=3, minimum=1,maximum=10)
    ], 
    outputs=[
        "text", 
        gr.Image(width=256, height=256, label="Output"),
        gr.Label()
    ],
    title=title,
    description=description,
    examples=examples,
)

demo.launch(share=True)