import os import warnings import gradio as gr import numpy as np import torch import torch.nn.functional as F from huggingface_hub import Repository from hydra import compose, initialize from PIL import Image from torchvision import transforms as T from models.builder import build_model from segmentation.datasets import PascalVOCDataset from visualization import mask2rgb # Suppress warnings warnings.filterwarnings("ignore") # Constants CHECKPOINT_PATH = "clip-dinoiser/checkpoints/last.pt" CONFIG_PATH = "configs" DINOCLIP_CONFIG = "clip_dinoiser.yaml" COLORS = [ (0, 255, 0), (255, 0, 0), (0, 255, 255), (255, 0, 255), (255, 255, 0), (250, 128, 114), (255, 165, 0), (0, 128, 0), (144, 238, 144), (175, 238, 238), (0, 191, 255), (0, 128, 0), (138, 43, 226), (255, 0, 255), (255, 215, 0), (0, 0, 255), ] # Initialize Hydra initialize(config_path=CONFIG_PATH, version_base=None) # Configuration and Model Initialization def load_model(): Repository( local_dir="clip-dinoiser", clone_from="ariG23498/clip-dinoiser", use_auth_token=os.environ.get("token"), ) device = "cuda" if torch.cuda.is_available() else "cpu" checkpoint = torch.load(CHECKPOINT_PATH, map_location=device) cfg = compose(config_name=DINOCLIP_CONFIG) model = build_model(cfg.model, class_names=PascalVOCDataset.CLASSES).to(device) model.clip_backbone.decode_head.use_templates = False model.load_state_dict(checkpoint["model_state_dict"], strict=False) return model.eval() def run_clip_dinoiser(input_image, text_prompts, model, device, colors): # Resize the input image image = input_image.resize((350, 350)) image = image.convert("RGB") text_prompts = text_prompts.split(",") palette = colors[: len(text_prompts)] model.clip_backbone.decode_head.update_vocab(text_prompts) model.to(device) img_tens = T.PILToTensor()(image).unsqueeze(0).to(device) / 255.0 h, w = img_tens.shape[-2:] output = model(img_tens).cpu() output = F.interpolate( output, scale_factor=model.clip_backbone.backbone.patch_size, mode="bilinear", align_corners=False, )[..., :h, :w] output = output[0].argmax(dim=0) mask = mask2rgb(output, palette) alpha = 0.5 blend = (alpha * np.array(image) / 255.0) + ((1 - alpha) * mask / 255.0) h_text = [(text, f"{idx}") for idx, text in enumerate(text_prompts)] return blend, mask, h_text def create_color_map(colors): return { f"{color_id}": f"#{hex(color[0])[2:].zfill(2)}{hex(color[1])[2:].zfill(2)}{hex(color[2])[2:].zfill(2)}" for color_id, color in enumerate(colors) } def setup_gradio_interface(model, device, colors, color_map): block = gr.Blocks() with block: gr.Markdown("