Sijuade commited on
Commit
14a8b64
1 Parent(s): dd24cd4

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +199 -0
app.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, torchvision
2
+ from torchvision import transforms
3
+ import numpy as np
4
+ import gradio as gr
5
+ from PIL import Image
6
+ from pytorch_grad_cam import GradCAM
7
+ from pytorch_grad_cam.utils.image import show_cam_on_image
8
+ from model.network import ResNet18
9
+ import matplotlib.pyplot as plt
10
+ import PIL
11
+ import io
12
+ from PIL import Image
13
+
14
+ from model.network import *
15
+ from utils.gradio_utils import *
16
+ from augment.augment import *
17
+ from dataset.dataset import *
18
+
19
+
20
+
21
+ model = ResNet18(20, None)
22
+ model = model.load_from_checkpoint("resnet18.ckpt", map_location=torch.device("cpu"))
23
+
24
+ dataloader_args = dict(shuffle=True, batch_size=64)
25
+ _, test_transforms = get_transforms(mu, std)
26
+
27
+ test = CIFAR10Dataset(transform=test_transforms, train=False)
28
+ test_loader = torch.utils.data.DataLoader(test, **dataloader_args)
29
+
30
+ target_layers = [model.res_block2.conv[-1]]
31
+ targets = None
32
+ device = torch.device("cpu")
33
+
34
+ examples = get_examples()
35
+
36
+ def upload_image_inference(input_img, n_top_classes, transparency):
37
+
38
+ org_img = input_img.copy()
39
+
40
+ input_img = test_transforms(image=org_img)['image']
41
+ input_img = input_img.unsqueeze(0)
42
+
43
+ outputs = model(input_img)
44
+
45
+ softmax = torch.nn.Softmax(dim=0)
46
+ o = softmax(outputs.flatten())
47
+ confidences = {classes[i]: float(o[i]) for i in range(n_top_classes)}
48
+ _, prediction = torch.max(outputs, 1)
49
+
50
+ cam = GradCAM(model=model, target_layers=target_layers)
51
+
52
+ grayscale_cam = cam(input_tensor=input_img, targets=None)
53
+ grayscale_cam = grayscale_cam[0, :]
54
+ img = input_img.squeeze(0)
55
+ img = inv_normalize(img)
56
+
57
+ rgb_img = np.transpose(img.cpu(), (1, 2, 0))
58
+ rgb_img = rgb_img.numpy()
59
+ visualization = show_cam_on_image(org_img/255, grayscale_cam, use_rgb=True, image_weight=transparency)
60
+
61
+ return([confidences, [org_img, grayscale_cam, visualization]])
62
+
63
+
64
+ def misclass_gr(num_images, layer_val, transparency):
65
+ images_list = misclassified_data[:num_images]
66
+
67
+ images_list = [image_to_array(img, layer_val, transparency) for img in images_list]
68
+ return(images_list)
69
+
70
+
71
+ def class_gr(num_images, layer_val, transparency):
72
+ images_list = classified_data[:num_images]
73
+
74
+ images_list = [image_to_array(img, layer_val, transparency) for img in images_list]
75
+ return(images_list)
76
+
77
+
78
+ def image_to_array(input_img, layer_val, transparency=0.6):
79
+ input_tensor = input_img[0]
80
+
81
+ cam = GradCAM(model=model, target_layers=[model.res_block2.conv[-layer_val]])
82
+ grayscale_cam = cam(input_tensor=input_tensor, targets=targets)
83
+ grayscale_cam = grayscale_cam[0, :]
84
+
85
+ img = input_tensor.squeeze(0)
86
+ img = inv_normalize(img)
87
+ rgb_img = np.transpose(img, (1, 2, 0))
88
+ rgb_img = rgb_img.numpy()
89
+
90
+ visualization = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True,
91
+ image_weight=transparency)
92
+
93
+ plt.imshow(visualization)
94
+ plt.title(r"Correct: " + classes[input_img[1].item()] + '\n' + 'Output: ' + classes[input_img[2].item()])
95
+
96
+ with io.BytesIO() as buffer:
97
+ plt.savefig(buffer, format = "png")
98
+ buffer.seek(0)
99
+ image = Image.open(buffer)
100
+ ar = np.asarray(image)
101
+
102
+ return(ar)
103
+
104
+
105
+ def get_misclassified_data(model, device, test_loader):
106
+ """
107
+ Function to run the model on test set and return misclassified images
108
+ :param model: Network Architecture
109
+ :param device: CPU/GPU
110
+ :param test_loader: DataLoader for test set
111
+ """
112
+ mis_count = 0
113
+ correct_count = 0
114
+
115
+ # Prepare the model for evaluation i.e. drop the dropout layer
116
+ model.eval()
117
+ # List to store misclassified Images
118
+ misclassified_data, classified_data = [], []
119
+ # Reset the gradients
120
+ with torch.no_grad():
121
+ # Extract images, labels in a batch
122
+ for data, target in test_loader:
123
+ # Migrate the data to the device
124
+ data, target = data.to(device), target.to(device)
125
+ # Extract single image, label from the batch
126
+ for image, label in zip(data, target):
127
+ # Add batch dimension to the image
128
+ image = image.unsqueeze(0)
129
+ # Get the model prediction on the image
130
+ output = model(image)
131
+ # Convert the output from one-hot encoding to a value
132
+ pred = output.argmax(dim=1, keepdim=True)
133
+ # If prediction is incorrect, append the data
134
+ if pred != label:
135
+ misclassified_data.append((image, label, pred))
136
+ mis_count += 1
137
+ else:
138
+ classified_data.append((image, label, pred))
139
+ correct_count += 1
140
+
141
+ if ((mis_count>=20) and (correct_count>=20)):
142
+ return ((classified_data, misclassified_data))
143
+
144
+
145
+ title = "CIFAR10 trained on ResNet18 (Pytorch Lightning) Model with GradCAM"
146
+ description = "A simple Gradio interface to infer on ResNet model, get GradCAM results for existing & new Images"
147
+
148
+ with gr.Blocks() as gradcam:
149
+ classified_data, misclassified_data = get_misclassified_data(model, device, test_loader)
150
+
151
+ gr.Markdown("Make Grad-Cam of uploaded image, or existing images.")
152
+ with gr.Tab("Upload New Image"):
153
+ upload_input = [gr.Image(shape=(32, 32)),
154
+ gr.Number(minimum=0, maximum=10, label='n Top Classes', value=3, precision=0),
155
+ gr.Slider(0, 1, label='Transparency', value=0.6)]
156
+
157
+ upload_output = [gr.Label(label='Top Classes'),
158
+ gr.Gallery(label="Image | CAM | Image+CAM",
159
+ show_label=True, min_width=80).style(columns=[3],
160
+ rows=[1],
161
+ object_fit="contain",
162
+ height="auto")]
163
+ button1 = gr.Button("Perform Inference")
164
+ gr.Examples(
165
+ examples=examples,
166
+ inputs=upload_input,
167
+ outputs=upload_output,
168
+ fn=upload_image_inference,
169
+ cache_examples=True,
170
+ )
171
+
172
+
173
+ with gr.Tab("View Class Activate Maps"):
174
+ with gr.Row():
175
+ with gr.Column():
176
+ cam_input21 = [gr.Number(minimum=1, maximum=20, precision=0, value=3, label='View Correctly Classified CAM | Num Images'),
177
+ gr.Number(minimum=1, maximum=3, precision=0, value=1, label='(-) Target Layer'),
178
+ gr.Slider(0, 1, value=0.6, label='Transparency')]
179
+
180
+ image_output21 = gr.Gallery(label="Images - Grad-CAM (correct)",
181
+ show_label=True, min_width=80)
182
+ button21 = gr.Button("View Images")
183
+
184
+ with gr.Column():
185
+ cam_input22 = [gr.Number(minimum=1, maximum=20, precision=0, value=3, label='View Misclassified CAM | Num Images'),
186
+ gr.Number(minimum=1, maximum=3, precision=0, value=1, label='(-) Target Layer'),
187
+ gr.Slider(0, 1, value=0.6, label='Transparency')]
188
+
189
+ image_output22 = gr.Gallery(label="Images - Grad-CAM (Misclassified)",
190
+ show_label=True, min_width=80)
191
+ button22 = gr.Button("View Images")
192
+
193
+ button1.click(upload_image_inference, inputs=upload_input, outputs=upload_output)
194
+ button21.click(class_gr, inputs=cam_input21, outputs=image_output21)
195
+ button22.click(misclass_gr, inputs=cam_input22, outputs=image_output22)
196
+
197
+
198
+
199
+ gradcam.launch()