Sijuade commited on
Commit
00e8269
1 Parent(s): 04bbcce

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +200 -0
app.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, torchvision
2
+ from torchvision import transforms
3
+ import numpy as np
4
+ import gradio as gr
5
+ from PIL import Image
6
+ from pytorch_grad_cam import GradCAM
7
+ from pytorch_grad_cam.utils.image import show_cam_on_image
8
+ from model.network import ResNet18
9
+ import matplotlib.pyplot as plt
10
+ import PIL
11
+ import io
12
+ from PIL import Image
13
+
14
+ from model.network import *
15
+ from utils.gradio_utils import *
16
+ from augment.augment import *
17
+ from dataset.dataset import *
18
+
19
+
20
+
21
+ model = ResNet18(20, None)
22
+ model = model.load_from_checkpoint("/content/S12/logs/lightning_logs/version_1/checkpoints/epoch=119-step=23520.ckpt", map_location=torch.device("cpu"))
23
+ model.eval()
24
+
25
+ dataloader_args = dict(shuffle=True, batch_size=64)
26
+ _, test_transforms = get_transforms(mu, std)
27
+
28
+ test = CIFAR10Dataset(transform=test_transforms, train=False)
29
+ test_loader = torch.utils.data.DataLoader(test, **dataloader_args)
30
+
31
+ target_layers = [model.res_block2.conv[-1]]
32
+ targets = None
33
+ device = torch.device("cpu")
34
+
35
+ examples = get_examples()
36
+
37
+ def upload_image_inference(input_img, n_top_classes, transparency):
38
+
39
+ org_img = input_img.copy()
40
+
41
+ input_img = transform(input_img)
42
+ input_img = input_img.unsqueeze(0)
43
+
44
+ outputs = model(input_img)
45
+
46
+ softmax = torch.nn.Softmax(dim=0)
47
+ o = softmax(outputs.flatten())
48
+ confidences = {classes[i]: float(o[i]) for i in range(n_top_classes)}
49
+ _, prediction = torch.max(outputs, 1)
50
+
51
+ cam = GradCAM(model=model, target_layers=target_layers)
52
+
53
+ grayscale_cam = cam(input_tensor=input_img, targets=None)
54
+ grayscale_cam = grayscale_cam[0, :]
55
+ img = input_img.squeeze(0)
56
+ img = inv_normalize(img)
57
+
58
+ rgb_img = np.transpose(img.cpu(), (1, 2, 0))
59
+ rgb_img = rgb_img.numpy()
60
+ visualization = show_cam_on_image(org_img/255, grayscale_cam, use_rgb=True, image_weight=transparency)
61
+
62
+ return([confidences, [org_img, grayscale_cam, visualization]])
63
+
64
+
65
+ def misclass_gr(num_images, layer_val, transparency):
66
+ images_list = misclassified_data[:num_images]
67
+
68
+ images_list = [image_to_array(img, layer_val, transparency) for img in images_list]
69
+ return(images_list)
70
+
71
+
72
+ def class_gr(num_images, layer_val, transparency):
73
+ images_list = classified_data[:num_images]
74
+
75
+ images_list = [image_to_array(img, layer_val, transparency) for img in images_list]
76
+ return(images_list)
77
+
78
+
79
+ def image_to_array(input_img, layer_val, transparency=0.6):
80
+ input_tensor = input_img[0]
81
+
82
+ cam = GradCAM(model=model, target_layers=[model.res_block2.conv[-layer_val]])
83
+ grayscale_cam = cam(input_tensor=input_tensor, targets=targets)
84
+ grayscale_cam = grayscale_cam[0, :]
85
+
86
+ img = input_tensor.squeeze(0)
87
+ img = inv_normalize(img)
88
+ rgb_img = np.transpose(img, (1, 2, 0))
89
+ rgb_img = rgb_img.numpy()
90
+
91
+ visualization = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True,
92
+ image_weight=transparency)
93
+
94
+ plt.imshow(visualization)
95
+ plt.title(r"Correct: " + classes[input_img[1].item()] + '\n' + 'Output: ' + classes[input_img[2].item()])
96
+
97
+ with io.BytesIO() as buffer:
98
+ plt.savefig(buffer, format = "png")
99
+ buffer.seek(0)
100
+ image = Image.open(buffer)
101
+ ar = np.asarray(image)
102
+
103
+ return(ar)
104
+
105
+
106
+ def get_misclassified_data(model, device, test_loader):
107
+ """
108
+ Function to run the model on test set and return misclassified images
109
+ :param model: Network Architecture
110
+ :param device: CPU/GPU
111
+ :param test_loader: DataLoader for test set
112
+ """
113
+ mis_count = 0
114
+ correct_count = 0
115
+
116
+ # Prepare the model for evaluation i.e. drop the dropout layer
117
+ model.eval()
118
+ # List to store misclassified Images
119
+ misclassified_data, classified_data = [], []
120
+ # Reset the gradients
121
+ with torch.no_grad():
122
+ # Extract images, labels in a batch
123
+ for data, target in test_loader:
124
+ # Migrate the data to the device
125
+ data, target = data.to(device), target.to(device)
126
+ # Extract single image, label from the batch
127
+ for image, label in zip(data, target):
128
+ # Add batch dimension to the image
129
+ image = image.unsqueeze(0)
130
+ # Get the model prediction on the image
131
+ output = model(image)
132
+ # Convert the output from one-hot encoding to a value
133
+ pred = output.argmax(dim=1, keepdim=True)
134
+ # If prediction is incorrect, append the data
135
+ if pred != label:
136
+ misclassified_data.append((image, label, pred))
137
+ mis_count += 1
138
+ else:
139
+ classified_data.append((image, label, pred))
140
+ correct_count += 1
141
+
142
+ if ((mis_count>=20) and (correct_count>=20)):
143
+ return ((classified_data, misclassified_data))
144
+
145
+
146
+ title = "CIFAR10 trained on ResNet18 (Pytorch Lightning) Model with GradCAM"
147
+ description = "A simple Gradio interface to infer on ResNet model, get GradCAM results for existing & new Images"
148
+
149
+ with gr.Blocks() as gradcam:
150
+ classified_data, misclassified_data = get_misclassified_data(model, device, test_loader)
151
+
152
+ gr.Markdown("Make Grad-Cam of uploaded image, or existing images.")
153
+ with gr.Tab("Upload New Image"):
154
+ upload_input = [gr.Image(shape=(32, 32)),
155
+ gr.Number(minimum=0, maximum=10, label='n Top Classes', value=3, precision=0),
156
+ gr.Slider(0, 1, label='Transparency', value=0.6)]
157
+
158
+ upload_output = [gr.Label(label='Top Classes'),
159
+ gr.Gallery(label="Image | CAM | Image+CAM",
160
+ show_label=True, elem_id="gallery1").style(columns=[3],
161
+ rows=[1],
162
+ object_fit="contain",
163
+ height="auto")]
164
+ button1 = gr.Button("Perform Inference")
165
+ gr.Examples(
166
+ examples=examples,
167
+ inputs=upload_input,
168
+ outputs=upload_output,
169
+ fn=upload_image_inference,
170
+ cache_examples=True,
171
+ )
172
+
173
+
174
+ with gr.Tab("View Class Activate Maps"):
175
+ with gr.Row():
176
+ with gr.Column():
177
+ cam_input21 = [gr.Number(minimum=1, maximum=20, precision=0, value=3, label='View Correctly Classified CAM | Num Images'),
178
+ gr.Number(minimum=1, maximum=3, precision=0, value=1, label='(-) Target Layer'),
179
+ gr.Slider(0, 1, value=0.6, label='Transparency')]
180
+
181
+ image_output21 = gr.Gallery(label="Images - Grad-CAM (correct)",
182
+ show_label=True, elem_id="gallery21")
183
+ button21 = gr.Button("View Images")
184
+
185
+ with gr.Column():
186
+ cam_input22 = [gr.Number(minimum=1, maximum=20, precision=0, value=3, label='View Misclassified CAM | Num Images'),
187
+ gr.Number(minimum=1, maximum=3, precision=0, value=1, label='(-) Target Layer'),
188
+ gr.Slider(0, 1, value=0.6, label='Transparency')]
189
+
190
+ image_output22 = gr.Gallery(label="Images - Grad-CAM (Misclassified)",
191
+ show_label=True, elem_id="gallery22")
192
+ button22 = gr.Button("View Images")
193
+
194
+ button1.click(upload_image_inference, inputs=upload_input, outputs=upload_output)
195
+ button21.click(class_gr, inputs=cam_input21, outputs=image_output21)
196
+ button22.click(misclass_gr, inputs=cam_input22, outputs=image_output22)
197
+
198
+
199
+
200
+ gradcam.launch()