File size: 3,600 Bytes
0fb76bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
853bc3f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
df0fc51
47f1ab8
853bc3f
47f1ab8
 
853bc3f
 
 
 
0fb76bb
853bc3f
 
 
0fb76bb
853bc3f
 
 
 
 
 
 
 
 
 
 
0fb76bb
 
853bc3f
 
63c77de
853bc3f
 
0fb76bb
 
 
 
 
 
853bc3f
 
 
 
 
 
 
 
 
0fb76bb
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import torch
import torchvision
from torchvision import transforms
import gradio as gr
import numpy as np
from PIL import Image
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from resnet import ResNet18

model = ResNet18()
model.load_state_dict(torch.load("model.pth", map_location=torch.device('cpu') ), strict=False)


inv_normalize = transforms.Normalize(
    mean=[-0.50/0.23, -0.50/0.23, -0.50/0.23],
    std = [1/0.23, 1/0.23, 1/0.23]
)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

def resize_image_pil(image, new_width, new_height):

    # convert to PIL IMage
    img = Image.fromarray(np.array(image))
    # get original size
    width, height = img.size

    # calculate scale
    width_scale = new_width/width
    height_scale = new_height/height

    scale = min(width_scale, height_scale)

    # resize
    resized = img.resize(size=(int(width*scale), int(height*scale)), resample=Image.NEAREST)

    # crop resized image
    resized = resized.crop((0, 0, new_width, new_height))

    return resized

# def inference(input_img, transparency):
#     transform = transforms.ToTensor()
#     input_img = transform(input_img)
#     input_img = input_img.to(device)
#     input_img = input_img.unsqueeze(0)
#     outputs = model(input_img)
#     _, prediction = torch.max(outputs, 1)
#     target_layers = [model.layer2[-2]]
#     cam = GradCAM(model=model, target_layers=target_layers, use_cuda=True)
#     grayscale_cam = cam(input_tensor=input_img, targets=targets)
#     grayscale_cam = grayscale_cam[0, :]
#     img = input_img.squeeze(0).to('cpu')
#     img = inv_normalize(img)
#     rgb_img = np.transpose(img, (1, 2, 0))
#     rgb_img = rgb_img.numpy()
#     visualization = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True, image_weight=transparency)
#     return classes[prediction[0].item()], visualization

def inference(input_img, transparency=0.5, target_layer_number=-1):
    input_img = resize_image_pil(input_img, 32, 32)
    input_img = np.array(input_img)
    org_img= input_img
    
    input_img = input_img.reshape((32, 32, 3))
    
    transform = transforms.ToTensor()
    input_img = transform(input_img)
    
    input_img = input_img.unsqueeze(0)
    outputs = model(input_img)

    softmax = torch.nn.Softmax(dim=0)
    o = softmax(outputs.flatten())
    confidences = {classes[i] : float(o[i]) for i in range(10)}
    _, prediction = torch.max(outputs, 1)
    target_layers = [model.layer2[target_layer_number]]
    cam = GradCAM(model=model, target_layers = target_layers)
    grayscale_cam = cam(input_tensor=input_img, targets=None)
    grayscale_cam = grayscale_cam[0, :]
    visualization = show_cam_on_image(
        org_img/255,
        grayscale_cam,
        use_rgb=True,
        image_weight=transparency
    )

    return classes[prediction[0].item()], visualization, confidences
    



demo = gr.Interface(
    fn=inference,
    inputs=[
        gr.Image(width=256, height=256, label="Input Image"),
        gr.Slider(0,1, value=0.5, label="Overall opacity value"),
        gr.Slider(-2, -1, value=-2, label="Which model layer to use for GradCAM?")
    ],
    outputs = [
        "text",
        gr.Image(width=256, height=256, label="Output"),
        gr.Label(num_top_classes=3)
    ],

    title="CIFAR10 trained on ResNet18 with GradCAM",
    
    description = "A simple Gradio interface to infer on ResNet model with GradCAM results shown on top.",
    
    examples = [
    ["cat.jpg", 0.5, -1],
    ["dog.jpg", 0.7, -2]
]
)

demo.launch()