FurnishAI-Test / predict.py
TheoBH's picture
Update predict.py
4db2969 verified
import torch
import random
import numpy as np
from PIL import Image
from collections import defaultdict
import os
import spaces
# Mentioning detectron2 as a dependency directly in requirements.txt tries to install detectron2 before torch and results in an error even if torch is listed as a dependency before detectron2.
# Hence, installing detectron2 this way when using Gradio HF spaces.
os.system('pip install git+https://github.com/facebookresearch/detectron2.git')
from detectron2.data import MetadataCatalog
from detectron2.utils.visualizer import ColorMode, Visualizer
from color_palette import ade_palette
from transformers import Mask2FormerImageProcessor, Mask2FormerForUniversalSegmentation
@spaces.GPU
def load_model_and_processor(model_ckpt: str):
device = "cuda" if torch.cuda.is_available() else "cpu"
model = Mask2FormerForUniversalSegmentation.from_pretrained(model_ckpt).to(torch.device(device))
model.eval()
image_preprocessor = Mask2FormerImageProcessor.from_pretrained(model_ckpt)
return model, image_preprocessor
@spaces.GPU
def load_default_ckpt():
default_ckpt = "facebook/mask2former-swin-tiny-coco-panoptic"
return default_ckpt
@spaces.GPU
def draw_panoptic_segmentation(predicted_segmentation_map, seg_info, image):
metadata = MetadataCatalog.get("coco_2017_val_panoptic")
for res in seg_info:
res['category_id'] = res.pop('label_id')
pred_class = res['category_id']
isthing = pred_class in metadata.thing_dataset_id_to_contiguous_id.values()
res['isthing'] = bool(isthing)
visualizer = Visualizer(np.array(image)[:, :, ::-1], metadata=metadata, instance_mode=ColorMode.IMAGE)
out = visualizer.draw_panoptic_seg_predictions(
predicted_segmentation_map.cpu(), seg_info, alpha=0.5
)
output_img = Image.fromarray(out.get_image())
labels = [res['category_id'] for res in seg_info]
return output_img, labels
@spaces.GPU
def predict_masks(input_img_path: str):
#load model and image processor
default_ckpt = load_default_ckpt()
model, image_processor = load_model_and_processor(default_ckpt)
## pass input image through image processor
image = Image.open(input_img_path)
inputs = image_processor(images=image, return_tensors="pt")
# Move inputs to the same device as the model
inputs = {name: tensor.to(model.device) for name, tensor in inputs.items()}
## pass inputs to model for prediction
with torch.no_grad():
outputs = model(**inputs)
result = image_processor.post_process_panoptic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
predicted_segmentation_map = result["segmentation"]
seg_info = result['segments_info']
output_result, labels = draw_panoptic_segmentation(predicted_segmentation_map, seg_info, image)
output_heading = "Panoptic Segmentation Output"
return output_result, output_heading, labels
@spaces.GPU
def get_mask_for_label(results, label):
import numpy as np
from PIL import Image
mask = (results['segmentation'].numpy() == label)
visual_mask = (mask * 255).astype(np.uint8)
visual_mask = Image.fromarray(visual_mask)
return visual_mask