|
|
|
import itertools |
|
import logging |
|
import numpy as np |
|
import math |
|
from collections import defaultdict |
|
import torch.utils.data |
|
|
|
from detectron2.config import configurable |
|
from detectron2.utils.logger import _log_api_usage |
|
|
|
from detectron2.data.catalog import DatasetCatalog |
|
from detectron2.data.common import DatasetFromList, MapDataset |
|
from detectron2.data.dataset_mapper import DatasetMapper |
|
from detectron2.data.samplers import ( |
|
InferenceSampler, |
|
RepeatFactorTrainingSampler, |
|
TrainingSampler |
|
) |
|
from detectron2.data.build import ( |
|
build_batch_data_loader, |
|
trivial_batch_collator |
|
) |
|
|
|
def filter_images_with_only_crowd_annotations(dataset_dicts): |
|
""" |
|
Filter out images with none annotations or only crowd annotations |
|
(i.e., images without non-crowd annotations). |
|
A common training-time preprocessing on COCO dataset. |
|
|
|
Args: |
|
dataset_dicts (list[dict]): annotations in Detectron2 Dataset format. |
|
|
|
Returns: |
|
list[dict]: the same format, but filtered. |
|
""" |
|
num_before = len(dataset_dicts) |
|
|
|
def valid(anns): |
|
for ann in anns: |
|
if ann.get("iscrowd", 0) == 0: |
|
return True |
|
return False |
|
|
|
dataset_dicts = [x for x in dataset_dicts if valid(x["annotations"])] |
|
num_after = len(dataset_dicts) |
|
logger = logging.getLogger(__name__) |
|
logger.info( |
|
"Removed {} images marked with crowd. {} images left.".format( |
|
num_before - num_after, num_after |
|
) |
|
) |
|
return dataset_dicts |
|
|
|
def get_detection_dataset_dicts(names, filter_empty=True, **kwargs): |
|
|
|
if isinstance(names, str): |
|
names = [names] |
|
|
|
assert len(names), names |
|
dataset_dicts = [DatasetCatalog.get(dataset_name) for dataset_name in names] |
|
for dataset_name, dicts in zip(names, dataset_dicts): |
|
assert len(dicts), "Dataset '{}' is empty!".format(dataset_name) |
|
|
|
dataset_dicts = list(itertools.chain.from_iterable(dataset_dicts)) |
|
|
|
has_instances = "annotations" in dataset_dicts[0] |
|
|
|
if filter_empty and has_instances: |
|
dataset_dicts = filter_images_with_only_crowd_annotations(dataset_dicts) |
|
|
|
assert len(dataset_dicts), "No valid data found in {}.".format(",".join(names)) |
|
return dataset_dicts |
|
|
|
|
|
def _train_loader_from_config(cfg, mapper=None, *, dataset=None, sampler=None, dataset_id_to_src=None): |
|
if dataset is None: |
|
dataset = get_detection_dataset_dicts( |
|
cfg.DATASETS.TRAIN, |
|
filter_empty=cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS, |
|
min_keypoints=cfg.MODEL.ROI_KEYPOINT_HEAD.MIN_KEYPOINTS_PER_IMAGE |
|
if cfg.MODEL.KEYPOINT_ON |
|
else 0, |
|
proposal_files=cfg.DATASETS.PROPOSAL_FILES_TRAIN if cfg.MODEL.LOAD_PROPOSALS else None, |
|
) |
|
_log_api_usage("dataset." + cfg.DATASETS.TRAIN[0]) |
|
|
|
if mapper is None: |
|
mapper = DatasetMapper(cfg, True) |
|
|
|
if sampler is None: |
|
sampler_name = cfg.DATALOADER.SAMPLER_TRAIN |
|
balance_datasets = cfg.DATALOADER.BALANCE_DATASETS |
|
logger = logging.getLogger(__name__) |
|
logger.info("Using training sampler {}".format(sampler_name)) |
|
|
|
if balance_datasets: |
|
assert dataset_id_to_src is not None, 'Need dataset sources.' |
|
|
|
dataset_source_to_int = {val:i for i, val in enumerate(set(dataset_id_to_src.values()))} |
|
dataset_ids_per_img = [dataset_source_to_int[dataset_id_to_src[img['dataset_id']]] for img in dataset] |
|
dataset_ids = np.unique(dataset_ids_per_img) |
|
|
|
|
|
if len(dataset_ids) == 1: |
|
weights_per_img = torch.ones(len(dataset_ids_per_img)).float() |
|
|
|
|
|
else: |
|
counts = np.bincount(dataset_ids_per_img) |
|
counts = [counts[id] for id in dataset_ids] |
|
weights = [1 - count/np.sum(counts) for count in counts] |
|
weights = [weight/np.min(weights) for weight in weights] |
|
|
|
weights_per_img = torch.zeros(len(dataset_ids_per_img)).float() |
|
dataset_ids_per_img = torch.FloatTensor(dataset_ids_per_img).long() |
|
|
|
|
|
for dataset_id, weight in zip(dataset_ids, weights): |
|
weights_per_img[dataset_ids_per_img == dataset_id] = weight |
|
|
|
|
|
if sampler_name == "TrainingSampler" and not balance_datasets: |
|
sampler = TrainingSampler(len(dataset)) |
|
|
|
|
|
elif sampler_name == "TrainingSampler" and balance_datasets: |
|
sampler = RepeatFactorTrainingSampler(weights_per_img) |
|
|
|
|
|
elif sampler_name == "RepeatFactorTrainingSampler" and not balance_datasets: |
|
repeat_factors = repeat_factors_from_category_frequency( |
|
dataset, cfg.DATALOADER.REPEAT_THRESHOLD |
|
) |
|
sampler = RepeatFactorTrainingSampler(repeat_factors) |
|
|
|
|
|
elif sampler_name == "RepeatFactorTrainingSampler" and balance_datasets: |
|
repeat_factors = repeat_factors_from_category_frequency( |
|
dataset, cfg.DATALOADER.REPEAT_THRESHOLD |
|
) |
|
repeat_factors *= weights_per_img |
|
repeat_factors /= repeat_factors.min().item() |
|
sampler = RepeatFactorTrainingSampler(repeat_factors) |
|
else: |
|
raise ValueError("Unknown training sampler: {}".format(sampler_name)) |
|
|
|
return { |
|
"dataset": dataset, |
|
"sampler": sampler, |
|
"mapper": mapper, |
|
"total_batch_size": cfg.SOLVER.IMS_PER_BATCH, |
|
"aspect_ratio_grouping": cfg.DATALOADER.ASPECT_RATIO_GROUPING, |
|
"num_workers": cfg.DATALOADER.NUM_WORKERS, |
|
} |
|
|
|
|
|
def repeat_factors_from_category_frequency(dataset_dicts, repeat_thresh): |
|
""" |
|
Compute (fractional) per-image repeat factors based on category frequency. |
|
The repeat factor for an image is a function of the frequency of the rarest |
|
category labeled in that image. The "frequency of category c" in [0, 1] is defined |
|
as the fraction of images in the training set (without repeats) in which category c |
|
appears. |
|
See :paper:`lvis` (>= v2) Appendix B.2. |
|
|
|
Args: |
|
dataset_dicts (list[dict]): annotations in Detectron2 dataset format. |
|
repeat_thresh (float): frequency threshold below which data is repeated. |
|
If the frequency is half of `repeat_thresh`, the image will be |
|
repeated twice. |
|
|
|
Returns: |
|
torch.Tensor: |
|
the i-th element is the repeat factor for the dataset image at index i. |
|
""" |
|
|
|
category_freq = defaultdict(int) |
|
for dataset_dict in dataset_dicts: |
|
cat_ids = {ann["category_id"] for ann in dataset_dict["annotations"]} |
|
for cat_id in cat_ids: |
|
if cat_id < 0: continue |
|
category_freq[cat_id] += 1 |
|
num_images = len(dataset_dicts) |
|
for k, v in category_freq.items(): |
|
category_freq[k] = v / num_images |
|
|
|
|
|
|
|
category_rep = { |
|
cat_id: max(1.0, math.sqrt(repeat_thresh / cat_freq)) |
|
for cat_id, cat_freq in category_freq.items() |
|
} |
|
|
|
|
|
|
|
rep_factors = [] |
|
for dataset_dict in dataset_dicts: |
|
cat_ids = {ann["category_id"] for ann in dataset_dict["annotations"]} |
|
rep_factor = max({category_rep[cat_id] for cat_id in cat_ids if cat_id >= 0}, default=1.0) |
|
rep_factors.append(rep_factor) |
|
|
|
return torch.tensor(rep_factors, dtype=torch.float32) |
|
|
|
@configurable(from_config=_train_loader_from_config) |
|
def build_detection_train_loader(dataset, *, mapper, sampler=None, total_batch_size, aspect_ratio_grouping=True, num_workers=0): |
|
if isinstance(dataset, list): |
|
dataset = DatasetFromList(dataset, copy=False) |
|
if mapper is not None: |
|
dataset = MapDataset(dataset, mapper) |
|
if sampler is None: |
|
sampler = TrainingSampler(len(dataset)) |
|
assert isinstance(sampler, torch.utils.data.Sampler) |
|
return build_batch_data_loader( |
|
dataset, |
|
sampler, |
|
total_batch_size, |
|
aspect_ratio_grouping=aspect_ratio_grouping, |
|
num_workers=num_workers |
|
) |
|
|
|
def _test_loader_from_config(cfg, dataset_name, batch_size=1, mapper=None, filter_empty=False): |
|
if isinstance(dataset_name, str): |
|
dataset_name = [dataset_name] |
|
|
|
dataset = get_detection_dataset_dicts( |
|
dataset_name, |
|
filter_empty=filter_empty, |
|
proposal_files=[ |
|
cfg.DATASETS.PROPOSAL_FILES_TEST[list(cfg.DATASETS.TEST).index(x)] for x in dataset_name |
|
] |
|
if cfg.MODEL.LOAD_PROPOSALS |
|
else None, |
|
) |
|
if mapper is None: |
|
mapper = DatasetMapper(cfg, False) |
|
|
|
return {"dataset": dataset, "mapper": mapper, 'batch_size':batch_size, "num_workers": cfg.DATALOADER.NUM_WORKERS} |
|
|
|
@configurable(from_config=_test_loader_from_config) |
|
def build_detection_test_loader(dataset, *, mapper, batch_size=1, sampler=None, num_workers=0): |
|
|
|
if isinstance(dataset, list): |
|
dataset = DatasetFromList(dataset, copy=False) |
|
if mapper is not None: |
|
dataset = MapDataset(dataset, mapper) |
|
if sampler is None: |
|
sampler = InferenceSampler(len(dataset)) |
|
|
|
|
|
|
|
batch_sampler = torch.utils.data.BatchSampler(sampler, batch_size=batch_size, drop_last=False) |
|
data_loader = torch.utils.data.DataLoader( |
|
dataset, |
|
num_workers=num_workers, |
|
batch_sampler=batch_sampler, |
|
collate_fn=trivial_batch_collator, |
|
) |
|
return data_loader |
|
|
|
|