import sys sys.path.insert(0, './code') from datamodules.transformations import UnNest from models.interpretation import ImageInterpretationNet from transformers import ViTFeatureExtractor, ViTForImageClassification from utils.plot import smoothen, draw_mask_on_image, draw_heatmap_on_image import gradio as gr import numpy as np import torch # Load Vision Transformer hf_model = "tanlq/vit-base-patch16-224-in21k-finetuned-cifar10" hf_model_imagenet = "google/vit-base-patch16-224" vit = ViTForImageClassification.from_pretrained(hf_model) vit_imagenet = ViTForImageClassification.from_pretrained(hf_model_imagenet) vit.eval() vit_imagenet.eval() # Load Feature Extractor feature_extractor = ViTFeatureExtractor.from_pretrained(hf_model, return_tensors="pt") feature_extractor_imagenet = ViTFeatureExtractor.from_pretrained(hf_model_imagenet, return_tensors="pt") feature_extractor = UnNest(feature_extractor) feature_extractor_imagenet = UnNest(feature_extractor_imagenet) # Load Vision DiffMask diffmask = ImageInterpretationNet.load_from_checkpoint('checkpoints/diffmask.ckpt') diffmask.set_vision_transformer(vit) diffmask_imagenet = ImageInterpretationNet.load_from_checkpoint('checkpoints/diffmask_imagenet.ckpt') diffmask_imagenet.set_vision_transformer(vit_imagenet) diffmask.eval() diffmask_imagenet.eval() # Define mask plotting functions def draw_mask(image, mask): return draw_mask_on_image(image, smoothen(mask))\ .permute(1, 2, 0)\ .clip(0, 1)\ .numpy() def draw_heatmap(image, mask): return draw_heatmap_on_image(image, smoothen(mask))\ .permute(1, 2, 0)\ .clip(0, 1)\ .numpy() # Define callable method for the demo def get_mask(image, model_name: str): torch.manual_seed(seed=0) if image is None: return None, None, None if model_name == 'DiffMask-CIFAR-10': diffmask_model = diffmask elif model_name == 'DiffMask-ImageNet': diffmask_model = diffmask_imagenet # Helper function to convert class index to name def idx2cname(idx): return diffmask_model.model.config.id2label[idx] # Prepare image and pass through Vision DiffMask image = torch.from_numpy(image).permute(2, 0, 1).float() / 255 dm_image = feature_extractor(image).unsqueeze(0) dm_out = diffmask_model.get_mask(dm_image) # Get mask and apply on image mask = dm_out["mask"][0].detach() masked_img = draw_mask(image, mask) heatmap = draw_heatmap(image, mask) # Get logits and map to predictions with class names n_classes = len(diffmask_model.model.config.id2label) logits_orig = dm_out["logits_orig"][0].detach().softmax(dim=-1) logits_mask = dm_out["logits"][0].detach().softmax(dim=-1) orig_probs = {idx2cname(i): logits_orig[i].item() for i in range(n_classes)} mask_probs = {idx2cname(i): logits_mask[i].item() for i in range(n_classes)} return np.hstack((masked_img, heatmap)), orig_probs, mask_probs # Launch demo interface gr.Interface( get_mask, inputs=[ gr.inputs.Image(label="Input", shape=(224, 224), source="upload", type="numpy"), gr.inputs.Dropdown(label="Model Name", choices=["DiffMask-ImageNet", "DiffMask-CIFAR-10"]), ], outputs=[ gr.outputs.Image(label="Output"), gr.outputs.Label(label="Original Prediction", num_top_classes=5), gr.outputs.Label(label="Masked Prediction", num_top_classes=5), ], examples=[["dogcat.jpeg", "DiffMask-ImageNet"], ["elephant-zebra.jpg", "DiffMask-ImageNet"], ["finch.jpeg", "DiffMask-ImageNet"]], title="Vision DiffMask Demo", live=True, ).launch()