Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import spaces | |
import torch | |
import torchvision.transforms | |
import numpy as np | |
from transformers import AutoModel | |
from theia.decoding import load_feature_stats, prepare_depth_decoder, prepare_mask_generator, decode_everything | |
def load_description(fp): | |
with open(fp, 'r', encoding='utf-8') as f: | |
content = f.read() | |
return content | |
def run_theia(model_size, image, pred_iou_thresh, stability_score_thresh): | |
theia_model = AutoModel.from_pretrained(f"theaiinstitute/theia-{model_size}-patch16-224-cddsv", trust_remote_code=True) | |
theia_model = theia_model.to('cuda') | |
target_model_names = [ | |
"google/vit-huge-patch14-224-in21k", | |
"facebook/dinov2-large", | |
"openai/clip-vit-large-patch14", | |
"facebook/sam-vit-huge", | |
"LiheYoung/depth-anything-large-hf", | |
] | |
feature_means, feature_vars = load_feature_stats(target_model_names, stat_file_root="feature_stats") | |
mask_generator, sam_model = prepare_mask_generator('cuda') | |
depth_anything_model_name = "LiheYoung/depth-anything-large-hf" | |
depth_anything_decoder, _ = prepare_depth_decoder(depth_anything_model_name, 'cuda') | |
image = torchvision.transforms.Resize(size=(224, 224))(image) | |
images = [image] | |
theia_decode_results, gt_results = decode_everything( | |
theia_model=theia_model, | |
feature_means=feature_means, | |
feature_vars=feature_vars, | |
images=images, | |
mask_generator=mask_generator, | |
sam_model=sam_model, | |
depth_anything_decoder=depth_anything_decoder, | |
pred_iou_thresh=pred_iou_thresh, | |
stability_score_thresh=stability_score_thresh, | |
gt=True, | |
device='cuda', | |
) | |
_, width, _ = theia_decode_results[0].shape | |
theia_decode_results = (255.0 * theia_decode_results[0]).astype(np.uint8) | |
theia_decode_dino = theia_decode_results[:, width // 4 : 2 * width // 4, :] | |
theia_decode_sam = theia_decode_results[:, 2 * width // 4 : 3 * width // 4, :] | |
theia_decode_depth = theia_decode_results[:, 3 * width // 4 :, :] | |
gt_results = (255.0 * gt_results[0]).astype(np.uint8) | |
gt_dino = gt_results[:, width // 4 : 2 * width // 4, :] | |
gt_sam = gt_results[:, 2 * width // 4 : 3 * width // 4, :] | |
gt_depth = gt_results[:, 3 * width // 4 :, :] | |
dinov2_output = [(theia_decode_dino, "Theia"), (gt_dino, "Ground Truth")] | |
sam_output = [(theia_decode_sam, "Theia"), (gt_sam, "Ground Truth")] | |
depth_anything_output = [(theia_decode_depth, "Theia"), (gt_depth, "Ground Truth")] | |
return dinov2_output, sam_output, depth_anything_output | |
with gr.Blocks() as demo: | |
gr.HTML(load_description("gradio_title.md")) | |
gr.Markdown("This space demonstrates decoding Theia-predicted VFM representations to their original teacher model outputs. For DINOv2 we apply the PCA visualization, for SAM we use its decoder to generate segmentation masks (but with SAM's pipeline of prompting), and for Depth-Anything we use its decoder head to do depth prediction.") | |
with gr.Row(): | |
with gr.Column(): | |
input_image = gr.Image(label="Input Image", type="pil") | |
with gr.Accordion("Advanced Settings", open=False): | |
model_size = gr.Radio(["tiny", "small", "base"], value="base", label="Theia Model Size") | |
pred_iou_thresh = gr.Slider(0.05, 0.95, step=0.05, value=0.65, label="SAM Pred IoU Thresh") | |
stability_score_thresh = gr.Slider(0.05, 0.95, step=0.05, value=0.85, label="SAM Stability Score Thresh") | |
submit_button = gr.Button("Submit") | |
with gr.Column(): | |
dinov2_output = gr.Gallery(label="DINOv2", type="numpy") | |
sam_output = gr.Gallery(label="SAM", type="numpy") | |
depth_anything_output = gr.Gallery(label="Depth-Anything", type="numpy") | |
submit_button.click( | |
run_theia, | |
inputs=[model_size, input_image, pred_iou_thresh, stability_score_thresh], | |
outputs=[dinov2_output, sam_output, depth_anything_output] | |
) | |
demo.queue() | |
demo.launch() |