Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from detectron2.structures import Instances | |
def segmentation_postprocess( | |
results: Instances, output_height: int, output_width: int, mask_threshold: float = 0.5 | |
): | |
if isinstance(output_width, torch.Tensor): | |
# This shape might (but not necessarily) be tensors during tracing. | |
# Converts integer tensors to float temporaries to ensure true | |
# division is performed when computing scale_x and scale_y. | |
output_width_tmp = output_width.float() | |
output_height_tmp = output_height.float() | |
new_size = torch.stack([output_height, output_width]) | |
else: | |
new_size = (output_height, output_width) | |
output_width_tmp = output_width | |
output_height_tmp = output_height | |
scale_x, scale_y = ( | |
output_width_tmp / results.image_size[1], | |
output_height_tmp / results.image_size[0], | |
) | |
results = Instances(new_size, **results.get_fields()) | |
if results.has("pred_boxes"): | |
output_boxes = results.pred_boxes | |
elif results.has("proposal_boxes"): | |
output_boxes = results.proposal_boxes | |
else: | |
output_boxes = None | |
assert output_boxes is not None, "Predictions must contain boxes!" | |
output_boxes.scale(scale_x, scale_y) | |
output_boxes.clip(results.image_size) | |
results = results[output_boxes.nonempty()] | |
if results.has("pred_masks"): | |
# import pdb;pdb.set_trace() | |
mask = F.interpolate(results.pred_masks.float(), size=(output_height, output_width), mode='nearest') | |
# import pdb;pdb.set_trace() | |
mask = mask.squeeze(1).byte() | |
results.pred_masks = mask | |
# import pdb;pdb.set_trace() | |
# results.pred_masks [N, output-height, output-width] | |
return results |