dawn17 commited on
Commit
ae11e16
1 Parent(s): 7683bed

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +170 -0
app.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List
3
+
4
+ import gradio as gr
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from models.david_page import DavidPageNet
9
+ from PIL import Image
10
+ from pytorch_grad_cam import GradCAM
11
+ from pytorch_grad_cam.utils.image import show_cam_on_image
12
+ from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
13
+ from torchvision import transforms
14
+
15
+
16
+ # imagenet mean and std
17
+ mean = [0.485, 0.456, 0.406]
18
+ std = [0.229, 0.224, 0.225]
19
+
20
+ inv_mean = [-mean / std for mean, std in zip(mean, std)]
21
+ inv_std = [1 / s for s in std]
22
+
23
+ # transforms
24
+ transform = transforms.Compose([
25
+ transforms.ToTensor(),
26
+ transforms.Normalize(mean=mean, std=std),
27
+
28
+ ]
29
+ )
30
+
31
+ inv_normalize = transforms.Normalize(mean=inv_mean, std=inv_std)
32
+
33
+ classes = [
34
+ "plane",
35
+ "car",
36
+ "bird",
37
+ "cat",
38
+ "deer",
39
+ "dog",
40
+ "frog",
41
+ "horse",
42
+ "ship",
43
+ "truck",
44
+ ]
45
+
46
+
47
+ class Gradio:
48
+ def __init__(self, model_path: str):
49
+ use_cuda = torch.cuda.is_available()
50
+ self.device = torch.device("cuda" if use_cuda else "cpu")
51
+ self.model = self.load_model(model_path)
52
+ self.temperature = 2
53
+
54
+ def load_model(self, model_path: str):
55
+ model = DavidPageNet().to(self.device)
56
+
57
+ if os.path.isfile(model_path):
58
+ model.load_state_dict(
59
+ torch.load(model_path)["model_state_dict"], strict=False
60
+ )
61
+
62
+ return model
63
+
64
+ def cam(
65
+ self,
66
+ input_tensor: torch.Tensor,
67
+ target_class_id: int,
68
+ layer_nums: List,
69
+ transparency: float = 0.7,
70
+ ):
71
+ targets = [ClassifierOutputTarget(target_class_id)]
72
+ target_layers = [getattr(self.model, f"block{layer-1}") for layer in layer_nums]
73
+
74
+ with GradCAM(
75
+ model=self.model,
76
+ target_layers=target_layers,
77
+ use_cuda=self.device == torch.device("cuda"),
78
+ ) as cam:
79
+ grayscale_cam = cam(input_tensor=input_tensor, targets=targets)
80
+ grayscale_cam = grayscale_cam[0, :]
81
+
82
+ img = inv_normalize(input_tensor)
83
+ rgb_img = img[0].permute(1, 2, 0).cpu().numpy()
84
+
85
+ visualization = show_cam_on_image(
86
+ rgb_img, grayscale_cam, use_rgb=True, image_weight=transparency
87
+ )
88
+ return visualization
89
+
90
+ def inference(
91
+ self,
92
+ input_img: np.array,
93
+ transparency: float,
94
+ ntop_classes: int,
95
+ layer_nums: List,
96
+ cam_for_class: str,
97
+ ):
98
+ self.model.eval()
99
+ input_img = transform(input_img)
100
+
101
+ input_img = input_img.to(self.device)
102
+ input_img = input_img.unsqueeze(0)
103
+
104
+ with torch.no_grad():
105
+ outputs = self.model(input_img).squeeze(0)
106
+ outputs = F.softmax(outputs / self.temperature, dim=-1)
107
+
108
+ probability, prediction = torch.sort(outputs, descending=True)
109
+ prediction = list(zip(prediction.tolist(), probability.tolist()))
110
+
111
+ class_id = (
112
+ prediction[0][0]
113
+ if cam_for_class in ["default", ""]
114
+ else classes.index(cam_for_class)
115
+ )
116
+ visualization = self.cam(
117
+ input_tensor=input_img,
118
+ target_class_id=class_id,
119
+ layer_nums=layer_nums,
120
+ transparency=transparency,
121
+ )
122
+ top_nclass_result = [
123
+ (classes[class_id], round(score, 2))
124
+ for class_id, score in prediction[:ntop_classes]
125
+ ]
126
+ return visualization, dict(top_nclass_result)
127
+
128
+
129
+ method = Gradio(model_path="./model.pt")
130
+ demo = gr.Interface(
131
+ method.inference,
132
+ [
133
+ gr.Image(shape=(32, 32), label="Input Image", value="./samples/dog_cat.jpeg"),
134
+ gr.Slider(
135
+ minimum=0,
136
+ maximum=1,
137
+ value=0.5,
138
+ label="Transparency",
139
+ info="Transparency of the CAM-Attention Output",
140
+ ),
141
+ gr.Slider(
142
+ minimum=1,
143
+ maximum=10,
144
+ step=1,
145
+ value=2,
146
+ label="Top Classes",
147
+ info="Number of Top Predicted Classes",
148
+ ),
149
+ gr.CheckboxGroup(
150
+ choices=[1, 2, 3, 4],
151
+ value=[3, 4],
152
+ label="Network Layers",
153
+ info="Network Layers for CAM-Attention Extraction",
154
+ ),
155
+ gr.Dropdown(
156
+ choices=["default"] + classes,
157
+ multiselect=False,
158
+ value="default",
159
+ label="Class Activation Map (CAM) Focus Visualization",
160
+ info="This section showcases the specific region of interest within the input image that the Class Activation Map (CAM) algorithm emphasizes to make predictions based on the selected class from the dropdown menu. The 'default' value serves as the default choice, representing the top class predicted by the model.",
161
+ ),
162
+ ],
163
+ [
164
+ gr.Image(shape=(32, 32)).style(width=128, height=128),
165
+ gr.Label(label="Top Classes"),
166
+ ],
167
+ )
168
+
169
+
170
+ demo.launch()