# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. from mmcv.transforms import LoadImageFromFile from mmdet.datasets.transforms import LoadAnnotations, LoadPanopticAnnotations from mmdet.registry import TRANSFORMS def get_loading_pipeline(pipeline): """Only keep loading image and annotations related configuration. Args: pipeline (list[dict]): Data pipeline configs. Returns: list[dict]: The new pipeline list with only keep loading image and annotations related configuration. Examples: >>> pipelines = [ ... dict(type='LoadImageFromFile'), ... dict(type='LoadAnnotations', with_bbox=True), ... dict(type='Resize', img_scale=(1333, 800), keep_ratio=True), ... dict(type='RandomFlip', flip_ratio=0.5), ... dict(type='Normalize', **img_norm_cfg), ... dict(type='Pad', size_divisor=32), ... dict(type='DefaultFormatBundle'), ... dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']) ... ] >>> expected_pipelines = [ ... dict(type='LoadImageFromFile'), ... dict(type='LoadAnnotations', with_bbox=True) ... ] >>> assert expected_pipelines ==\ ... get_loading_pipeline(pipelines) """ loading_pipeline_cfg = [] for cfg in pipeline: obj_cls = TRANSFORMS.get(cfg['type']) # TODO:use more elegant way to distinguish loading modules if obj_cls is not None and obj_cls in (LoadImageFromFile, LoadAnnotations, LoadPanopticAnnotations): loading_pipeline_cfg.append(cfg) assert len(loading_pipeline_cfg) == 2, \ 'The data pipeline in your config file must include ' \ 'loading image and annotations related pipeline.' return loading_pipeline_cfg