Spaces:
Sleeping
Sleeping
import itertools | |
import logging | |
import torch.utils.data | |
from detectron2.config import CfgNode, configurable | |
from detectron2.data.build import ( | |
build_batch_data_loader, | |
load_proposals_into_dataset, | |
trivial_batch_collator, | |
) | |
from detectron2.data.catalog import DatasetCatalog | |
from detectron2.data.common import DatasetFromList, MapDataset | |
from detectron2.data.dataset_mapper import DatasetMapper | |
from detectron2.data.samplers import InferenceSampler, TrainingSampler | |
from detectron2.utils.comm import get_world_size | |
from torch.utils.data.sampler import Sampler | |
from collections import defaultdict | |
from typing import Optional | |
from detectron2.utils import comm | |
def _compute_num_images_per_worker(cfg: CfgNode): | |
num_workers = get_world_size() | |
images_per_batch = cfg.SOLVER.IMS_PER_BATCH | |
assert ( | |
images_per_batch % num_workers == 0 | |
), "SOLVER.IMS_PER_BATCH ({}) must be divisible by the number of workers ({}).".format( | |
images_per_batch, num_workers | |
) | |
assert ( | |
images_per_batch >= num_workers | |
), "SOLVER.IMS_PER_BATCH ({}) must be larger than the number of workers ({}).".format( | |
images_per_batch, num_workers | |
) | |
images_per_worker = images_per_batch // num_workers | |
return images_per_worker | |
def filter_images_with_only_crowd_annotations(dataset_dicts, dataset_names): | |
""" | |
Filter out images with none annotations or only crowd annotations | |
(i.e., images without non-crowd annotations). | |
A common training-time preprocessing on COCO dataset. | |
Args: | |
dataset_dicts (list[dict]): annotations in Detectron2 Dataset format. | |
Returns: | |
list[dict]: the same format, but filtered. | |
""" | |
num_before = len(dataset_dicts) | |
def valid(anns): | |
for ann in anns: | |
if isinstance(ann, list): | |
for instance in ann: | |
if instance.get("iscrowd", 0) == 0: | |
return True | |
else: | |
if ann.get("iscrowd", 0) == 0: | |
return True | |
return False | |
dataset_dicts = [x for x in dataset_dicts if valid(x["annotations"])] | |
num_after = len(dataset_dicts) | |
logger = logging.getLogger(__name__) | |
logger.info( | |
"Removed {} images with no usable annotations. {} images left.".format( | |
num_before - num_after, num_after | |
) | |
) | |
return dataset_dicts | |
def get_detection_dataset_dicts( | |
dataset_names, filter_empty=True, proposal_files=None | |
): | |
""" | |
Load and prepare dataset dicts for instance detection/segmentation and semantic segmentation. | |
Args: | |
dataset_names (str or list[str]): a dataset name or a list of dataset names | |
filter_empty (bool): whether to filter out images without instance annotations | |
proposal_files (list[str]): if given, a list of object proposal files | |
that match each dataset in `dataset_names`. | |
Returns: | |
list[dict]: a list of dicts following the standard dataset dict format. | |
""" | |
if isinstance(dataset_names, str): | |
dataset_names = [dataset_names] | |
assert len(dataset_names) | |
dataset_dicts = [DatasetCatalog.get(dataset_name) for dataset_name in dataset_names] | |
for dataset_name, dicts in zip(dataset_names, dataset_dicts): | |
assert len(dicts), "Dataset '{}' is empty!".format(dataset_name) | |
if proposal_files is not None: | |
assert len(dataset_names) == len(proposal_files) | |
# load precomputed proposals from proposal files | |
dataset_dicts = [ | |
load_proposals_into_dataset(dataset_i_dicts, proposal_file) | |
for dataset_i_dicts, proposal_file in zip(dataset_dicts, proposal_files) | |
] | |
dataset_dicts = list(itertools.chain.from_iterable(dataset_dicts)) | |
has_instances = "annotations" in dataset_dicts[0] | |
if filter_empty and has_instances: | |
dataset_dicts = filter_images_with_only_crowd_annotations(dataset_dicts, dataset_names) | |
assert len(dataset_dicts), "No valid data found in {}.".format(",".join(dataset_names)) | |
return dataset_dicts | |
def _train_loader_from_config(cfg, mapper, *, dataset=None, sampler=None): | |
if dataset is None: | |
dataset = get_detection_dataset_dicts( | |
cfg.DATASETS.TRAIN, | |
filter_empty=cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS, | |
proposal_files=cfg.DATASETS.PROPOSAL_FILES_TRAIN if cfg.MODEL.LOAD_PROPOSALS else None, | |
) | |
if mapper is None: | |
mapper = DatasetMapper(cfg, True) | |
if sampler is None: | |
sampler_name = cfg.DATALOADER.SAMPLER_TRAIN | |
logger = logging.getLogger(__name__) | |
logger.info("Using training sampler {}".format(sampler_name)) | |
if sampler_name == "TrainingSampler": | |
sampler = TrainingSampler(len(dataset)) | |
elif sampler_name == "ClassAwareSampler": | |
sampler = ClassAwareSampler(dataset) | |
return { | |
"dataset": dataset, | |
"sampler": sampler, | |
"mapper": mapper, | |
"total_batch_size": cfg.SOLVER.IMS_PER_BATCH, | |
"aspect_ratio_grouping": cfg.DATALOADER.ASPECT_RATIO_GROUPING, | |
"num_workers": cfg.DATALOADER.NUM_WORKERS, | |
"use_mixup": True | |
} | |
# TODO can allow dataset as an iterable or IterableDataset to make this function more general | |
def build_detection_train_loader( | |
dataset, *, mapper, sampler=None, total_batch_size, aspect_ratio_grouping=True, num_workers=0, | |
use_mixup=False | |
): | |
""" | |
Build a dataloader for object detection with some default features. | |
This interface is experimental. | |
Args: | |
dataset (list or torch.utils.data.Dataset): a list of dataset dicts, | |
or a map-style pytorch dataset. They can be obtained by using | |
:func:`DatasetCatalog.get` or :func:`get_detection_dataset_dicts`. | |
mapper (callable): a callable which takes a sample (dict) from dataset and | |
returns the format to be consumed by the model. | |
When using cfg, the default choice is ``DatasetMapper(cfg, is_train=True)``. | |
sampler (torch.utils.data.sampler.Sampler or None): a sampler that | |
produces indices to be applied on ``dataset``. | |
Default to :class:`TrainingSampler`, which coordinates a random shuffle | |
sequence across all workers. | |
total_batch_size (int): total batch size across all workers. Batching | |
simply puts data into a list. | |
aspect_ratio_grouping (bool): whether to group images with similar | |
aspect ratio for efficiency. When enabled, it requires each | |
element in dataset be a dict with keys "width" and "height". | |
num_workers (int): number of parallel data loading workers | |
Returns: | |
torch.utils.data.DataLoader: a dataloader. Each output from it is a | |
``list[mapped_element]`` of length ``total_batch_size / num_workers``, | |
where ``mapped_element`` is produced by the ``mapper``. | |
""" | |
if isinstance(dataset, list): | |
dataset = DatasetFromList(dataset, copy=False) | |
if mapper is not None: | |
if use_mixup: | |
dataset = MapDatasetMixup(dataset, mapper) | |
else: | |
dataset = MapDataset(dataset, mapper) | |
if sampler is None: | |
sampler = TrainingSampler(len(dataset)) | |
assert isinstance(sampler, torch.utils.data.sampler.Sampler) | |
return build_batch_data_loader( | |
dataset, | |
sampler, | |
total_batch_size, | |
aspect_ratio_grouping=aspect_ratio_grouping, | |
num_workers=num_workers, | |
) | |
def _test_loader_from_config(cfg, dataset_name, mapper=None): | |
""" | |
Uses the given `dataset_name` argument (instead of the names in cfg), because the | |
standard practice is to evaluate each test set individually (not combining them). | |
""" | |
dataset = get_detection_dataset_dicts( | |
[dataset_name], | |
filter_empty=False, | |
proposal_files=[ | |
cfg.DATASETS.PROPOSAL_FILES_TEST[list(cfg.DATASETS.TEST).index(dataset_name)] | |
] | |
if cfg.MODEL.LOAD_PROPOSALS | |
else None, | |
) | |
if mapper is None: | |
mapper = DatasetMapper(cfg, False) | |
return {"dataset": dataset, "mapper": mapper, "num_workers": cfg.DATALOADER.NUM_WORKERS} | |
def build_detection_test_loader(dataset, *, mapper, num_workers=0): | |
""" | |
Similar to `build_detection_train_loader`, but uses a batch size of 1. | |
This interface is experimental. | |
Args: | |
dataset (list or torch.utils.data.Dataset): a list of dataset dicts, | |
or a map-style pytorch dataset. They can be obtained by using | |
:func:`DatasetCatalog.get` or :func:`get_detection_dataset_dicts`. | |
mapper (callable): a callable which takes a sample (dict) from dataset | |
and returns the format to be consumed by the model. | |
When using cfg, the default choice is ``DatasetMapper(cfg, is_train=False)``. | |
num_workers (int): number of parallel data loading workers | |
Returns: | |
DataLoader: a torch DataLoader, that loads the given detection | |
dataset, with test-time transformation and batching. | |
Examples: | |
:: | |
data_loader = build_detection_test_loader( | |
DatasetRegistry.get("my_test"), | |
mapper=DatasetMapper(...)) | |
# or, instantiate with a CfgNode: | |
data_loader = build_detection_test_loader(cfg, "my_test") | |
""" | |
if isinstance(dataset, list): | |
dataset = DatasetFromList(dataset, copy=False) | |
if mapper is not None: | |
dataset = MapDataset(dataset, mapper) | |
sampler = InferenceSampler(len(dataset)) | |
# Always use 1 image per worker during inference since this is the | |
# standard when reporting inference time in papers. | |
# batch_sampler = torch.utils.data.sampler.BatchSampler(sampler, 1, drop_last=False) | |
data_loader = torch.utils.data.DataLoader( | |
dataset, | |
batch_size=1, | |
sampler=sampler, | |
drop_last=False, | |
num_workers=num_workers, | |
collate_fn=trivial_batch_collator, | |
) | |
return data_loader | |
class ClassAwareSampler(Sampler): | |
def __init__(self, dataset_dicts, seed: Optional[int] = None): | |
""" | |
""" | |
self._size = len(dataset_dicts) | |
assert self._size > 0 | |
if seed is None: | |
seed = comm.shared_random_seed() | |
self._seed = int(seed) | |
self._rank = comm.get_rank() | |
self._world_size = comm.get_world_size() | |
self.weights = self._get_class_balance_factor(dataset_dicts) | |
def __iter__(self): | |
start = self._rank | |
yield from itertools.islice( | |
self._infinite_indices(), start, None, self._world_size) | |
def _infinite_indices(self): | |
g = torch.Generator() | |
g.manual_seed(self._seed) | |
while True: | |
ids = torch.multinomial( | |
self.weights, self._size, generator=g, | |
replacement=True) | |
yield from ids | |
def _get_class_balance_factor(self, dataset_dicts, l=1.): | |
ret = [] | |
category_freq = defaultdict(int) | |
for dataset_dict in dataset_dicts: # For each image (without repeats) | |
cat_ids = {ann["category_id"] for ann in dataset_dict["annotations"]} | |
for cat_id in cat_ids: | |
category_freq[cat_id] += 1 | |
for i, dataset_dict in enumerate(dataset_dicts): | |
cat_ids = {ann["category_id"] for ann in dataset_dict["annotations"]} | |
ret.append(sum( | |
[1. / (category_freq[cat_id] ** l) for cat_id in cat_ids])) | |
return torch.tensor(ret).float() |