Spaces:
Runtime error
Runtime error
import sys | |
sys.path.insert(0, './code') | |
from datamodules.transformations import UnNest | |
from models.interpretation import ImageInterpretationNet | |
from transformers import ViTFeatureExtractor, ViTForImageClassification | |
from utils.plot import smoothen, draw_mask_on_image, draw_heatmap_on_image | |
import gradio as gr | |
import numpy as np | |
import torch | |
import seaborn as sns | |
import matplotlib.pyplot as plt | |
# Load Vision Transformer | |
hf_model = "tanlq/vit-base-patch16-224-in21k-finetuned-cifar10" | |
hf_model_imagenet = "google/vit-base-patch16-224" | |
vit = ViTForImageClassification.from_pretrained(hf_model) | |
vit_imagenet = ViTForImageClassification.from_pretrained(hf_model_imagenet) | |
vit.eval() | |
vit_imagenet.eval() | |
# Load Feature Extractor | |
feature_extractor = ViTFeatureExtractor.from_pretrained(hf_model, return_tensors="pt") | |
feature_extractor_imagenet = ViTFeatureExtractor.from_pretrained(hf_model_imagenet, return_tensors="pt") | |
feature_extractor = UnNest(feature_extractor) | |
feature_extractor_imagenet = UnNest(feature_extractor_imagenet) | |
# Load Vision DiffMask | |
diffmask = ImageInterpretationNet.load_from_checkpoint('checkpoints/diffmask.ckpt') | |
diffmask.set_vision_transformer(vit) | |
diffmask_imagenet = ImageInterpretationNet.load_from_checkpoint('checkpoints/diffmask_imagenet.ckpt') | |
diffmask_imagenet.set_vision_transformer(vit_imagenet) | |
diffmask.eval() | |
diffmask_imagenet.eval() | |
# Define mask plotting functions | |
def draw_mask(image, mask): | |
return draw_mask_on_image(image, smoothen(mask))\ | |
.permute(1, 2, 0)\ | |
.clip(0, 1)\ | |
.numpy() | |
def draw_heatmap(image, mask): | |
return draw_heatmap_on_image(image, smoothen(mask))\ | |
.permute(1, 2, 0)\ | |
.clip(0, 1)\ | |
.numpy() | |
# Define callable method for the demo | |
def get_mask(image, model_name: str): | |
if image is None: | |
return None, None | |
if model_name == 'DiffMask-CiFAR-10': | |
diffmask_model = diffmask | |
elif model_name == 'DiffMask-ImageNet': | |
diffmask_model = diffmask_imagenet | |
image = torch.from_numpy(image).permute(2, 0, 1).float() / 255 | |
dm_image = feature_extractor(image).unsqueeze(0) | |
dm_out = diffmask_model.get_mask(dm_image) | |
mask = dm_out["mask"][0].detach() | |
logits = dm_out["logits"][0].detach().softmax(dim=-1) | |
logits_orig = dm_out["logits_orig"][0].detach().softmax(dim=-1) | |
# fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(10, 10)) | |
# sns.displot(logits_orig.cpu().numpy().flatten(), kind="kde", label="Original", ax=ax) | |
top5logits_orig = logits_orig.topk(5, dim=-1) | |
idx = top5logits_orig.indices | |
# keep the top 5 classes from the indices of the top 5 logits | |
top5logits_orig = top5logits_orig.values | |
top5logits = logits[idx] | |
pred = dm_out["pred_class"][0].detach() | |
pred = diffmask_model.model.config.id2label[pred.item()] | |
masked_img = draw_mask(image, mask) | |
heatmap = draw_heatmap(image, mask) | |
orig_probs = {diffmask_model.model.config.id2label[i]: top5logits_orig[i].item() for i in range(5)} | |
pred_probs = {diffmask_model.model.config.id2label[i]: top5logits[i].item() for i in range(5)} | |
return np.hstack((masked_img, heatmap)), pred, orig_probs, pred_probs | |
# Launch demo interface | |
gr.Interface( | |
get_mask, | |
inputs=[gr.inputs.Image(label="Input", shape=(224, 224), source="upload", type="numpy"), | |
gr.inputs.Dropdown(["DiffMask-CiFAR-10", "DiffMask-ImageNet"])], | |
outputs=[gr.outputs.Image(label="Output"), gr.outputs.Label(label="Prediction"), | |
gr.Label(label="Original Probabilities"), gr.Label(label="Predicted Probabilities")], | |
title="Vision DiffMask Demo", | |
live=True, | |
).launch() | |