File size: 3,003 Bytes
7762f58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import torch
import random
import numpy as np
from PIL import Image
from collections import defaultdict
import os
# Mentioning detectron2 as a dependency directly in requirements.txt tries to install detectron2 before torch and results in an error even if torch is listed as a dependency before detectron2.
# Hence, installing detectron2 this way when using Gradio HF spaces.
os.system('pip install git+https://github.com/facebookresearch/detectron2.git')

from detectron2.data import MetadataCatalog
from detectron2.utils.visualizer import ColorMode, Visualizer
from color_palette import ade_palette
from transformers import Mask2FormerImageProcessor, Mask2FormerForUniversalSegmentation

def load_model_and_processor(model_ckpt: str):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = Mask2FormerForUniversalSegmentation.from_pretrained(model_ckpt).to(torch.device(device))
    model.eval()
    image_preprocessor = Mask2FormerImageProcessor.from_pretrained(model_ckpt)
    return model, image_preprocessor

def load_default_ckpt():
    default_ckpt = "facebook/mask2former-swin-tiny-coco-panoptic"
    return default_ckpt

def draw_panoptic_segmentation(predicted_segmentation_map, seg_info, image):
    metadata = MetadataCatalog.get("coco_2017_val_panoptic")
    for res in seg_info:
        res['category_id'] = res.pop('label_id')
        pred_class = res['category_id']
        isthing = pred_class in metadata.thing_dataset_id_to_contiguous_id.values()
        res['isthing'] = bool(isthing)

    visualizer = Visualizer(np.array(image)[:, :, ::-1], metadata=metadata, instance_mode=ColorMode.IMAGE)
    out = visualizer.draw_panoptic_seg_predictions(
    predicted_segmentation_map.cpu(), seg_info, alpha=0.5
    )
    output_img = Image.fromarray(out.get_image())
    labels = [res['category_id'] for res in seg_info]
    return output_img, labels



def predict_masks(input_img_path: str):
    
    #load model and image processor
    default_ckpt = load_default_ckpt()
    model, image_processor = load_model_and_processor()
    
    ## pass input image through image processor
    image = Image.open(input_img_path)
    inputs = image_processor(images=image, return_tensors="pt")
    
    ## pass inputs to model for prediction
    with torch.no_grad():
        outputs = model(**inputs)
    result = image_processor.post_process_panoptic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
    predicted_segmentation_map = result["segmentation"]
    seg_info = result['segments_info']
    output_result, labels = draw_panoptic_segmentation(predicted_segmentation_map, seg_info, image)
    output_heading = "Panoptic Segmentation Output"

    return output_result, output_heading, labels





def get_mask_for_label(results, label):
    import numpy as np
    from PIL import Image

    mask = (results['segmentation'].numpy() == label)
    visual_mask = (mask * 255).astype(np.uint8)
    visual_mask = Image.fromarray(visual_mask)
    
    return visual_mask