import gradio as gr import matplotlib.pyplot as plt import numpy as np import os import requests import timm import torch import torchvision.transforms as T import types import albumentations as A from PIL import Image from tqdm import tqdm from sklearn.decomposition import PCA from torch_kmeans import KMeans, CosineSimilarity cmap = plt.get_cmap("tab20") MEAN = np.array([123.675, 116.280, 103.530]) / 255 STD = np.array([58.395, 57.120, 57.375]) / 255 transforms = A.Compose([ A.Normalize(mean=list(MEAN), std=list(STD)), ]) def get_intermediate_layers( self, x: torch.Tensor, n=1, reshape: bool = False, return_prefix_tokens: bool = False, return_class_token: bool = False, norm: bool = True, ): outputs = self._intermediate_layers(x, n) if norm: outputs = [self.norm(out) for out in outputs] if return_class_token: prefix_tokens = [out[:, 0] for out in outputs] else: prefix_tokens = [out[:, 0 : self.num_prefix_tokens] for out in outputs] outputs = [out[:, self.num_prefix_tokens :] for out in outputs] if reshape: B, C, H, W = x.shape grid_size = ( (H - self.patch_embed.patch_size[0]) // self.patch_embed.proj.stride[0] + 1, (W - self.patch_embed.patch_size[1]) // self.patch_embed.proj.stride[1] + 1, ) outputs = [ out.reshape(x.shape[0], grid_size[0], grid_size[1], -1) .permute(0, 3, 1, 2) .contiguous() for out in outputs ] if return_prefix_tokens or return_class_token: return tuple(zip(outputs, prefix_tokens)) return tuple(outputs) def viz_feat(feat): _,_,h,w = feat.shape feat = feat.squeeze(0).permute((1,2,0)) projected_featmap = feat.reshape(-1, feat.shape[-1]).cpu() pca = PCA(n_components=3) pca.fit(projected_featmap) pca_features = pca.transform(projected_featmap) pca_features = (pca_features - pca_features.min()) / (pca_features.max() - pca_features.min()) pca_features = pca_features * 255 res_pred = Image.fromarray(pca_features.reshape(h, w, 3).astype(np.uint8)) return res_pred def plot_feats(model_option, ori_feats, fine_feats, ori_labels=None, fine_labels=None): ori_feats_map = viz_feat(ori_feats) fine_feats_map = viz_feat(fine_feats) fig, ax = plt.subplots(2, 2, figsize=(6, 5)) ax[0][0].imshow(ori_feats_map) ax[0][0].set_title("Original " + model_option, fontsize=15) ax[0][1].imshow(fine_feats_map) ax[0][1].set_title("Ours", fontsize=15) ax[1][0].imshow(ori_labels) ax[1][1].imshow(fine_labels) for xx in ax: for x in xx: x.xaxis.set_major_formatter(plt.NullFormatter()) x.yaxis.set_major_formatter(plt.NullFormatter()) x.set_xticks([]) x.set_yticks([]) x.axis('off') plt.tight_layout() plt.close(fig) return fig def download_image(url, save_path): response = requests.get(url) with open(save_path, 'wb') as file: file.write(response.content) def process_image(image, stride, transforms): transformed = transforms(image=np.array(image)) image_tensor = torch.tensor(transformed['image']) image_tensor = image_tensor.permute(2,0,1) image_tensor = image_tensor.unsqueeze(0).to(device) h, w = image_tensor.shape[2:] height_int = (h // stride)*stride width_int = (w // stride)*stride image_resized = torch.nn.functional.interpolate(image_tensor, size=(height_int, width_int), mode='bilinear') return image_resized def kmeans_clustering(feats_map, n_clusters=20): if n_clusters == None: n_clusters = 20 print('num clusters: ', n_clusters) B, D, h, w = feats_map.shape feats_map_flattened = feats_map.permute((0, 2, 3, 1)).reshape(B, -1, D) kmeans_engine = KMeans(n_clusters=n_clusters, distance=CosineSimilarity) kmeans_engine.fit(feats_map_flattened) labels = kmeans_engine.predict( feats_map_flattened ) labels = labels.reshape( B, h, w ).float() labels = labels[0].cpu().numpy() label_map = cmap(labels / n_clusters)[..., :3] label_map = np.uint8(label_map * 255) label_map = Image.fromarray(label_map) return label_map def load_model(options): original_models = {} fine_models = {} for option in tqdm(options): print('Please wait ...') print('loading weights of ', option) original_models[option] = timm.create_model( timm_model_card[option], pretrained=True, num_classes=0, dynamic_img_size=True, dynamic_img_pad=False, ).to(device) original_models[option].get_intermediate_layers = types.MethodType( get_intermediate_layers, original_models[option] ) fine_models[option] = torch.hub.load("ywyue/FiT3D", our_model_card[option]).to(device) fine_models[option].get_intermediate_layers = types.MethodType( get_intermediate_layers, fine_models[option] ) print('Done! Now play the demo :)') return original_models, fine_models if __name__ == "__main__": if torch.cuda.is_available(): device = torch.device('cuda') else: device = torch.device('cpu') print("device: ") print(device) example_urls = { "library.jpg": "https://n.ethz.ch/~yuayue/assets/fit3d/demo_images/library.jpg", "livingroom.jpg": "https://n.ethz.ch/~yuayue/assets/fit3d/demo_images/livingroom.jpg", "airplane.jpg": "https://n.ethz.ch/~yuayue/assets/fit3d/demo_images/airplane.jpg", "ship.jpg": "https://n.ethz.ch/~yuayue/assets/fit3d/demo_images/ship.jpg", "chair.jpg": "https://n.ethz.ch/~yuayue/assets/fit3d/demo_images/chair.jpg", } example_dir = "/tmp/examples" os.makedirs(example_dir, exist_ok=True) for name, url in example_urls.items(): save_path = os.path.join(example_dir, name) if not os.path.exists(save_path): print(f"Downloading to {save_path}...") download_image(url, save_path) else: print(f"{save_path} already exists.") image_input = gr.Image(label="Choose an image:", height=500, type="pil", image_mode='RGB', sources=['upload', 'webcam', 'clipboard'] ) options = ['DINOv2', 'DINOv2-reg', 'CLIP', 'MAE', 'DeiT-III'] model_option = gr.Radio(options, value="DINOv2", label='Choose a 2D foundation model') kmeans_num = gr.Number( label="number of K-Means clusters", value=20 ) timm_model_card = { "DINOv2": "vit_small_patch14_dinov2.lvd142m", "DINOv2-reg": "vit_small_patch14_reg4_dinov2.lvd142m", "CLIP": "vit_base_patch16_clip_384.laion2b_ft_in12k_in1k", "MAE": "vit_base_patch16_224.mae", "DeiT-III": "deit3_base_patch16_224.fb_in1k" } our_model_card = { "DINOv2": "dinov2_small_fine", "DINOv2-reg": "dinov2_reg_small_fine", "CLIP": "clip_base_fine", "MAE": "mae_base_fine", "DeiT-III": "deit3_base_fine" } os.environ['TORCH_HOME'] = '/tmp/.cache' os.environ['GRADIO_EXAMPLES_CACHE'] = '/tmp/gradio_cache' # Pre-load all models original_models, fine_models = load_model(options) def fit3d(image, model_option, kmeans_num): # Select model original_model = original_models[model_option] fine_model = fine_models[model_option] # Data preprocessing p = original_model.patch_embed.patch_size stride = p if isinstance(p, int) else p[0] image_resized = process_image(image, stride, transforms) with torch.no_grad(): ori_feats = original_model.get_intermediate_layers(image_resized, n=[8,9,10,11], reshape=True, return_prefix_tokens=False, return_class_token=False, norm=True) fine_feats = fine_model.get_intermediate_layers(image_resized, n=[8,9,10,11], reshape=True, return_prefix_tokens=False, return_class_token=False, norm=True) ori_feats = ori_feats[-1] fine_feats = fine_feats[-1] ori_labels = kmeans_clustering(ori_feats, kmeans_num) fine_labels = kmeans_clustering(fine_feats, kmeans_num) return plot_feats(model_option, ori_feats, fine_feats, ori_labels, fine_labels) demo = gr.Interface( title="
\

FiT3D

\

Improving 2D Feature Representations by 3D-Aware Fine-Tuning

\

ECCV 2024

\
", description="
\ \   \ Project Page \   \ \
", fn=fit3d, inputs=[image_input, model_option, kmeans_num], outputs="plot", examples=[ ["/tmp/examples/library.jpg", "DINOv2"], ["/tmp/examples/livingroom.jpg", "DINOv2"], ["/tmp/examples/airplane.jpg", "DINOv2"], ["/tmp/examples/ship.jpg", "DINOv2"], ["/tmp/examples/chair.jpg", "DINOv2"], ]) demo.launch()