Spaces:
Runtime error
Runtime error
Orpheous1
commited on
Commit
•
0028cfc
1
Parent(s):
5dc90b6
fix top 5
Browse files
app.py
CHANGED
@@ -9,8 +9,6 @@ from utils.plot import smoothen, draw_mask_on_image, draw_heatmap_on_image
|
|
9 |
import gradio as gr
|
10 |
import numpy as np
|
11 |
import torch
|
12 |
-
import seaborn as sns
|
13 |
-
import matplotlib.pyplot as plt
|
14 |
|
15 |
# Load Vision Transformer
|
16 |
hf_model = "tanlq/vit-base-patch16-224-in21k-finetuned-cifar10"
|
@@ -52,46 +50,49 @@ def draw_heatmap(image, mask):
|
|
52 |
# Define callable method for the demo
|
53 |
def get_mask(image, model_name: str):
|
54 |
if image is None:
|
55 |
-
return None, None
|
56 |
-
|
|
|
57 |
diffmask_model = diffmask
|
58 |
elif model_name == 'DiffMask-ImageNet':
|
59 |
diffmask_model = diffmask_imagenet
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
image = torch.from_numpy(image).permute(2, 0, 1).float() / 255
|
61 |
dm_image = feature_extractor(image).unsqueeze(0)
|
62 |
dm_out = diffmask_model.get_mask(dm_image)
|
63 |
-
mask = dm_out["mask"][0].detach()
|
64 |
-
logits = dm_out["logits"][0].detach().softmax(dim=-1)
|
65 |
-
logits_orig = dm_out["logits_orig"][0].detach().softmax(dim=-1)
|
66 |
-
# fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(10, 10))
|
67 |
-
# sns.displot(logits_orig.cpu().numpy().flatten(), kind="kde", label="Original", ax=ax)
|
68 |
-
top5logits_orig = logits_orig.topk(5, dim=-1)
|
69 |
-
idx = top5logits_orig.indices
|
70 |
-
# keep the top 5 classes from the indices of the top 5 logits
|
71 |
-
top5logits_orig = top5logits_orig.values
|
72 |
-
top5logits = logits[idx]
|
73 |
-
|
74 |
-
pred = dm_out["pred_class"][0].detach()
|
75 |
-
pred = diffmask_model.model.config.id2label[pred.item()]
|
76 |
|
|
|
|
|
77 |
masked_img = draw_mask(image, mask)
|
78 |
heatmap = draw_heatmap(image, mask)
|
79 |
-
orig_probs = {diffmask_model.model.config.id2label[i]: top5logits_orig[i].item() for i in range(5)}
|
80 |
-
pred_probs = {diffmask_model.model.config.id2label[i]: top5logits[i].item() for i in range(5)}
|
81 |
-
|
82 |
-
return np.hstack((masked_img, heatmap)), pred, orig_probs, pred_probs
|
83 |
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
|
|
|
85 |
|
86 |
|
87 |
# Launch demo interface
|
88 |
gr.Interface(
|
89 |
get_mask,
|
90 |
-
inputs=[
|
91 |
-
|
92 |
-
|
93 |
-
|
|
|
|
|
|
|
|
|
|
|
94 |
title="Vision DiffMask Demo",
|
95 |
live=True,
|
96 |
).launch()
|
97 |
-
|
|
|
9 |
import gradio as gr
|
10 |
import numpy as np
|
11 |
import torch
|
|
|
|
|
12 |
|
13 |
# Load Vision Transformer
|
14 |
hf_model = "tanlq/vit-base-patch16-224-in21k-finetuned-cifar10"
|
|
|
50 |
# Define callable method for the demo
|
51 |
def get_mask(image, model_name: str):
|
52 |
if image is None:
|
53 |
+
return None, None, None
|
54 |
+
|
55 |
+
if model_name == 'DiffMask-CIFAR-10':
|
56 |
diffmask_model = diffmask
|
57 |
elif model_name == 'DiffMask-ImageNet':
|
58 |
diffmask_model = diffmask_imagenet
|
59 |
+
|
60 |
+
# Helper function to convert class index to name
|
61 |
+
def idx2cname(idx):
|
62 |
+
return diffmask_model.model.config.id2label[idx]
|
63 |
+
|
64 |
+
# Prepare image and pass through Vision DiffMask
|
65 |
image = torch.from_numpy(image).permute(2, 0, 1).float() / 255
|
66 |
dm_image = feature_extractor(image).unsqueeze(0)
|
67 |
dm_out = diffmask_model.get_mask(dm_image)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
|
69 |
+
# Get mask and apply on image
|
70 |
+
mask = dm_out["mask"][0].detach()
|
71 |
masked_img = draw_mask(image, mask)
|
72 |
heatmap = draw_heatmap(image, mask)
|
|
|
|
|
|
|
|
|
73 |
|
74 |
+
# Get logits and map to predictions with class names
|
75 |
+
n_classes = len(diffmask_model.model.config.id2label)
|
76 |
+
logits_orig = dm_out["logits_orig"][0].detach().softmax(dim=-1)
|
77 |
+
logits_mask = dm_out["logits"][0].detach().softmax(dim=-1)
|
78 |
+
orig_probs = {idx2cname(i): logits_orig[i].item() for i in range(n_classes)}
|
79 |
+
mask_probs = {idx2cname(i): logits_mask[i].item() for i in range(n_classes)}
|
80 |
|
81 |
+
return np.hstack((masked_img, heatmap)), orig_probs, mask_probs
|
82 |
|
83 |
|
84 |
# Launch demo interface
|
85 |
gr.Interface(
|
86 |
get_mask,
|
87 |
+
inputs=[
|
88 |
+
gr.inputs.Image(label="Input", shape=(224, 224), source="upload", type="numpy"),
|
89 |
+
gr.inputs.Dropdown(label="Model Name", choices=["DiffMask-CIFAR-10", "DiffMask-ImageNet"]),
|
90 |
+
],
|
91 |
+
outputs=[
|
92 |
+
gr.outputs.Image(label="Output"),
|
93 |
+
gr.outputs.Label(label="Original Prediction", num_top_classes=5),
|
94 |
+
gr.outputs.Label(label="Masked Prediction", num_top_classes=5),
|
95 |
+
],
|
96 |
title="Vision DiffMask Demo",
|
97 |
live=True,
|
98 |
).launch()
|
|