Orpheous1 commited on
Commit
0028cfc
1 Parent(s): 5dc90b6
Files changed (1) hide show
  1. app.py +27 -26
app.py CHANGED
@@ -9,8 +9,6 @@ from utils.plot import smoothen, draw_mask_on_image, draw_heatmap_on_image
9
  import gradio as gr
10
  import numpy as np
11
  import torch
12
- import seaborn as sns
13
- import matplotlib.pyplot as plt
14
 
15
  # Load Vision Transformer
16
  hf_model = "tanlq/vit-base-patch16-224-in21k-finetuned-cifar10"
@@ -52,46 +50,49 @@ def draw_heatmap(image, mask):
52
  # Define callable method for the demo
53
  def get_mask(image, model_name: str):
54
  if image is None:
55
- return None, None
56
- if model_name == 'DiffMask-CiFAR-10':
 
57
  diffmask_model = diffmask
58
  elif model_name == 'DiffMask-ImageNet':
59
  diffmask_model = diffmask_imagenet
 
 
 
 
 
 
60
  image = torch.from_numpy(image).permute(2, 0, 1).float() / 255
61
  dm_image = feature_extractor(image).unsqueeze(0)
62
  dm_out = diffmask_model.get_mask(dm_image)
63
- mask = dm_out["mask"][0].detach()
64
- logits = dm_out["logits"][0].detach().softmax(dim=-1)
65
- logits_orig = dm_out["logits_orig"][0].detach().softmax(dim=-1)
66
- # fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(10, 10))
67
- # sns.displot(logits_orig.cpu().numpy().flatten(), kind="kde", label="Original", ax=ax)
68
- top5logits_orig = logits_orig.topk(5, dim=-1)
69
- idx = top5logits_orig.indices
70
- # keep the top 5 classes from the indices of the top 5 logits
71
- top5logits_orig = top5logits_orig.values
72
- top5logits = logits[idx]
73
-
74
- pred = dm_out["pred_class"][0].detach()
75
- pred = diffmask_model.model.config.id2label[pred.item()]
76
 
 
 
77
  masked_img = draw_mask(image, mask)
78
  heatmap = draw_heatmap(image, mask)
79
- orig_probs = {diffmask_model.model.config.id2label[i]: top5logits_orig[i].item() for i in range(5)}
80
- pred_probs = {diffmask_model.model.config.id2label[i]: top5logits[i].item() for i in range(5)}
81
-
82
- return np.hstack((masked_img, heatmap)), pred, orig_probs, pred_probs
83
 
 
 
 
 
 
 
84
 
 
85
 
86
 
87
  # Launch demo interface
88
  gr.Interface(
89
  get_mask,
90
- inputs=[gr.inputs.Image(label="Input", shape=(224, 224), source="upload", type="numpy"),
91
- gr.inputs.Dropdown(["DiffMask-CiFAR-10", "DiffMask-ImageNet"])],
92
- outputs=[gr.outputs.Image(label="Output"), gr.outputs.Label(label="Prediction"),
93
- gr.Label(label="Original Probabilities"), gr.Label(label="Predicted Probabilities")],
 
 
 
 
 
94
  title="Vision DiffMask Demo",
95
  live=True,
96
  ).launch()
97
-
 
9
  import gradio as gr
10
  import numpy as np
11
  import torch
 
 
12
 
13
  # Load Vision Transformer
14
  hf_model = "tanlq/vit-base-patch16-224-in21k-finetuned-cifar10"
 
50
  # Define callable method for the demo
51
  def get_mask(image, model_name: str):
52
  if image is None:
53
+ return None, None, None
54
+
55
+ if model_name == 'DiffMask-CIFAR-10':
56
  diffmask_model = diffmask
57
  elif model_name == 'DiffMask-ImageNet':
58
  diffmask_model = diffmask_imagenet
59
+
60
+ # Helper function to convert class index to name
61
+ def idx2cname(idx):
62
+ return diffmask_model.model.config.id2label[idx]
63
+
64
+ # Prepare image and pass through Vision DiffMask
65
  image = torch.from_numpy(image).permute(2, 0, 1).float() / 255
66
  dm_image = feature_extractor(image).unsqueeze(0)
67
  dm_out = diffmask_model.get_mask(dm_image)
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
+ # Get mask and apply on image
70
+ mask = dm_out["mask"][0].detach()
71
  masked_img = draw_mask(image, mask)
72
  heatmap = draw_heatmap(image, mask)
 
 
 
 
73
 
74
+ # Get logits and map to predictions with class names
75
+ n_classes = len(diffmask_model.model.config.id2label)
76
+ logits_orig = dm_out["logits_orig"][0].detach().softmax(dim=-1)
77
+ logits_mask = dm_out["logits"][0].detach().softmax(dim=-1)
78
+ orig_probs = {idx2cname(i): logits_orig[i].item() for i in range(n_classes)}
79
+ mask_probs = {idx2cname(i): logits_mask[i].item() for i in range(n_classes)}
80
 
81
+ return np.hstack((masked_img, heatmap)), orig_probs, mask_probs
82
 
83
 
84
  # Launch demo interface
85
  gr.Interface(
86
  get_mask,
87
+ inputs=[
88
+ gr.inputs.Image(label="Input", shape=(224, 224), source="upload", type="numpy"),
89
+ gr.inputs.Dropdown(label="Model Name", choices=["DiffMask-CIFAR-10", "DiffMask-ImageNet"]),
90
+ ],
91
+ outputs=[
92
+ gr.outputs.Image(label="Output"),
93
+ gr.outputs.Label(label="Original Prediction", num_top_classes=5),
94
+ gr.outputs.Label(label="Masked Prediction", num_top_classes=5),
95
+ ],
96
  title="Vision DiffMask Demo",
97
  live=True,
98
  ).launch()