Spaces:
Build error
Build error
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. | |
import torch | |
from maskrcnn_benchmark.structures.image_list import to_image_list | |
import pdb | |
class BatchCollator(object): | |
""" | |
From a list of samples from the dataset, | |
returns the batched images and targets. | |
This should be passed to the DataLoader | |
""" | |
def __init__(self, size_divisible=0): | |
self.size_divisible = size_divisible | |
def __call__(self, batch): | |
transposed_batch = list(zip(*batch)) | |
images = to_image_list(transposed_batch[0], self.size_divisible) | |
targets = transposed_batch[1] | |
img_ids = transposed_batch[2] | |
positive_map = None | |
positive_map_eval = None | |
greenlight_map = None | |
if isinstance(targets[0], dict): | |
return images, targets, img_ids, positive_map, positive_map_eval | |
if "greenlight_map" in transposed_batch[1][0].fields(): | |
greenlight_map = torch.stack([i.get_field("greenlight_map") for i in transposed_batch[1]], dim = 0) | |
if "positive_map" in transposed_batch[1][0].fields(): | |
# we batch the positive maps here | |
# Since in general each batch element will have a different number of boxes, | |
# we collapse a single batch dimension to avoid padding. This is sufficient for our purposes. | |
max_len = max([v.get_field("positive_map").shape[1] for v in transposed_batch[1]]) | |
nb_boxes = sum([v.get_field("positive_map").shape[0] for v in transposed_batch[1]]) | |
batched_pos_map = torch.zeros((nb_boxes, max_len), dtype=torch.bool) | |
cur_count = 0 | |
for v in transposed_batch[1]: | |
cur_pos = v.get_field("positive_map") | |
batched_pos_map[cur_count: cur_count + len(cur_pos), : cur_pos.shape[1]] = cur_pos | |
cur_count += len(cur_pos) | |
assert cur_count == len(batched_pos_map) | |
positive_map = batched_pos_map.float() | |
if "positive_map_eval" in transposed_batch[1][0].fields(): | |
# we batch the positive maps here | |
# Since in general each batch element will have a different number of boxes, | |
# we collapse a single batch dimension to avoid padding. This is sufficient for our purposes. | |
max_len = max([v.get_field("positive_map_eval").shape[1] for v in transposed_batch[1]]) | |
nb_boxes = sum([v.get_field("positive_map_eval").shape[0] for v in transposed_batch[1]]) | |
batched_pos_map = torch.zeros((nb_boxes, max_len), dtype=torch.bool) | |
cur_count = 0 | |
for v in transposed_batch[1]: | |
cur_pos = v.get_field("positive_map_eval") | |
batched_pos_map[cur_count: cur_count + len(cur_pos), : cur_pos.shape[1]] = cur_pos | |
cur_count += len(cur_pos) | |
assert cur_count == len(batched_pos_map) | |
# assert batched_pos_map.sum().item() == sum([v["positive_map"].sum().item() for v in batch[1]]) | |
positive_map_eval = batched_pos_map.float() | |
return images, targets, img_ids, positive_map, positive_map_eval, greenlight_map | |
class BBoxAugCollator(object): | |
""" | |
From a list of samples from the dataset, | |
returns the images and targets. | |
Images should be converted to batched images in `im_detect_bbox_aug` | |
""" | |
def __call__(self, batch): | |
# return list(zip(*batch)) | |
transposed_batch = list(zip(*batch)) | |
images = transposed_batch[0] | |
targets = transposed_batch[1] | |
img_ids = transposed_batch[2] | |
positive_map = None | |
positive_map_eval = None | |
if isinstance(targets[0], dict): | |
return images, targets, img_ids, positive_map, positive_map_eval | |
return images, targets, img_ids, positive_map, positive_map_eval | |