import sys sys.path.insert(0, './code') from datamodules.transformations import UnNest from models.interpretation import ImageInterpretationNet from transformers import ViTFeatureExtractor, ViTForImageClassification from utils.plot import smoothen, draw_mask_on_image, draw_heatmap_on_image import gradio as gr import numpy as np import torch # Load Vision Transformer hf_model = "tanlq/vit-base-patch16-224-in21k-finetuned-cifar10" hf_model_imagenet = "google/vit-base-patch16-224" vit = ViTForImageClassification.from_pretrained(hf_model) vit_imagenet = ViTForImageClassification.from_pretrained(hf_model_imagenet) vit.eval() vit_imagenet.eval() # Load Feature Extractor feature_extractor = ViTFeatureExtractor.from_pretrained(hf_model, return_tensors="pt") feature_extractor_imagenet = ViTFeatureExtractor.from_pretrained(hf_model_imagenet, return_tensors="pt") feature_extractor = UnNest(feature_extractor) feature_extractor_imagenet = UnNest(feature_extractor_imagenet) # Load Vision DiffMask diffmask = ImageInterpretationNet.load_from_checkpoint('checkpoints/diffmask.ckpt') diffmask.set_vision_transformer(vit) diffmask_imagenet = ImageInterpretationNet.load_from_checkpoint('checkpoints/diffmask_imagenet.ckpt') diffmask_imagenet.set_vision_transformer(vit_imagenet) diffmask.eval() diffmask_imagenet.eval() # Define mask plotting functions def draw_mask(image, mask): return draw_mask_on_image(image, smoothen(mask))\ .permute(1, 2, 0)\ .clip(0, 1)\ .numpy() def draw_heatmap(image, mask): return draw_heatmap_on_image(image, smoothen(mask))\ .permute(1, 2, 0)\ .clip(0, 1)\ .numpy() # Define callable method for the demo def get_mask(image, model_name: str): if image is None: return None, None if model_name == 'DiffMask-CiFAR-10': diffmask_model = diffmask elif model_name == 'DiffMask-ImageNet': diffmask_model = diffmask_imagenet image = torch.from_numpy(image).permute(2, 0, 1).float() / 255 dm_image = feature_extractor(image).unsqueeze(0) dm_out = diffmask_model.get_mask(dm_image) mask = dm_out["mask"][0].detach() pred = dm_out["pred_class"][0].detach() pred = diffmask_model.model.config.id2label[pred.item()] masked_img = draw_mask(image, mask) heatmap = draw_heatmap(image, mask) return np.hstack((masked_img, heatmap)), pred # Launch demo interface gr.Interface( get_mask, inputs=[gr.inputs.Image(label="Input", shape=(224, 224), source="upload", type="numpy"), gr.inputs.Dropdown(["DiffMask-CiFAR-10", "DiffMask-ImageNet"])], outputs=[gr.outputs.Image(label="Output"), gr.outputs.Label(label="Prediction")], title="Vision DiffMask Demo", live=True, ).launch()