import torch import torch.nn as nn import torch.nn.functional as F from detectron2.structures import Instances def segmentation_postprocess( results: Instances, output_height: int, output_width: int, mask_threshold: float = 0.5 ): if isinstance(output_width, torch.Tensor): # This shape might (but not necessarily) be tensors during tracing. # Converts integer tensors to float temporaries to ensure true # division is performed when computing scale_x and scale_y. output_width_tmp = output_width.float() output_height_tmp = output_height.float() new_size = torch.stack([output_height, output_width]) else: new_size = (output_height, output_width) output_width_tmp = output_width output_height_tmp = output_height scale_x, scale_y = ( output_width_tmp / results.image_size[1], output_height_tmp / results.image_size[0], ) results = Instances(new_size, **results.get_fields()) if results.has("pred_boxes"): output_boxes = results.pred_boxes elif results.has("proposal_boxes"): output_boxes = results.proposal_boxes else: output_boxes = None assert output_boxes is not None, "Predictions must contain boxes!" output_boxes.scale(scale_x, scale_y) output_boxes.clip(results.image_size) results = results[output_boxes.nonempty()] if results.has("pred_masks"): # import pdb;pdb.set_trace() mask = F.interpolate(results.pred_masks.float(), size=(output_height, output_width), mode='nearest') # import pdb;pdb.set_trace() mask = mask.squeeze(1).byte() results.pred_masks = mask # import pdb;pdb.set_trace() # results.pred_masks [N, output-height, output-width] return results