FiT3D / app.py
yuanwenyue's picture
Update app.py
cc2c63e verified
raw
history blame contribute delete
No virus
10.1 kB
import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
import os
import requests
import spaces
import timm
import torch
import torchvision.transforms as T
import types
import albumentations as A
from PIL import Image
from tqdm import tqdm
from sklearn.decomposition import PCA
from torch_kmeans import KMeans, CosineSimilarity
cmap = plt.get_cmap("tab20")
MEAN = np.array([123.675, 116.280, 103.530]) / 255
STD = np.array([58.395, 57.120, 57.375]) / 255
transforms = A.Compose([
A.Normalize(mean=list(MEAN), std=list(STD)),
])
def get_intermediate_layers(
self,
x: torch.Tensor,
n=1,
reshape: bool = False,
return_prefix_tokens: bool = False,
return_class_token: bool = False,
norm: bool = True,
):
outputs = self._intermediate_layers(x, n)
if norm:
outputs = [self.norm(out) for out in outputs]
if return_class_token:
prefix_tokens = [out[:, 0] for out in outputs]
else:
prefix_tokens = [out[:, 0 : self.num_prefix_tokens] for out in outputs]
outputs = [out[:, self.num_prefix_tokens :] for out in outputs]
if reshape:
B, C, H, W = x.shape
grid_size = (
(H - self.patch_embed.patch_size[0])
// self.patch_embed.proj.stride[0]
+ 1,
(W - self.patch_embed.patch_size[1])
// self.patch_embed.proj.stride[1]
+ 1,
)
outputs = [
out.reshape(x.shape[0], grid_size[0], grid_size[1], -1)
.permute(0, 3, 1, 2)
.contiguous()
for out in outputs
]
if return_prefix_tokens or return_class_token:
return tuple(zip(outputs, prefix_tokens))
return tuple(outputs)
def viz_feat(feat):
_,_,h,w = feat.shape
feat = feat.squeeze(0).permute((1,2,0))
projected_featmap = feat.reshape(-1, feat.shape[-1]).cpu()
pca = PCA(n_components=3)
pca.fit(projected_featmap)
pca_features = pca.transform(projected_featmap)
pca_features = (pca_features - pca_features.min()) / (pca_features.max() - pca_features.min())
pca_features = pca_features * 255
res_pred = Image.fromarray(pca_features.reshape(h, w, 3).astype(np.uint8))
return res_pred
def plot_feats(model_option, ori_feats, fine_feats, ori_labels=None, fine_labels=None):
ori_feats_map = viz_feat(ori_feats)
fine_feats_map = viz_feat(fine_feats)
fig, ax = plt.subplots(2, 2, figsize=(6, 5))
ax[0][0].imshow(ori_feats_map)
ax[0][0].set_title("Original " + model_option, fontsize=15)
ax[0][1].imshow(fine_feats_map)
ax[0][1].set_title("Fine-tuned", fontsize=15)
ax[1][0].imshow(ori_labels)
ax[1][1].imshow(fine_labels)
for xx in ax:
for x in xx:
x.xaxis.set_major_formatter(plt.NullFormatter())
x.yaxis.set_major_formatter(plt.NullFormatter())
x.set_xticks([])
x.set_yticks([])
x.axis('off')
plt.tight_layout()
plt.close(fig)
return fig
def download_image(url, save_path):
response = requests.get(url)
with open(save_path, 'wb') as file:
file.write(response.content)
def process_image(image, stride, transforms):
transformed = transforms(image=np.array(image))
image_tensor = torch.tensor(transformed['image'])
image_tensor = image_tensor.permute(2,0,1)
image_tensor = image_tensor.unsqueeze(0).to(device)
h, w = image_tensor.shape[2:]
height_int = (h // stride)*stride
width_int = (w // stride)*stride
image_resized = torch.nn.functional.interpolate(image_tensor, size=(height_int, width_int), mode='bilinear')
return image_resized
def kmeans_clustering(feats_map, n_clusters=20):
if n_clusters == None:
n_clusters = 20
print('num clusters: ', n_clusters)
B, D, h, w = feats_map.shape
feats_map_flattened = feats_map.permute((0, 2, 3, 1)).reshape(B, -1, D)
kmeans_engine = KMeans(n_clusters=n_clusters, distance=CosineSimilarity)
kmeans_engine.fit(feats_map_flattened)
labels = kmeans_engine.predict(
feats_map_flattened
)
labels = labels.reshape(
B, h, w
).float()
labels = labels[0].cpu().numpy()
label_map = cmap(labels / n_clusters)[..., :3]
label_map = np.uint8(label_map * 255)
label_map = Image.fromarray(label_map)
return label_map
def load_model(options):
original_models = {}
fine_models = {}
for option in tqdm(options):
print('Please wait ...')
print('loading weights of ', option)
original_models[option] = timm.create_model(
timm_model_card[option],
pretrained=True,
num_classes=0,
dynamic_img_size=True,
dynamic_img_pad=False,
).to(device)
original_models[option].get_intermediate_layers = types.MethodType(
get_intermediate_layers,
original_models[option]
)
fine_models[option] = torch.hub.load("ywyue/FiT3D", our_model_card[option]).to(device)
fine_models[option].get_intermediate_layers = types.MethodType(
get_intermediate_layers,
fine_models[option]
)
print('Done! Now play the demo :)')
return original_models, fine_models
if __name__ == "__main__":
if torch.cuda.is_available():
device = torch.device('cuda')
else:
device = torch.device('cpu')
print("device: ")
print(device)
example_urls = {
"library.jpg": "https://n.ethz.ch/~yuayue/assets/fit3d/demo_images/library.jpg",
"livingroom.jpg": "https://n.ethz.ch/~yuayue/assets/fit3d/demo_images/livingroom.jpg",
"airplane.jpg": "https://n.ethz.ch/~yuayue/assets/fit3d/demo_images/airplane.jpg",
"ship.jpg": "https://n.ethz.ch/~yuayue/assets/fit3d/demo_images/ship.jpg",
"chair.jpg": "https://n.ethz.ch/~yuayue/assets/fit3d/demo_images/chair.jpg",
}
example_dir = "/tmp/examples"
os.makedirs(example_dir, exist_ok=True)
for name, url in example_urls.items():
save_path = os.path.join(example_dir, name)
if not os.path.exists(save_path):
print(f"Downloading to {save_path}...")
download_image(url, save_path)
else:
print(f"{save_path} already exists.")
image_input = gr.Image(label="Choose an image:",
height=500,
type="pil",
image_mode='RGB',
sources=['upload', 'webcam', 'clipboard']
)
options = ['DINOv2', 'DINOv2-reg', 'CLIP', 'MAE', 'DeiT-III']
model_option = gr.Radio(options, value="DINOv2", label='Choose a 2D foundation model')
kmeans_num = gr.Number(
label="Number of K-Means clusters", value=20
)
timm_model_card = {
"DINOv2": "vit_small_patch14_dinov2.lvd142m",
"DINOv2-reg": "vit_small_patch14_reg4_dinov2.lvd142m",
"CLIP": "vit_base_patch16_clip_384.laion2b_ft_in12k_in1k",
"MAE": "vit_base_patch16_224.mae",
"DeiT-III": "deit3_base_patch16_224.fb_in1k"
}
our_model_card = {
"DINOv2": "dinov2_small_fine",
"DINOv2-reg": "dinov2_reg_small_fine",
"CLIP": "clip_base_fine",
"MAE": "mae_base_fine",
"DeiT-III": "deit3_base_fine"
}
os.environ['TORCH_HOME'] = '/tmp/.cache'
# os.environ['GRADIO_EXAMPLES_CACHE'] = '/tmp/gradio_cache'
# Pre-load all models
original_models, fine_models = load_model(options)
@spaces.GPU
def fit3d(image, model_option, kmeans_num):
# Select model
original_model = original_models[model_option]
fine_model = fine_models[model_option]
# Data preprocessing
p = original_model.patch_embed.patch_size
stride = p if isinstance(p, int) else p[0]
image_resized = process_image(image, stride, transforms)
with torch.no_grad():
ori_feats = original_model.get_intermediate_layers(image_resized, n=[8,9,10,11], reshape=True, return_prefix_tokens=False,
return_class_token=False, norm=True)
fine_feats = fine_model.get_intermediate_layers(image_resized, n=[8,9,10,11], reshape=True, return_prefix_tokens=False,
return_class_token=False, norm=True)
ori_feats = ori_feats[-1]
fine_feats = fine_feats[-1]
ori_labels = kmeans_clustering(ori_feats, kmeans_num)
fine_labels = kmeans_clustering(fine_feats, kmeans_num)
return plot_feats(model_option, ori_feats, fine_feats, ori_labels, fine_labels)
demo = gr.Interface(
title="<div> \
<h1>FiT3D</h1> \
<h2>Improving 2D Feature Representations by 3D-Aware Fine-Tuning</h2> \
<h2>ECCV 2024</h2> \
</div>",
description="<div style='display: flex; justify-content: center; align-items: center; text-align: center;'> \
<a href='https://arxiv.org/abs/2407.20229'><img src='https://img.shields.io/badge/arXiv-2407.20229-red'></a> \
&nbsp; \
<a href='https://ywyue.github.io/FiT3D'><img src='https://img.shields.io/badge/Project_Page-FiT3D-green' alt='Project Page'></a> \
&nbsp; \
<a href='https://github.com/ywyue/FiT3D'><img src='https://img.shields.io/badge/Github-Code-blue'></a> \
</div>",
fn=fit3d,
inputs=[image_input, model_option, kmeans_num],
outputs="plot",
examples=[
["/tmp/examples/library.jpg", "DINOv2", 20],
["/tmp/examples/livingroom.jpg", "DINOv2", 20],
["/tmp/examples/airplane.jpg", "DINOv2", 20],
["/tmp/examples/ship.jpg", "DINOv2", 20],
["/tmp/examples/chair.jpg", "DINOv2", 20],
],
cache_examples=True)
demo.launch()