Spaces:
Sleeping
Sleeping
File size: 4,664 Bytes
c184c0e d2ff88f c184c0e 29cef8e d2ff88f 29cef8e 74b0d5e 29cef8e d2ff88f 74b0d5e d2ff88f 35301df d2ff88f b957ec1 d2ff88f 29cef8e b79f89f 74b0d5e b79f89f 29cef8e 74b0d5e b79f89f 74b0d5e 1ceab8e 74b0d5e 3457a2e 74b0d5e cf8de3e 74b0d5e cf8de3e 74b0d5e b79f89f 74b0d5e b79f89f 74b0d5e b79f89f 74b0d5e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
from models.builder import build_model
from visualization import mask2rgb
from segmentation.datasets import PascalVOCDataset
import os
from hydra import compose, initialize
from PIL import Image
import matplotlib.pyplot as plt
from torchvision import transforms as T
import torch.nn.functional as F
import numpy as np
from operator import itemgetter
import torch
import random
import warnings
warnings.filterwarnings("ignore")
initialize(config_path="configs", version_base=None)
from huggingface_hub import Repository
repo = Repository(
local_dir="clip-dinoiser",
clone_from="ariG23498/clip-dinoiser",
use_auth_token=os.environ.get("token")
)
check_path = 'clip-dinoiser/checkpoints/last.pt'
device = "cuda" if torch.cuda.is_available() else "cpu"
check = torch.load(check_path, map_location=device)
dinoclip_cfg = "clip_dinoiser.yaml"
cfg = compose(config_name=dinoclip_cfg)
model = build_model(cfg.model, class_names=PascalVOCDataset.CLASSES).to(device)
model.clip_backbone.decode_head.use_templates=False # switching off the imagenet templates for fast inference
model.load_state_dict(check['model_state_dict'], strict=False)
model = model.eval()
import gradio as gr
colors = [
(0, 255, 0),
(0, 0, 255),
(255, 255, 0),
(255, 0, 255),
(0, 255, 255),
(114, 128, 250),
(0, 165, 255),
(0, 128, 0),
(144, 238, 144),
(238, 238, 175),
(255, 191, 0),
(0, 128, 0),
(226, 43, 138),
(255, 0, 255),
(0, 215, 255),
(255, 0, 0),
]
color_map = {
f"{color_id}": f"#{hex(color[2])[2:].zfill(2)}{hex(color[1])[2:].zfill(2)}{hex(color[0])[2:].zfill(2)}" for color_id, color in enumerate(colors)
}
def run_clip_dinoiser(input_image, text_prompts):
image = input_image.convert("RGB")
text_prompts = text_prompts.split(",")
palette = colors[:len(text_prompts)]
model.clip_backbone.decode_head.update_vocab(text_prompts)
model.to(device)
model.apply_found = True
img_tens = T.PILToTensor()(image).unsqueeze(0).to(device) / 255.
h, w = img_tens.shape[-2:]
output = model(img_tens).cpu()
output = F.interpolate(output, scale_factor=model.clip_backbone.backbone.patch_size, mode="bilinear",
align_corners=False)[..., :h, :w]
output = output[0].argmax(dim=0)
mask = mask2rgb(output, palette)
# fig = plt.figure(figsize=(3, 1))
# classes = np.unique(output).tolist()
# plt.imshow(np.array(itemgetter(*classes)(palette)).reshape(1, -1, 3))
# plt.xticks(np.arange(len(classes)), list(itemgetter(*classes)(text_prompts)), rotation=45)
# plt.yticks([])
# fig, ax = plt.subplots(nrows=1, ncols=2)
# alpha=0.5
# blend = (alpha)*np.array(image)/255. + (1-alpha) * mask/255.
# ax[0].imshow(blend)
# ax[1].imshow(mask)
# ax[0].axis('off')
# ax[1].axis('off')
classes = np.unique(output).tolist()
palette_array = np.array(itemgetter(*classes)(palette)).reshape(1, -1, 3)
alpha=0.5
blend = (alpha)*np.array(image)/255. + (1-alpha) * mask/255.
h_text = list()
for idx, text in enumerate(text_prompts):
for alphabet in text:
h_text.append((alphabet, color_map[str(idx)]))
return blend, mask, h_text
if __name__ == "__main__":
block = gr.Blocks().queue()
with block:
gr.Markdown("<h1><center>CLIP-DINOiser<h1><center>")
with gr.Row():
with gr.Column():
input_image = gr.Image(type="pil")
text_prompts = gr.Textbox(label="Enter comma-separated prompts")
run_button = gr.Button(value="Run")
with gr.Column():
with gr.Row():
overlay_mask = gr.Image(
type="numpy",
)
only_mask = gr.Image(
type="numpy",
)
h_text = gr.HighlightedText(
label="text",
combine_adjacent=False,
show_legend=False,
color_map=color_map
)
run_button.click(
fn=run_clip_dinoiser,
inputs=[input_image, text_prompts,],
outputs=[overlay_mask, only_mask, h_text]
)
gr.Examples(
[["vintage_bike.jpeg", "background, vintage bike, leather bag"]],
inputs = [input_image, text_prompts,],
outputs = [overlay_mask, only_mask, h_text],
fn=run_clip_dinoiser,
cache_examples=True,
label='Try this example input!'
)
block.launch(share=False, show_api=False, show_error=True) |