File size: 4,299 Bytes
e35c029
b194803
e35c029
c62ab28
e35c029
 
 
 
ce0a6c0
 
 
 
 
 
 
c27bfca
9aa9e84
 
b194803
e35c029
 
 
 
 
 
 
6d5e85a
093a08f
b033d86
e35c029
b033d86
8232dd0
f7c781f
e35c029
093a08f
2eca07c
e35c029
 
 
 
 
 
 
9e62052
 
e35c029
b033d86
e35c029
 
939a9a7
9e62052
 
 
 
 
 
 
 
 
939a9a7
6273591
 
 
 
e35c029
2febbfb
 
 
 
 
 
044c290
093a08f
 
78ce8aa
e312e92
 
093a08f
2febbfb
093a08f
2febbfb
044c290
 
 
6273591
9e62052
 
9aa9e84
9e62052
 
2febbfb
b25a60b
2febbfb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import gradio as gr
import spaces
import torch
import torchvision.transforms
import numpy as np
from transformers import AutoModel
from theia.decoding import load_feature_stats, prepare_depth_decoder, prepare_mask_generator, decode_everything


def load_description(fp):
    with open(fp, 'r', encoding='utf-8') as f:
        content = f.read()
    return content


@spaces.GPU(duration=90)
def run_theia(model_size, image, pred_iou_thresh, stability_score_thresh):
    theia_model = AutoModel.from_pretrained(f"theaiinstitute/theia-{model_size}-patch16-224-cddsv", trust_remote_code=True)
    theia_model = theia_model.to('cuda')
    target_model_names = [
        "google/vit-huge-patch14-224-in21k",
        "facebook/dinov2-large",
        "openai/clip-vit-large-patch14",
        "facebook/sam-vit-huge",
        "LiheYoung/depth-anything-large-hf",
    ]
    feature_means, feature_vars = load_feature_stats(target_model_names, stat_file_root="feature_stats")

    mask_generator, sam_model = prepare_mask_generator('cuda')
    depth_anything_model_name = "LiheYoung/depth-anything-large-hf"
    depth_anything_decoder, _ = prepare_depth_decoder(depth_anything_model_name, 'cuda')

    image = torchvision.transforms.Resize(size=(224, 224))(image)
    images = [image]

    theia_decode_results, gt_results = decode_everything(
        theia_model=theia_model,
        feature_means=feature_means,
        feature_vars=feature_vars,
        images=images,
        mask_generator=mask_generator,
        sam_model=sam_model,
        depth_anything_decoder=depth_anything_decoder,
        pred_iou_thresh=pred_iou_thresh,
        stability_score_thresh=stability_score_thresh,
        gt=True,
        device='cuda',
    )

    _, width, _ = theia_decode_results[0].shape
    theia_decode_results = (255.0 * theia_decode_results[0]).astype(np.uint8)
    theia_decode_dino = theia_decode_results[:, width // 4 : 2 * width // 4, :]
    theia_decode_sam = theia_decode_results[:, 2 * width // 4 : 3 * width // 4, :]
    theia_decode_depth = theia_decode_results[:, 3 * width // 4 :, :]

    gt_results = (255.0 * gt_results[0]).astype(np.uint8)
    gt_dino = gt_results[:, width // 4 : 2 * width // 4, :]
    gt_sam = gt_results[:, 2 * width // 4 : 3 * width // 4, :]
    gt_depth = gt_results[:, 3 * width // 4 :, :]

    dinov2_output = [(theia_decode_dino, "Theia"), (gt_dino, "Ground Truth")]
    sam_output = [(theia_decode_sam, "Theia"), (gt_sam, "Ground Truth")]
    depth_anything_output = [(theia_decode_depth, "Theia"), (gt_depth, "Ground Truth")]
    return dinov2_output, sam_output, depth_anything_output

with gr.Blocks() as demo:
    gr.HTML(load_description("gradio_title.md"))
    gr.Markdown("This space demonstrates decoding Theia-predicted VFM representations to their original teacher model outputs. For DINOv2 we apply the PCA visualization, for SAM we use its decoder to generate segmentation masks (but with SAM's pipeline of prompting), and for Depth-Anything we use its decoder head to do depth prediction.")

    with gr.Row():
        with gr.Column():
            input_image = gr.Image(label="Input Image", type="pil")

            with gr.Accordion("Advanced Settings", open=False):
                model_size = gr.Radio(["tiny", "small", "base"], value="base", label="Theia Model Size")
                pred_iou_thresh = gr.Slider(0.05, 0.95, step=0.05, value=0.65, label="SAM Pred IoU Thresh")
                stability_score_thresh = gr.Slider(0.05, 0.95, step=0.05, value=0.85, label="SAM Stability Score Thresh")

            submit_button = gr.Button("Submit")

        with gr.Column():
            dinov2_output = gr.Gallery(label="DINOv2", type="numpy")
            sam_output = gr.Gallery(label="SAM", type="numpy")
            depth_anything_output = gr.Gallery(label="Depth-Anything", type="numpy")

    submit_button.click(
        run_theia,
        inputs=[model_size, input_image, pred_iou_thresh, stability_score_thresh],
        outputs=[dinov2_output, sam_output, depth_anything_output]
    )

demo.queue()
demo.launch()