File size: 3,659 Bytes
d4ab5ac
 
 
 
 
 
 
 
 
 
 
 
 
 
9d9aad0
d4ab5ac
9d9aad0
d4ab5ac
9d9aad0
d4ab5ac
 
 
9d9aad0
d4ab5ac
9d9aad0
d4ab5ac
 
 
 
8fd2935
9d9aad0
 
 
d4ab5ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0422277
8fd2935
523f190
d4ab5ac
0028cfc
 
 
8fd2935
 
 
0028cfc
 
 
 
 
 
d4ab5ac
 
56ec0e7
d4ab5ac
0028cfc
 
d4ab5ac
 
5dc90b6
0028cfc
 
 
 
 
 
5dc90b6
0028cfc
5dc90b6
d4ab5ac
 
 
 
0028cfc
 
523f190
0028cfc
 
 
 
 
 
523f190
 
d4ab5ac
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import sys
sys.path.insert(0, './code')

from datamodules.transformations import UnNest
from models.interpretation import ImageInterpretationNet
from transformers import ViTFeatureExtractor, ViTForImageClassification
from utils.plot import smoothen, draw_mask_on_image, draw_heatmap_on_image

import gradio as gr
import numpy as np
import torch

# Load Vision Transformer
hf_model = "tanlq/vit-base-patch16-224-in21k-finetuned-cifar10"
hf_model_imagenet = "google/vit-base-patch16-224"
vit = ViTForImageClassification.from_pretrained(hf_model)
vit_imagenet = ViTForImageClassification.from_pretrained(hf_model_imagenet)
vit.eval()
vit_imagenet.eval()

# Load Feature Extractor
feature_extractor = ViTFeatureExtractor.from_pretrained(hf_model, return_tensors="pt")
feature_extractor_imagenet = ViTFeatureExtractor.from_pretrained(hf_model_imagenet, return_tensors="pt")
feature_extractor = UnNest(feature_extractor)
feature_extractor_imagenet = UnNest(feature_extractor_imagenet)

# Load Vision DiffMask
diffmask = ImageInterpretationNet.load_from_checkpoint('checkpoints/diffmask.ckpt')
diffmask.set_vision_transformer(vit)
diffmask_imagenet = ImageInterpretationNet.load_from_checkpoint('checkpoints/diffmask_imagenet.ckpt')
diffmask_imagenet.set_vision_transformer(vit_imagenet)
diffmask.eval()
diffmask_imagenet.eval()

# Define mask plotting functions
def draw_mask(image, mask):
    return draw_mask_on_image(image, smoothen(mask))\
        .permute(1, 2, 0)\
        .clip(0, 1)\
        .numpy()


def draw_heatmap(image, mask):
    return draw_heatmap_on_image(image, smoothen(mask))\
        .permute(1, 2, 0)\
        .clip(0, 1)\
        .numpy()


# Define callable method for the demo
@torch.no_grad()
def get_mask(image, model_name: str):
    torch.manual_seed(seed=0)
    if image is None:
        return None, None, None

    if model_name == 'DiffMask-CIFAR-10':
        diffmask_model = diffmask
    elif model_name == 'DiffMask-ImageNet':
        diffmask_model = diffmask_imagenet

    # Helper function to convert class index to name
    def idx2cname(idx):
        return diffmask_model.model.config.id2label[idx]

    # Prepare image and pass through Vision DiffMask
    image = torch.from_numpy(image).permute(2, 0, 1).float() / 255
    dm_image = feature_extractor(image).unsqueeze(0)
    dm_out = diffmask_model.get_mask(dm_image)

    # Get mask and apply on image
    mask = dm_out["mask"][0].detach()
    masked_img = draw_mask(image, mask)
    heatmap = draw_heatmap(image, mask)

    # Get logits and map to predictions with class names
    n_classes = len(diffmask_model.model.config.id2label)
    logits_orig = dm_out["logits_orig"][0].detach().softmax(dim=-1)
    logits_mask = dm_out["logits"][0].detach().softmax(dim=-1)
    orig_probs = {idx2cname(i): logits_orig[i].item() for i in range(n_classes)}
    mask_probs = {idx2cname(i): logits_mask[i].item() for i in range(n_classes)}

    return np.hstack((masked_img, heatmap)), orig_probs, mask_probs


# Launch demo interface
gr.Interface(
    get_mask,
    inputs=[
        gr.inputs.Image(label="Input", shape=(224, 224), source="upload", type="numpy"),
        gr.inputs.Dropdown(label="Model Name", choices=["DiffMask-ImageNet", "DiffMask-CIFAR-10"]),
    ],
    outputs=[
        gr.outputs.Image(label="Output"),
        gr.outputs.Label(label="Original Prediction", num_top_classes=5),
        gr.outputs.Label(label="Masked Prediction", num_top_classes=5),
    ],
    examples=[["dogcat.jpeg", "DiffMask-ImageNet"], ["elephant-zebra.jpg", "DiffMask-ImageNet"],
    ["finch.jpeg", "DiffMask-ImageNet"]],
    title="Vision DiffMask Demo",
    live=True,
).launch()