Spaces:
Running
Running
File size: 1,811 Bytes
a153c95 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 |
import torch
import torch.nn as nn
import torch.nn.functional as F
from detectron2.structures import Instances
def segmentation_postprocess(
results: Instances, output_height: int, output_width: int, mask_threshold: float = 0.5
):
if isinstance(output_width, torch.Tensor):
# This shape might (but not necessarily) be tensors during tracing.
# Converts integer tensors to float temporaries to ensure true
# division is performed when computing scale_x and scale_y.
output_width_tmp = output_width.float()
output_height_tmp = output_height.float()
new_size = torch.stack([output_height, output_width])
else:
new_size = (output_height, output_width)
output_width_tmp = output_width
output_height_tmp = output_height
scale_x, scale_y = (
output_width_tmp / results.image_size[1],
output_height_tmp / results.image_size[0],
)
results = Instances(new_size, **results.get_fields())
if results.has("pred_boxes"):
output_boxes = results.pred_boxes
elif results.has("proposal_boxes"):
output_boxes = results.proposal_boxes
else:
output_boxes = None
assert output_boxes is not None, "Predictions must contain boxes!"
output_boxes.scale(scale_x, scale_y)
output_boxes.clip(results.image_size)
results = results[output_boxes.nonempty()]
if results.has("pred_masks"):
# import pdb;pdb.set_trace()
mask = F.interpolate(results.pred_masks.float(), size=(output_height, output_width), mode='nearest')
# import pdb;pdb.set_trace()
mask = mask.squeeze(1).byte()
results.pred_masks = mask
# import pdb;pdb.set_trace()
# results.pred_masks [N, output-height, output-width]
return results |