import torch def preprocess_video_panoptic_gt( gt_labels, gt_masks, gt_semantic_seg, gt_instance_ids, num_things, num_stuff, ): num_classes = num_things + num_stuff num_frames = len(gt_masks) mask_size = gt_masks[0].masks.shape[-2:] thing_masks_list = [] for frame_id in range(num_frames): thing_masks_list.append(gt_masks[frame_id].pad( mask_size, pad_val=0).to_tensor( dtype=torch.bool, device=gt_labels.device) ) instances = torch.unique(gt_instance_ids[:, 1]) things_masks = [] labels = [] for instance in instances: pos_ins = torch.nonzero(torch.eq(gt_instance_ids[:, 1], instance), as_tuple=True)[0] # 0 is for redundant tuple labels_instance = gt_labels[:, 1][pos_ins] assert torch.allclose(labels_instance, labels_instance[0]) labels.append(labels_instance[0]) instance_frame_ids = gt_instance_ids[:, 0][pos_ins].to(dtype=torch.int32).tolist() instance_masks = [] for frame_id in range(num_frames): frame_instance_ids = gt_instance_ids[gt_instance_ids[:, 0] == frame_id, 1] if frame_id not in instance_frame_ids: empty_mask = torch.zeros( mask_size, dtype=thing_masks_list[frame_id].dtype, device=thing_masks_list[frame_id].device ) instance_masks.append(empty_mask) else: pos_inner_frame = torch.nonzero(torch.eq(frame_instance_ids, instance), as_tuple=True)[0].item() frame_mask = thing_masks_list[frame_id][pos_inner_frame] instance_masks.append(frame_mask) things_masks.append(torch.stack(instance_masks)) if len(instances) == 0: things_masks = torch.stack(thing_masks_list, dim=1) labels = torch.empty_like(instances) else: things_masks = torch.stack(things_masks) labels = torch.stack(labels) assert torch.all(torch.less(labels, num_things)) if gt_semantic_seg is not None: things_labels = labels gt_semantic_seg = gt_semantic_seg.squeeze(1) semantic_labels = torch.unique( gt_semantic_seg, sorted=False, return_inverse=False, return_counts=False) stuff_masks_list = [] stuff_labels_list = [] for label in semantic_labels: if label < num_things or label >= num_classes: continue stuff_mask = gt_semantic_seg == label stuff_masks_list.append(stuff_mask) stuff_labels_list.append(label) if len(stuff_masks_list) > 0: stuff_masks = torch.stack(stuff_masks_list, dim=0) stuff_labels = torch.stack(stuff_labels_list, dim=0) assert torch.all(torch.ge(stuff_labels, num_things)) and torch.all(torch.less(stuff_labels, num_classes)) labels = torch.cat([things_labels, stuff_labels], dim=0) masks = torch.cat([things_masks, stuff_masks], dim=0) else: labels = things_labels masks = things_masks assert len(labels) == len(masks) else: masks = things_masks labels = labels.to(dtype=torch.long) masks = masks.to(dtype=torch.long) return labels, masks