# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import copy import warnings from mmengine.config import ConfigDict def compat_cfg(cfg): """This function would modify some filed to keep the compatibility of config. For example, it will move some args which will be deprecated to the correct fields. """ cfg = copy.deepcopy(cfg) cfg = compat_imgs_per_gpu(cfg) cfg = compat_loader_args(cfg) cfg = compat_runner_args(cfg) return cfg def compat_runner_args(cfg): if 'runner' not in cfg: cfg.runner = ConfigDict({ 'type': 'EpochBasedRunner', 'max_epochs': cfg.total_epochs }) warnings.warn( 'config is now expected to have a `runner` section, ' 'please set `runner` in your config.', UserWarning) else: if 'total_epochs' in cfg: assert cfg.total_epochs == cfg.runner.max_epochs return cfg def compat_imgs_per_gpu(cfg): cfg = copy.deepcopy(cfg) if 'imgs_per_gpu' in cfg.data: warnings.warn('"imgs_per_gpu" is deprecated in MMDet V2.0. ' 'Please use "samples_per_gpu" instead') if 'samples_per_gpu' in cfg.data: warnings.warn( f'Got "imgs_per_gpu"={cfg.data.imgs_per_gpu} and ' f'"samples_per_gpu"={cfg.data.samples_per_gpu}, "imgs_per_gpu"' f'={cfg.data.imgs_per_gpu} is used in this experiments') else: warnings.warn('Automatically set "samples_per_gpu"="imgs_per_gpu"=' f'{cfg.data.imgs_per_gpu} in this experiments') cfg.data.samples_per_gpu = cfg.data.imgs_per_gpu return cfg def compat_loader_args(cfg): """Deprecated sample_per_gpu in cfg.data.""" cfg = copy.deepcopy(cfg) if 'train_dataloader' not in cfg.data: cfg.data['train_dataloader'] = ConfigDict() if 'val_dataloader' not in cfg.data: cfg.data['val_dataloader'] = ConfigDict() if 'test_dataloader' not in cfg.data: cfg.data['test_dataloader'] = ConfigDict() # special process for train_dataloader if 'samples_per_gpu' in cfg.data: samples_per_gpu = cfg.data.pop('samples_per_gpu') assert 'samples_per_gpu' not in \ cfg.data.train_dataloader, ('`samples_per_gpu` are set ' 'in `data` field and ` ' 'data.train_dataloader` ' 'at the same time. ' 'Please only set it in ' '`data.train_dataloader`. ') cfg.data.train_dataloader['samples_per_gpu'] = samples_per_gpu if 'persistent_workers' in cfg.data: persistent_workers = cfg.data.pop('persistent_workers') assert 'persistent_workers' not in \ cfg.data.train_dataloader, ('`persistent_workers` are set ' 'in `data` field and ` ' 'data.train_dataloader` ' 'at the same time. ' 'Please only set it in ' '`data.train_dataloader`. ') cfg.data.train_dataloader['persistent_workers'] = persistent_workers if 'workers_per_gpu' in cfg.data: workers_per_gpu = cfg.data.pop('workers_per_gpu') cfg.data.train_dataloader['workers_per_gpu'] = workers_per_gpu cfg.data.val_dataloader['workers_per_gpu'] = workers_per_gpu cfg.data.test_dataloader['workers_per_gpu'] = workers_per_gpu # special process for val_dataloader if 'samples_per_gpu' in cfg.data.val: # keep default value of `sample_per_gpu` is 1 assert 'samples_per_gpu' not in \ cfg.data.val_dataloader, ('`samples_per_gpu` are set ' 'in `data.val` field and ` ' 'data.val_dataloader` at ' 'the same time. ' 'Please only set it in ' '`data.val_dataloader`. ') cfg.data.val_dataloader['samples_per_gpu'] = \ cfg.data.val.pop('samples_per_gpu') # special process for val_dataloader # in case the test dataset is concatenated if isinstance(cfg.data.test, dict): if 'samples_per_gpu' in cfg.data.test: assert 'samples_per_gpu' not in \ cfg.data.test_dataloader, ('`samples_per_gpu` are set ' 'in `data.test` field and ` ' 'data.test_dataloader` ' 'at the same time. ' 'Please only set it in ' '`data.test_dataloader`. ') cfg.data.test_dataloader['samples_per_gpu'] = \ cfg.data.test.pop('samples_per_gpu') elif isinstance(cfg.data.test, list): for ds_cfg in cfg.data.test: if 'samples_per_gpu' in ds_cfg: assert 'samples_per_gpu' not in \ cfg.data.test_dataloader, ('`samples_per_gpu` are set ' 'in `data.test` field and ` ' 'data.test_dataloader` at' ' the same time. ' 'Please only set it in ' '`data.test_dataloader`. ') samples_per_gpu = max( [ds_cfg.pop('samples_per_gpu', 1) for ds_cfg in cfg.data.test]) cfg.data.test_dataloader['samples_per_gpu'] = samples_per_gpu return cfg