Spaces:
Sleeping
Sleeping
from transformers import AutoFeatureExtractor, AutoModel | |
import torch | |
from torchvision.transforms.functional import to_pil_image | |
from einops import rearrange, reduce | |
from skops import hub_utils | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
import gradio as gr | |
import os | |
import pickle | |
setups = ['ResNet-50', 'ViT', 'DINO-ResNet-50', 'DINO-ViT'] | |
embedder_names = ['microsoft/resnet-50', 'google/vit-base-patch16-224', 'Ramos-Ramos/dino-resnet-50', 'facebook/dino-vitb16'] | |
gam_names = ['emb-gam-resnet', 'emb-gam-vit', 'emb-gam-dino-resnet', 'emb-gam-dino'] | |
embedder_to_setup = dict(zip(embedder_names, setups)) | |
gam_to_setup = dict(zip(gam_names, setups)) | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
embedders = {} | |
for name in embedder_names: | |
embedder = {} | |
embedder['feature_extractor'] = AutoFeatureExtractor.from_pretrained(name) | |
embedder['model'] = AutoModel.from_pretrained(name).eval().to(device) | |
if 'resnet-50' in name: | |
embedder['num_patches_side'] = 7 | |
embedder['embedding_postprocess'] = lambda x: rearrange(x.last_hidden_state, 'b d h w -> b (h w) d') | |
else: | |
embedder['num_patches_side'] = embedder['model'].config.image_size // embedder['model'].config.patch_size | |
embedder['embedding_postprocess'] = lambda x: x.last_hidden_state[:, 1:] | |
embedders[embedder_to_setup[name]] = embedder | |
gams = {} | |
for name in gam_names: | |
if not os.path.exists(name): | |
os.mkdir(name) | |
hub_utils.download(repo_id=f'Ramos-Ramos/{name}', dst=name) | |
with open(f'{name}/model.pkl', 'rb') as infile: | |
gams[gam_to_setup[name]] = pickle.load(infile) | |
labels = [ | |
'tench', | |
'English springer', | |
'cassette player', | |
'chain saw', | |
'church', | |
'French horn', | |
'garbage truck', | |
'gas pump', | |
'golf ball', | |
'parachute' | |
] | |
def visualize(input_img, visual_emb_gam_setups, show_scores, show_cbars): | |
'''Visualizes the patch contributions to all labels of one or more visual | |
Emb-GAMs''' | |
if not visual_emb_gam_setups: | |
fig = plt.Figure() | |
return fig, fig | |
patch_contributions = {} | |
# get patch contributions per Emb-GAM | |
for setup in visual_emb_gam_setups: | |
# prepare embedding model | |
embedder_setup = embedders[setup] | |
feature_extractor = embedder_setup['feature_extractor'] | |
embedding_postprocess = embedder_setup['embedding_postprocess'] | |
num_patches_side = embedder_setup['num_patches_side'] | |
# prepare GAM | |
gam = gams[setup] | |
# get patch embeddings | |
inputs = { | |
k: v.to(device) | |
for k, v | |
in feature_extractor(input_img, return_tensors='pt').items() | |
} | |
with torch.no_grad(): | |
patch_embeddings = embedding_postprocess( | |
embedder_setup['model'](**inputs) | |
).cpu()[0] | |
# get patch emebddings | |
patch_contributions[setup] = ( | |
gam.coef_ \ | |
+ gam.intercept_.reshape(-1, 1) / (num_patches_side ** 2) | |
).reshape(-1, num_patches_side, num_patches_side) | |
# plot heatmaps | |
multiple_setups = len(visual_emb_gam_setups) > 1 | |
# set up figure | |
fig, axs = plt.subplots( | |
len(visual_emb_gam_setups), | |
11, | |
figsize=(20, round(10/4 * len(visual_emb_gam_setups))) | |
) | |
gs_ax = axs[0, 0] if multiple_setups else axs[0] | |
gs = gs_ax.get_gridspec() | |
ax_rm = axs[:, 0] if multiple_setups else [axs[0]] | |
for ax in ax_rm: | |
ax.remove() | |
ax_orig_img = fig.add_subplot(gs[:, 0] if multiple_setups else gs[0]) | |
# plot original image | |
ax_orig_img.imshow(input_img) | |
ax_orig_img.axis('off') | |
# plot patch contributions | |
axs_maps = axs[:, 1:] if multiple_setups else [axs[1:]] | |
for i, setup in enumerate(visual_emb_gam_setups): | |
vmin = patch_contributions[setup].min() | |
vmax = patch_contributions[setup].max() | |
for j in range(10): | |
ax = axs_maps[i][j] | |
sns.heatmap( | |
patch_contributions[setup][j], | |
ax=ax, | |
square=True, | |
vmin=vmin, | |
vmax=vmax, | |
cbar=show_cbars | |
) | |
if show_scores: | |
ax.set_xlabel(f'{patch_contributions[setup][j].sum():.2f}') | |
if j == 0: | |
ax.set_ylabel(setup) | |
if i == 0: | |
ax.set_title(labels[j]) | |
ax.set_xticks([]) | |
ax.set_yticks([]) | |
plt.tight_layout() | |
return fig | |
demo = gr.Interface( | |
fn=visualize, | |
inputs=[ | |
gr.Image(shape=(224, 224), type='pil', label='Input image'), | |
gr.CheckboxGroup(setups, value=setups, label='Visual Emb-GAM'), | |
gr.Checkbox(label='Show scores'), | |
gr.Checkbox(label='Show color bars') | |
], | |
outputs=[ | |
gr.Plot(label='Patch contributions'), | |
] | |
) | |
demo.launch(debug=True) |