Spaces:
Sleeping
Sleeping
File size: 4,610 Bytes
d2ff88f d5766fb 03067d6 d5766fb 03067d6 29cef8e d2ff88f 03067d6 d5766fb d2ff88f d5766fb 29cef8e d5766fb b79f89f d5766fb b79f89f d5766fb b79f89f d5766fb b79f89f d5766fb b79f89f d5766fb b79f89f d5766fb 03067d6 d5766fb 03067d6 d5766fb b79f89f d5766fb 03067d6 d5766fb 03067d6 d5766fb 03067d6 d6c6d2a 03067d6 74b0d5e 03067d6 29cef8e 74b0d5e 03067d6 74b0d5e 03067d6 74b0d5e d5766fb 03067d6 b79f89f d5766fb b79f89f d5766fb 74b0d5e d5766fb 74b0d5e aaf6fd0 74b0d5e 3457a2e 74b0d5e f00aad7 b79f89f aaf6fd0 b79f89f 03067d6 b79f89f 74b0d5e 03067d6 d5766fb 03067d6 74b0d5e d5766fb 74b0d5e d5766fb 03067d6 b79f89f 03067d6 d5766fb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
import os
import warnings
import gradio as gr
import numpy as np
import torch
import torch.nn.functional as F
from huggingface_hub import Repository
from hydra import compose, initialize
from PIL import Image
from torchvision import transforms as T
from models.builder import build_model
from segmentation.datasets import PascalVOCDataset
from visualization import mask2rgb
# Suppress warnings
warnings.filterwarnings("ignore")
# Constants
CHECKPOINT_PATH = "clip-dinoiser/checkpoints/last.pt"
CONFIG_PATH = "configs"
DINOCLIP_CONFIG = "clip_dinoiser.yaml"
COLORS = [
(0, 255, 0),
(255, 0, 0),
(0, 255, 255),
(255, 0, 255),
(255, 255, 0),
(250, 128, 114),
(255, 165, 0),
(0, 128, 0),
(144, 238, 144),
(175, 238, 238),
(0, 191, 255),
(0, 128, 0),
(138, 43, 226),
(255, 0, 255),
(255, 215, 0),
(0, 0, 255),
]
# Initialize Hydra
initialize(config_path=CONFIG_PATH, version_base=None)
# Configuration and Model Initialization
def load_model():
Repository(
local_dir="clip-dinoiser",
clone_from="ariG23498/clip-dinoiser",
use_auth_token=os.environ.get("token"),
)
device = "cuda" if torch.cuda.is_available() else "cpu"
checkpoint = torch.load(CHECKPOINT_PATH, map_location=device)
cfg = compose(config_name=DINOCLIP_CONFIG)
model = build_model(cfg.model, class_names=PascalVOCDataset.CLASSES).to(device)
model.clip_backbone.decode_head.use_templates = False
model.load_state_dict(checkpoint["model_state_dict"], strict=False)
return model.eval()
def run_clip_dinoiser(input_image, text_prompts, model, device, colors):
# Resize the input image
image = input_image.resize((350, 350))
image = image.convert("RGB")
text_prompts = text_prompts.split(",")
palette = colors[: len(text_prompts)]
model.clip_backbone.decode_head.update_vocab(text_prompts)
model.to(device)
img_tens = T.PILToTensor()(image).unsqueeze(0).to(device) / 255.0
h, w = img_tens.shape[-2:]
output = model(img_tens).cpu()
output = F.interpolate(
output,
scale_factor=model.clip_backbone.backbone.patch_size,
mode="bilinear",
align_corners=False,
)[..., :h, :w]
output = output[0].argmax(dim=0)
mask = mask2rgb(output, palette)
alpha = 0.5
blend = (alpha * np.array(image) / 255.0) + ((1 - alpha) * mask / 255.0)
h_text = [(text, f"{idx}") for idx, text in enumerate(text_prompts)]
return blend, mask, h_text
def create_color_map(colors):
return {
f"{color_id}": f"#{hex(color[0])[2:].zfill(2)}{hex(color[1])[2:].zfill(2)}{hex(color[2])[2:].zfill(2)}"
for color_id, color in enumerate(colors)
}
def setup_gradio_interface(model, device, colors, color_map):
block = gr.Blocks()
with block:
gr.Markdown("<h1><center>CLIP-DINOiser<h1><center>")
with gr.Row():
with gr.Column():
input_image = gr.Image(type="pil", label="Input Image")
text_prompts = gr.Textbox(label="Enter comma-separated prompts")
run_button = gr.Button(value="Run")
with gr.Column():
with gr.Row():
overlay_mask = gr.Image(type="numpy", label="Overlay Mask")
only_mask = gr.Image(type="numpy", label="Segmentation Mask")
h_text = gr.HighlightedText(
label="Labels",
combine_adjacent=False,
show_legend=False,
color_map=color_map,
)
run_button.click(
fn=lambda img, prompts: run_clip_dinoiser(
img, prompts, model, device, colors
),
inputs=[input_image, text_prompts],
outputs=[overlay_mask, only_mask, h_text],
)
gr.Examples(
examples=[["vintage_bike.jpeg", "background, vintage bike, leather bag"]],
inputs=[input_image, text_prompts],
outputs=[overlay_mask, only_mask, h_text],
fn=lambda img, prompts: run_clip_dinoiser(
img, prompts, model, device, colors
),
cache_examples=True,
label="Try this example input!",
)
return block
if __name__ == "__main__":
model = load_model()
device = "cuda" if torch.cuda.is_available() else "cpu"
color_map = create_color_map(COLORS)
gradio_interface = setup_gradio_interface(model, device, COLORS, color_map)
gradio_interface.launch(share=False, show_api=False, show_error=True)
|