File size: 4,610 Bytes
d2ff88f
d5766fb
03067d6
 
d5766fb
03067d6
29cef8e
d2ff88f
03067d6
 
 
 
d5766fb
 
 
d2ff88f
d5766fb
 
29cef8e
d5766fb
 
 
 
 
b79f89f
d5766fb
b79f89f
d5766fb
 
 
 
b79f89f
 
d5766fb
 
b79f89f
d5766fb
b79f89f
d5766fb
 
b79f89f
 
d5766fb
 
 
03067d6
d5766fb
 
 
 
 
03067d6
d5766fb
 
 
 
 
b79f89f
d5766fb
 
03067d6
d5766fb
 
03067d6
d5766fb
03067d6
d6c6d2a
03067d6
 
74b0d5e
03067d6
29cef8e
74b0d5e
 
 
03067d6
74b0d5e
 
03067d6
 
 
 
 
 
74b0d5e
 
d5766fb
 
03067d6
b79f89f
d5766fb
b79f89f
 
 
d5766fb
 
 
 
 
74b0d5e
 
d5766fb
 
74b0d5e
 
 
 
 
 
aaf6fd0
74b0d5e
3457a2e
74b0d5e
 
f00aad7
 
 
b79f89f
aaf6fd0
b79f89f
 
03067d6
b79f89f
74b0d5e
 
03067d6
 
 
d5766fb
03067d6
74b0d5e
d5766fb
74b0d5e
d5766fb
 
 
03067d6
 
 
b79f89f
03067d6
d5766fb
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import os
import warnings

import gradio as gr
import numpy as np
import torch
import torch.nn.functional as F
from huggingface_hub import Repository
from hydra import compose, initialize
from PIL import Image
from torchvision import transforms as T

from models.builder import build_model
from segmentation.datasets import PascalVOCDataset
from visualization import mask2rgb

# Suppress warnings
warnings.filterwarnings("ignore")

# Constants
CHECKPOINT_PATH = "clip-dinoiser/checkpoints/last.pt"
CONFIG_PATH = "configs"
DINOCLIP_CONFIG = "clip_dinoiser.yaml"
COLORS = [
    (0, 255, 0),
    (255, 0, 0),
    (0, 255, 255),
    (255, 0, 255),
    (255, 255, 0),
    (250, 128, 114),
    (255, 165, 0),
    (0, 128, 0),
    (144, 238, 144),
    (175, 238, 238),
    (0, 191, 255),
    (0, 128, 0),
    (138, 43, 226),
    (255, 0, 255),
    (255, 215, 0),
    (0, 0, 255),
]

# Initialize Hydra
initialize(config_path=CONFIG_PATH, version_base=None)


# Configuration and Model Initialization
def load_model():
    Repository(
        local_dir="clip-dinoiser",
        clone_from="ariG23498/clip-dinoiser",
        use_auth_token=os.environ.get("token"),
    )

    device = "cuda" if torch.cuda.is_available() else "cpu"
    checkpoint = torch.load(CHECKPOINT_PATH, map_location=device)
    cfg = compose(config_name=DINOCLIP_CONFIG)

    model = build_model(cfg.model, class_names=PascalVOCDataset.CLASSES).to(device)
    model.clip_backbone.decode_head.use_templates = False
    model.load_state_dict(checkpoint["model_state_dict"], strict=False)
    return model.eval()


def run_clip_dinoiser(input_image, text_prompts, model, device, colors):
    # Resize the input image
    image = input_image.resize((350, 350))

    image = image.convert("RGB")
    text_prompts = text_prompts.split(",")
    palette = colors[: len(text_prompts)]

    model.clip_backbone.decode_head.update_vocab(text_prompts)
    model.to(device)

    img_tens = T.PILToTensor()(image).unsqueeze(0).to(device) / 255.0
    h, w = img_tens.shape[-2:]
    output = model(img_tens).cpu()
    output = F.interpolate(
        output,
        scale_factor=model.clip_backbone.backbone.patch_size,
        mode="bilinear",
        align_corners=False,
    )[..., :h, :w]
    output = output[0].argmax(dim=0)

    mask = mask2rgb(output, palette)
    alpha = 0.5
    blend = (alpha * np.array(image) / 255.0) + ((1 - alpha) * mask / 255.0)

    h_text = [(text, f"{idx}") for idx, text in enumerate(text_prompts)]
    return blend, mask, h_text


def create_color_map(colors):
    return {
        f"{color_id}": f"#{hex(color[0])[2:].zfill(2)}{hex(color[1])[2:].zfill(2)}{hex(color[2])[2:].zfill(2)}"
        for color_id, color in enumerate(colors)
    }


def setup_gradio_interface(model, device, colors, color_map):
    block = gr.Blocks()

    with block:
        gr.Markdown("<h1><center>CLIP-DINOiser<h1><center>")

        with gr.Row():
            with gr.Column():
                input_image = gr.Image(type="pil", label="Input Image")
                text_prompts = gr.Textbox(label="Enter comma-separated prompts")
                run_button = gr.Button(value="Run")

            with gr.Column():
                with gr.Row():
                    overlay_mask = gr.Image(type="numpy", label="Overlay Mask")
                    only_mask = gr.Image(type="numpy", label="Segmentation Mask")
                h_text = gr.HighlightedText(
                    label="Labels",
                    combine_adjacent=False,
                    show_legend=False,
                    color_map=color_map,
                )

        run_button.click(
            fn=lambda img, prompts: run_clip_dinoiser(
                img, prompts, model, device, colors
            ),
            inputs=[input_image, text_prompts],
            outputs=[overlay_mask, only_mask, h_text],
        )

        gr.Examples(
            examples=[["vintage_bike.jpeg", "background, vintage bike, leather bag"]],
            inputs=[input_image, text_prompts],
            outputs=[overlay_mask, only_mask, h_text],
            fn=lambda img, prompts: run_clip_dinoiser(
                img, prompts, model, device, colors
            ),
            cache_examples=True,
            label="Try this example input!",
        )

    return block


if __name__ == "__main__":
    model = load_model()
    device = "cuda" if torch.cuda.is_available() else "cpu"
    color_map = create_color_map(COLORS)
    gradio_interface = setup_gradio_interface(model, device, COLORS, color_map)
    gradio_interface.launch(share=False, show_api=False, show_error=True)