raedinkhaled commited on
Commit
5174b1f
1 Parent(s): b6c245f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +114 -0
app.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from pathlib import Path
3
+
4
+ import cv2
5
+ import gradio as gr
6
+ import numpy as np
7
+ import torch
8
+ from torchvision import models, transforms
9
+ from torchvision.models.feature_extraction import create_feature_extractor
10
+ from transformers import ViTForImageClassification
11
+
12
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
13
+
14
+ labels = json.loads(Path("labels.json").read_text())
15
+
16
+ # Load ResNet-50
17
+ resnet50 = models.resnet50(pretrained=True).to(device)
18
+ resnet50.eval()
19
+
20
+ # Create ResNet feature extractor
21
+ feature_extractor = create_feature_extractor(resnet50, return_nodes=["layer4", "fc"])
22
+ fc_layer_weights = resnet50.fc.weight
23
+
24
+ # Load ViT
25
+ vit = ViTForImageClassification.from_pretrained("raedinkhaled/vit-base-mri").to(
26
+ device
27
+ )
28
+ vit.eval()
29
+
30
+
31
+ normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
32
+
33
+ preprocess = transforms.Compose(
34
+ [transforms.Resize((224, 224)), transforms.ToTensor(), normalize]
35
+ )
36
+
37
+ examples = sorted([f.as_posix() for f in Path("examples").glob("*")])
38
+
39
+
40
+ def get_cam(img_tensor):
41
+ output = feature_extractor(img_tensor)
42
+ cnn_features = output["layer4"].squeeze()
43
+ class_id = output["fc"].argmax()
44
+
45
+ cam = fc_layer_weights[class_id].matmul(cnn_features.flatten(1))
46
+ cam = cam.reshape(cnn_features.shape[1], cnn_features.shape[2])
47
+
48
+ return cam.cpu().numpy(), labels[class_id]
49
+
50
+
51
+ def get_attention_mask(img_tensor):
52
+ result = vit(img_tensor, output_attentions=True)
53
+ class_id = result[0].argmax()
54
+ attention_probs = torch.stack(result[1]).squeeze(1)
55
+
56
+ # Average the attention at each layer over all heads
57
+ attention_probs = torch.mean(attention_probs, dim=1)
58
+ residual = torch.eye(attention_probs.size(-1)).to(device)
59
+ attention_probs = 0.5 * attention_probs + 0.5 * residual
60
+
61
+ # normalize by layer
62
+ attention_probs = attention_probs / attention_probs.sum(dim=-1).unsqueeze(-1)
63
+
64
+ attention_rollout = attention_probs[0]
65
+
66
+ for i in range(1, attention_probs.size(0)):
67
+ attention_rollout = torch.matmul(attention_probs[i], attention_rollout)
68
+
69
+ # Attention between cls token and patches
70
+ mask = attention_rollout[0, 1:]
71
+ mask_size = np.sqrt(mask.size(0)).astype(int)
72
+ mask = mask.reshape(mask_size, mask_size)
73
+
74
+ return mask.cpu().numpy(), labels[class_id]
75
+
76
+
77
+ def plot_mask_on_image(image, mask):
78
+ # min-max normalization
79
+ mask = (mask - mask.min()) / mask.max()
80
+ mask = (255 * mask).astype(np.uint8)
81
+ mask = cv2.resize(mask, image.size)
82
+
83
+ heatmap = cv2.applyColorMap(mask, cv2.COLORMAP_JET)
84
+ result = heatmap * 0.3 + np.array(image) * 0.5
85
+ return result.astype(np.uint8)
86
+
87
+
88
+ def inference(img):
89
+ img_tensor = preprocess(img).unsqueeze(0).to(device)
90
+
91
+ with torch.no_grad():
92
+ cam, resnet_label = get_cam(img_tensor)
93
+ attention_mask, vit_label = get_attention_mask(img_tensor)
94
+
95
+ cam_result = plot_mask_on_image(img, cam)
96
+ rollout_result = plot_mask_on_image(img, attention_mask)
97
+
98
+ return resnet_label, cam_result, vit_label, rollout_result
99
+
100
+ if __name__ == "__main__":
101
+ interface = gr.Interface(
102
+ fn=inference,
103
+ inputs=gr.inputs.Image(type="pil", label="Input Image"),
104
+ outputs=[
105
+ gr.outputs.Label(num_top_classes=1, type="auto", label="ResNet Label"),
106
+ gr.outputs.Image(type="auto", label="ResNet CAM"),
107
+ gr.outputs.Label(num_top_classes=1, type="auto", label="ViT Label"),
108
+ gr.outputs.Image(type="auto", label="raedinkhaled/vit-base-mri CAM"),
109
+ ],
110
+ examples=examples,
111
+ title="Transformer Explainability On Our Pre Trained Model",
112
+ live=True,
113
+ )
114
+ interface.launch()