from transformers import AutoFeatureExtractor, AutoModel import torch from torchvision.transforms.functional import to_pil_image from einops import rearrange, reduce from skops import hub_utils import matplotlib.pyplot as plt import seaborn as sns import gradio as gr import os import pickle setups = ['ResNet-50', 'ViT', 'DINO-ResNet-50', 'DINO-ViT'] embedder_names = ['microsoft/resnet-50', 'google/vit-base-patch16-224', 'Ramos-Ramos/dino-resnet-50', 'facebook/dino-vitb16'] gam_names = ['emb-gam-resnet', 'emb-gam-vit', 'emb-gam-dino-resnet', 'emb-gam-dino'] embedder_to_setup = dict(zip(embedder_names, setups)) gam_to_setup = dict(zip(gam_names, setups)) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') embedders = {} for name in embedder_names: embedder = {} embedder['feature_extractor'] = AutoFeatureExtractor.from_pretrained(name) embedder['model'] = AutoModel.from_pretrained(name).eval().to(device) if 'resnet-50' in name: embedder['num_patches_side'] = 7 embedder['embedding_postprocess'] = lambda x: rearrange(x.last_hidden_state, 'b d h w -> b (h w) d') else: embedder['num_patches_side'] = embedder['model'].config.image_size // embedder['model'].config.patch_size embedder['embedding_postprocess'] = lambda x: x.last_hidden_state[:, 1:] embedders[embedder_to_setup[name]] = embedder gams = {} for name in gam_names: if not os.path.exists(name): os.mkdir(name) hub_utils.download(repo_id=f'Ramos-Ramos/{name}', dst=name) with open(f'{name}/model.pkl', 'rb') as infile: gams[gam_to_setup[name]] = pickle.load(infile) labels = [ 'tench', 'English springer', 'cassette player', 'chain saw', 'church', 'French horn', 'garbage truck', 'gas pump', 'golf ball', 'parachute' ] def visualize(input_img, visual_emb_gam_setups, show_scores, show_cbars): '''Visualizes the patch contributions to all labels of one or more visual Emb-GAMs''' if not visual_emb_gam_setups: fig = plt.Figure() return fig, fig patch_contributions = {} # get patch contributions per Emb-GAM for setup in visual_emb_gam_setups: # prepare embedding model embedder_setup = embedders[setup] feature_extractor = embedder_setup['feature_extractor'] embedding_postprocess = embedder_setup['embedding_postprocess'] num_patches_side = embedder_setup['num_patches_side'] # prepare GAM gam = gams[setup] # get patch embeddings inputs = { k: v.to(device) for k, v in feature_extractor(input_img, return_tensors='pt').items() } with torch.no_grad(): patch_embeddings = embedding_postprocess( embedder_setup['model'](**inputs) ).cpu()[0] # get patch emebddings patch_contributions[setup] = ( gam.coef_ \ @ patch_embeddings.T.numpy() \ + gam.intercept_.reshape(-1, 1) / (num_patches_side ** 2) ).reshape(-1, num_patches_side, num_patches_side) # plot heatmaps multiple_setups = len(visual_emb_gam_setups) > 1 # set up figure fig, axs = plt.subplots( len(visual_emb_gam_setups), 11, figsize=(20, round(10/4 * len(visual_emb_gam_setups))) ) gs_ax = axs[0, 0] if multiple_setups else axs[0] gs = gs_ax.get_gridspec() ax_rm = axs[:, 0] if multiple_setups else [axs[0]] for ax in ax_rm: ax.remove() ax_orig_img = fig.add_subplot(gs[:, 0] if multiple_setups else gs[0]) # plot original image ax_orig_img.imshow(input_img) ax_orig_img.axis('off') # plot patch contributions axs_maps = axs[:, 1:] if multiple_setups else [axs[1:]] for i, setup in enumerate(visual_emb_gam_setups): vmin = patch_contributions[setup].min() vmax = patch_contributions[setup].max() for j in range(10): ax = axs_maps[i][j] sns.heatmap( patch_contributions[setup][j], ax=ax, square=True, vmin=vmin, vmax=vmax, cbar=show_cbars ) if show_scores: ax.set_xlabel(f'{patch_contributions[setup][j].sum():.2f}') if j == 0: ax.set_ylabel(setup) if i == 0: ax.set_title(labels[j]) ax.set_xticks([]) ax.set_yticks([]) plt.tight_layout() return fig demo = gr.Interface( fn=visualize, inputs=[ gr.Image(shape=(224, 224), type='pil', label='Input image'), gr.CheckboxGroup(setups, value=setups, label='Visual Emb-GAM'), gr.Checkbox(label='Show scores'), gr.Checkbox(label='Show color bars') ], outputs=[ gr.Plot(label='Patch contributions'), ] ) demo.launch(debug=True)