jmercat's picture
Removed history to avoid any unverified information being released
5769ee4
import os
from typing import Optional, Union, List
import warnings
import argparse
from mmcv import Config
def config_argparse(config_path: Optional[Union[str, List[str]]] = None) -> Config:
"""Function that loads the config file as an MMCV Config object and overwrites its values with argparsed arguments.
Args:
config_path : path of the mmcv config file
Returns:
MMCV Config object with default values from the config_path and overwritten values from argparse
"""
if config_path is None:
working_dir = os.path.dirname(os.path.realpath(__file__))
config_path = os.path.join(
working_dir, "..", "..", "config", "learning_config.py"
)
if isinstance(config_path, str):
cfg = Config.fromfile(config_path)
else:
cfg = Config.fromfile(config_path[0])
for path in config_path[1:]:
c = Config.fromfile(path)
cfg.update(c)
parser = argparse.ArgumentParser()
excluded_args = ["force_config", "load_last"]
overwritable_types = (str, float, int, list)
for key, value in cfg.items():
if key not in excluded_args:
if list in overwritable_types and isinstance(value, list):
if len(value) > 0:
parser.add_argument(
"--" + key, default=value, nargs="+", type=type(value[0])
)
else:
parser.add_argument("--" + key, default=value, nargs="+")
elif isinstance(value, overwritable_types):
parser.add_argument("--" + key, default=value, type=type(value))
if "load_from" not in cfg.keys():
parser.add_argument(
"--load_from",
default="",
type=str,
help="""Use this to load the model weights from a wandb checkpoint,
refer to the checkpoint with the wandb id (example:'1f1ho81a')""",
)
parser.add_argument(
"--load_last",
action="store_true",
help="""Use this flag to force the use of the last checkpoint instead of the best one
when loading a model.""",
)
parser.add_argument(
"--force_config",
action="store_true",
help="""Use this flag to force the use of the local config file
when loading a model from a checkpoint. Otherwise the checkpoint config file is used.
In any case the parameters can be overwritten with an argparse argument.""",
)
if "force_config" not in cfg.keys():
parser.set_defaults(force_config=False)
else:
parser.set_defaults(force_config=cfg.force_config)
if "load_last" not in cfg.keys():
parser.set_defaults(load_last=False)
else:
parser.set_defaults(force_config=cfg.force_config)
args = parser.parse_args()
# Print a warning in case the parameter 'dt' or 'time_scene' is changed becaus 'sample_times' might need to be updated accordingly.
if (
args.dt != cfg.dt or args.time_scene != cfg.time_scene
) and args.sample_times == cfg.sample_times:
warnings.warn(
f"""Parameter 'dt' has been changed from {args.dataset_parameters['dt']} to {args.dt} by
a command line argument, it might be used to set the parameter 'sample_times' that
cannot be updated accordingly. Consider setting 'dt' in {config_path} instead."""
)
# Config has a dataset_parameters field that copies the parameters related to dataset to compare them,
# they must be updated too if some of the dataset parameters were changed by argparse
for key, value in cfg.dataset_parameters.items():
if isinstance(value, overwritable_types):
cfg.dataset_parameters[key] = args.__getattribute__(key)
cfg.update(args.__dict__)
return cfg