Sijuade commited on
Commit
04bbcce
1 Parent(s): a1539a5

Create utils/gradio_utils.py

Browse files
Files changed (1) hide show
  1. utils/gradio_utils.py +64 -0
utils/gradio_utils.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torchvision import transforms
3
+ from pytorch_grad_cam import GradCAM
4
+ from pytorch_grad_cam.utils.image import show_cam_on_image
5
+ import matplotlib.pyplot as plt
6
+ import PIL
7
+ import io
8
+ from PIL import Image
9
+ import numpy as np
10
+ import random
11
+
12
+ transform = transforms.ToTensor()
13
+ targets = None
14
+ device = torch.device("cpu")
15
+
16
+
17
+ mu = [0.49139968, 0.48215841, 0.44653091]
18
+ std = [0.24703223, 0.24348513, 0.26158784]
19
+
20
+
21
+ inv_normalize = transforms.Normalize(
22
+ mean=[-0.50/0.23, -0.50/0.23, -0.50/0.23],
23
+ std=[1/0.23, 1/0.23, 1/0.23]
24
+ )
25
+ classes = ('plane', 'car', 'bird', 'cat', 'deer',
26
+ 'dog', 'frog', 'horse', 'ship', 'truck')
27
+
28
+ transform = transforms.ToTensor()
29
+
30
+
31
+ def get_examples():
32
+ example_images = [f'{c}.jpg' for c in classes]
33
+ example_top = [random.randint(0, 9) for r in range(10)]
34
+ example_transparency = [random.choice([0.6, 0.7, 0.8]) for r in range(10)]
35
+ examples = [[example_images[i], example_top[i], example_transparency[i]] for i in range(len(example_images))]
36
+ return(examples)
37
+
38
+
39
+ def image_to_array(input_img, model, layer_val, transparency=0.6):
40
+ input_tensor = input_img[0]
41
+ print(input_tensor.shape)
42
+
43
+ cam = GradCAM(model=model, target_layers=[model.res_block2.conv[-layer_val]])
44
+ grayscale_cam = cam(input_tensor=input_tensor, targets=targets)
45
+ grayscale_cam = grayscale_cam[0, :]
46
+
47
+ img = input_tensor.squeeze(0)
48
+ img = inv_normalize(img)
49
+ rgb_img = np.transpose(img, (1, 2, 0))
50
+ rgb_img = rgb_img.numpy()
51
+
52
+ visualization = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True,
53
+ image_weight=transparency)
54
+
55
+ plt.imshow(visualization)
56
+ plt.title(r"Correct: " + classes[input_img[1].item()] + '\n' + 'Output: ' + classes[input_img[2].item()])
57
+
58
+ with io.BytesIO() as buffer:
59
+ plt.savefig(buffer, format = "png")
60
+ buffer.seek(0)
61
+ image = Image.open(buffer)
62
+ ar = np.asarray(image)
63
+
64
+ return(ar)