ynhe's picture
[Init]
5423e78
raw
history blame
3.63 kB
# Copyright (c) Facebook, Inc. and its affiliates.
# Modified by Jialian Wu from https://github.com/facebookresearch/detectron2/blob/main/detectron2/utils/visualizer.py
import torch
from detectron2.engine.defaults import DefaultPredictor
from detectron2.utils.visualizer import ColorMode, Visualizer
class BatchDefaultPredictor(DefaultPredictor):
def __call__(self, original_images):
"""
Args:
original_image (np.ndarray): an image of shape (H, W, C) (in BGR order).
Returns:
predictions (dict):
the output of the model for one image only.
See :doc:`/tutorials/models` for details about the format.
"""
with torch.no_grad(): # https://github.com/sphinx-doc/sphinx/issues/4258
# Apply pre-processing to image.
height, width = original_images.shape[1:3]
batch_inputs = []
for original_image in original_images:
image = self.aug.get_transform(original_image).apply_image(original_image)
image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1))
inputs = {"image": image, "height": height, "width": width}
batch_inputs.append(inputs)
predictions = self.model(batch_inputs)[0]
return predictions
class Visualizer_GRiT(Visualizer):
def __init__(self, image, instance_mode=None):
super().__init__(image, instance_mode=instance_mode)
def draw_instance_predictions(self, predictions):
boxes = predictions.pred_boxes if predictions.has("pred_boxes") else None
scores = predictions.scores if predictions.has("scores") else None
classes = predictions.pred_classes.tolist() if predictions.has("pred_classes") else None
object_description = predictions.pred_object_descriptions.data
# uncomment to output scores in visualized images
# object_description = [c + '|' + str(round(s.item(), 1)) for c, s in zip(object_description, scores)]
if self._instance_mode == ColorMode.SEGMENTATION and self.metadata.get("thing_colors"):
colors = [
self._jitter([x / 255 for x in self.metadata.thing_colors[c]]) for c in classes
]
alpha = 0.8
else:
colors = None
alpha = 0.5
if self._instance_mode == ColorMode.IMAGE_BW:
self.output.reset_image(
self._create_grayscale_image(
(predictions.pred_masks.any(dim=0) > 0).numpy()
if predictions.has("pred_masks")
else None
)
)
alpha = 0.3
self.overlay_instances(
masks=None,
boxes=boxes,
labels=object_description,
keypoints=None,
assigned_colors=colors,
alpha=alpha,
)
return self.output
class VisualizationDemo(object):
def __init__(self, cfg, instance_mode=ColorMode.IMAGE):
self.cpu_device = torch.device("cpu")
self.instance_mode = instance_mode
self.predictor = DefaultPredictor(cfg)
def run_on_image(self, image):
predictions = self.predictor(image)
# Convert image from OpenCV BGR format to Matplotlib RGB format.
image = image[:, :, ::-1]
visualizer = Visualizer_GRiT(image, instance_mode=self.instance_mode)
instances = predictions["instances"].to(self.cpu_device)
vis_output = visualizer.draw_instance_predictions(predictions=instances)
return predictions, vis_output