import torch import random import numpy as np from PIL import Image from collections import defaultdict import os # Mentioning detectron2 as a dependency directly in requirements.txt tries to install detectron2 before torch and results in an error even if torch is listed as a dependency before detectron2. # Hence, installing detectron2 this way when using Gradio HF spaces. os.system('pip install git+https://github.com/facebookresearch/detectron2.git') from detectron2.data import MetadataCatalog from detectron2.utils.visualizer import ColorMode, Visualizer from color_palette import ade_palette from transformers import Mask2FormerImageProcessor, Mask2FormerForUniversalSegmentation def load_model_and_processor(model_ckpt: str): device = "cuda" if torch.cuda.is_available() else "cpu" model = Mask2FormerForUniversalSegmentation.from_pretrained(model_ckpt).to(torch.device(device)) model.eval() image_preprocessor = Mask2FormerImageProcessor.from_pretrained(model_ckpt) return model, image_preprocessor def load_default_ckpt(): default_ckpt = "facebook/mask2former-swin-tiny-coco-panoptic" return default_ckpt def draw_panoptic_segmentation(predicted_segmentation_map, seg_info, image): metadata = MetadataCatalog.get("coco_2017_val_panoptic") for res in seg_info: res['category_id'] = res.pop('label_id') pred_class = res['category_id'] isthing = pred_class in metadata.thing_dataset_id_to_contiguous_id.values() res['isthing'] = bool(isthing) visualizer = Visualizer(np.array(image)[:, :, ::-1], metadata=metadata, instance_mode=ColorMode.IMAGE) out = visualizer.draw_panoptic_seg_predictions( predicted_segmentation_map.cpu(), seg_info, alpha=0.5 ) output_img = Image.fromarray(out.get_image()) labels = [res['category_id'] for res in seg_info] return output_img, labels def predict_masks(input_img_path: str): #load model and image processor default_ckpt = load_default_ckpt() model, image_processor = load_model_and_processor() ## pass input image through image processor image = Image.open(input_img_path) inputs = image_processor(images=image, return_tensors="pt") ## pass inputs to model for prediction with torch.no_grad(): outputs = model(**inputs) result = image_processor.post_process_panoptic_segmentation(outputs, target_sizes=[image.size[::-1]])[0] predicted_segmentation_map = result["segmentation"] seg_info = result['segments_info'] output_result, labels = draw_panoptic_segmentation(predicted_segmentation_map, seg_info, image) output_heading = "Panoptic Segmentation Output" return output_result, output_heading, labels def get_mask_for_label(results, label): import numpy as np from PIL import Image mask = (results['segmentation'].numpy() == label) visual_mask = (mask * 255).astype(np.uint8) visual_mask = Image.fromarray(visual_mask) return visual_mask