# Copyright (c) Facebook, Inc. and its affiliates. # Modified by Bowen Cheng from https://github.com/sukjunhwang/IFC import itertools import logging import torch.utils.data from detectron2.config import CfgNode, configurable from detectron2.data.build import ( build_batch_data_loader, load_proposals_into_dataset, trivial_batch_collator, ) from detectron2.data.catalog import DatasetCatalog from detectron2.data.common import DatasetFromList, MapDataset from detectron2.data.dataset_mapper import DatasetMapper from detectron2.data.samplers import InferenceSampler, TrainingSampler from detectron2.utils.comm import get_world_size def _compute_num_images_per_worker(cfg: CfgNode): num_workers = get_world_size() images_per_batch = cfg.SOLVER.IMS_PER_BATCH assert ( images_per_batch % num_workers == 0 ), "SOLVER.IMS_PER_BATCH ({}) must be divisible by the number of workers ({}).".format( images_per_batch, num_workers ) assert ( images_per_batch >= num_workers ), "SOLVER.IMS_PER_BATCH ({}) must be larger than the number of workers ({}).".format( images_per_batch, num_workers ) images_per_worker = images_per_batch // num_workers return images_per_worker def filter_images_with_only_crowd_annotations(dataset_dicts, dataset_names): """ Filter out images with none annotations or only crowd annotations (i.e., images without non-crowd annotations). A common training-time preprocessing on COCO dataset. Args: dataset_dicts (list[dict]): annotations in Detectron2 Dataset format. Returns: list[dict]: the same format, but filtered. """ num_before = len(dataset_dicts) def valid(anns): for ann in anns: if isinstance(ann, list): for instance in ann: if instance.get("iscrowd", 0) == 0: return True else: if ann.get("iscrowd", 0) == 0: return True return False dataset_dicts = [x for x in dataset_dicts if valid(x["annotations"])] num_after = len(dataset_dicts) logger = logging.getLogger(__name__) logger.info( "Removed {} images with no usable annotations. {} images left.".format( num_before - num_after, num_after ) ) return dataset_dicts def get_detection_dataset_dicts( dataset_names, filter_empty=True, proposal_files=None ): """ Load and prepare dataset dicts for instance detection/segmentation and semantic segmentation. Args: dataset_names (str or list[str]): a dataset name or a list of dataset names filter_empty (bool): whether to filter out images without instance annotations proposal_files (list[str]): if given, a list of object proposal files that match each dataset in `dataset_names`. Returns: list[dict]: a list of dicts following the standard dataset dict format. """ if isinstance(dataset_names, str): dataset_names = [dataset_names] assert len(dataset_names) dataset_dicts = [DatasetCatalog.get(dataset_name) for dataset_name in dataset_names] for dataset_name, dicts in zip(dataset_names, dataset_dicts): assert len(dicts), "Dataset '{}' is empty!".format(dataset_name) if proposal_files is not None: assert len(dataset_names) == len(proposal_files) # load precomputed proposals from proposal files dataset_dicts = [ load_proposals_into_dataset(dataset_i_dicts, proposal_file) for dataset_i_dicts, proposal_file in zip(dataset_dicts, proposal_files) ] dataset_dicts = list(itertools.chain.from_iterable(dataset_dicts)) has_instances = "annotations" in dataset_dicts[0] if filter_empty and has_instances: dataset_dicts = filter_images_with_only_crowd_annotations(dataset_dicts, dataset_names) assert len(dataset_dicts), "No valid data found in {}.".format(",".join(dataset_names)) return dataset_dicts def _train_loader_from_config(cfg, mapper, *, dataset=None, sampler=None): if dataset is None: dataset = get_detection_dataset_dicts( cfg.DATASETS.TRAIN, filter_empty=cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS, proposal_files=cfg.DATASETS.PROPOSAL_FILES_TRAIN if cfg.MODEL.LOAD_PROPOSALS else None, ) if mapper is None: mapper = DatasetMapper(cfg, True) if sampler is None: sampler_name = cfg.DATALOADER.SAMPLER_TRAIN logger = logging.getLogger(__name__) logger.info("Using training sampler {}".format(sampler_name)) sampler = TrainingSampler(len(dataset)) return { "dataset": dataset, "sampler": sampler, "mapper": mapper, "total_batch_size": cfg.SOLVER.IMS_PER_BATCH, "aspect_ratio_grouping": cfg.DATALOADER.ASPECT_RATIO_GROUPING, "num_workers": cfg.DATALOADER.NUM_WORKERS, } # TODO can allow dataset as an iterable or IterableDataset to make this function more general @configurable(from_config=_train_loader_from_config) def build_detection_train_loader( dataset, *, mapper, sampler=None, total_batch_size, aspect_ratio_grouping=True, num_workers=0 ): """ Build a dataloader for object detection with some default features. This interface is experimental. Args: dataset (list or torch.utils.data.Dataset): a list of dataset dicts, or a map-style pytorch dataset. They can be obtained by using :func:`DatasetCatalog.get` or :func:`get_detection_dataset_dicts`. mapper (callable): a callable which takes a sample (dict) from dataset and returns the format to be consumed by the model. When using cfg, the default choice is ``DatasetMapper(cfg, is_train=True)``. sampler (torch.utils.data.sampler.Sampler or None): a sampler that produces indices to be applied on ``dataset``. Default to :class:`TrainingSampler`, which coordinates a random shuffle sequence across all workers. total_batch_size (int): total batch size across all workers. Batching simply puts data into a list. aspect_ratio_grouping (bool): whether to group images with similar aspect ratio for efficiency. When enabled, it requires each element in dataset be a dict with keys "width" and "height". num_workers (int): number of parallel data loading workers Returns: torch.utils.data.DataLoader: a dataloader. Each output from it is a ``list[mapped_element]`` of length ``total_batch_size / num_workers``, where ``mapped_element`` is produced by the ``mapper``. """ if isinstance(dataset, list): dataset = DatasetFromList(dataset, copy=False) if mapper is not None: dataset = MapDataset(dataset, mapper) if sampler is None: sampler = TrainingSampler(len(dataset)) assert isinstance(sampler, torch.utils.data.sampler.Sampler) return build_batch_data_loader( dataset, sampler, total_batch_size, aspect_ratio_grouping=aspect_ratio_grouping, num_workers=num_workers, ) def _test_loader_from_config(cfg, dataset_name, mapper=None): """ Uses the given `dataset_name` argument (instead of the names in cfg), because the standard practice is to evaluate each test set individually (not combining them). """ dataset = get_detection_dataset_dicts( [dataset_name], filter_empty=False, proposal_files=[ cfg.DATASETS.PROPOSAL_FILES_TEST[list(cfg.DATASETS.TEST).index(dataset_name)] ] if cfg.MODEL.LOAD_PROPOSALS else None, ) if mapper is None: mapper = DatasetMapper(cfg, False) return {"dataset": dataset, "mapper": mapper, "num_workers": cfg.DATALOADER.NUM_WORKERS} @configurable(from_config=_test_loader_from_config) def build_detection_test_loader(dataset, *, mapper, num_workers=0): """ Similar to `build_detection_train_loader`, but uses a batch size of 1. This interface is experimental. Args: dataset (list or torch.utils.data.Dataset): a list of dataset dicts, or a map-style pytorch dataset. They can be obtained by using :func:`DatasetCatalog.get` or :func:`get_detection_dataset_dicts`. mapper (callable): a callable which takes a sample (dict) from dataset and returns the format to be consumed by the model. When using cfg, the default choice is ``DatasetMapper(cfg, is_train=False)``. num_workers (int): number of parallel data loading workers Returns: DataLoader: a torch DataLoader, that loads the given detection dataset, with test-time transformation and batching. Examples: :: data_loader = build_detection_test_loader( DatasetRegistry.get("my_test"), mapper=DatasetMapper(...)) # or, instantiate with a CfgNode: data_loader = build_detection_test_loader(cfg, "my_test") """ if isinstance(dataset, list): dataset = DatasetFromList(dataset, copy=False) if mapper is not None: dataset = MapDataset(dataset, mapper) sampler = InferenceSampler(len(dataset)) # Always use 1 image per worker during inference since this is the # standard when reporting inference time in papers. batch_sampler = torch.utils.data.sampler.BatchSampler(sampler, 1, drop_last=False) data_loader = torch.utils.data.DataLoader( dataset, num_workers=num_workers, batch_sampler=batch_sampler, collate_fn=trivial_batch_collator, ) return data_loader