RegionSpot / regionspot /util /postprocessing.py
bklg's picture
Upload 114 files
a153c95
raw
history blame
1.81 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
from detectron2.structures import Instances
def segmentation_postprocess(
results: Instances, output_height: int, output_width: int, mask_threshold: float = 0.5
):
if isinstance(output_width, torch.Tensor):
# This shape might (but not necessarily) be tensors during tracing.
# Converts integer tensors to float temporaries to ensure true
# division is performed when computing scale_x and scale_y.
output_width_tmp = output_width.float()
output_height_tmp = output_height.float()
new_size = torch.stack([output_height, output_width])
else:
new_size = (output_height, output_width)
output_width_tmp = output_width
output_height_tmp = output_height
scale_x, scale_y = (
output_width_tmp / results.image_size[1],
output_height_tmp / results.image_size[0],
)
results = Instances(new_size, **results.get_fields())
if results.has("pred_boxes"):
output_boxes = results.pred_boxes
elif results.has("proposal_boxes"):
output_boxes = results.proposal_boxes
else:
output_boxes = None
assert output_boxes is not None, "Predictions must contain boxes!"
output_boxes.scale(scale_x, scale_y)
output_boxes.clip(results.image_size)
results = results[output_boxes.nonempty()]
if results.has("pred_masks"):
# import pdb;pdb.set_trace()
mask = F.interpolate(results.pred_masks.float(), size=(output_height, output_width), mode='nearest')
# import pdb;pdb.set_trace()
mask = mask.squeeze(1).byte()
results.pred_masks = mask
# import pdb;pdb.set_trace()
# results.pred_masks [N, output-height, output-width]
return results