""" Copyright (c) 2022, salesforce.com, inc. All rights reserved. SPDX-License-Identifier: BSD-3-Clause For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause """ import logging import json from typing import Dict from omegaconf import OmegaConf from medomni.common.registry import registry import ipdb class Config: def __init__(self, args): self.config = {} self.args = args # Register the config and configuration for setup registry.register("configuration", self) user_config = self._build_opt_list(self.args.options) config = OmegaConf.load(self.args.cfg_path) runner_config = self.build_runner_config(config) model_config = self.build_model_config(config, **user_config) dataset_config = self.build_dataset_config(config) # Validate the user-provided runner configuration # model and dataset configuration are supposed to be validated by the respective classes # [TODO] validate the model/dataset configuration # self._validate_runner_config(runner_config) # Override the default configuration with user options. self.config = OmegaConf.merge( runner_config, model_config, dataset_config, user_config ) def _validate_runner_config(self, runner_config): """ This method validates the configuration, such that 1) all the user specified options are valid; 2) no type mismatches between the user specified options and the config. """ runner_config_validator = create_runner_config_validator() runner_config_validator.validate(runner_config) def _build_opt_list(self, opts): opts_dot_list = self._convert_to_dot_list(opts) return OmegaConf.from_dotlist(opts_dot_list) @staticmethod def build_model_config(config, **kwargs): model = config.get("model", None) assert model is not None, "Missing model configuration file." model_cls = registry.get_model_class(model.arch) assert model_cls is not None, f"Model '{model.arch}' has not been registered." model_type = kwargs.get("model.model_type", None) if not model_type: model_type = model.get("model_type", None) # else use the model type selected by user. assert model_type is not None, "Missing model_type." model_config_path = model_cls.default_config_path(model_type=model_type) model_config = OmegaConf.create() # hierarchy override, customized config > default config model_config = OmegaConf.merge( model_config, OmegaConf.load(model_config_path), {"model": config["model"]}, ) return model_config @staticmethod def build_runner_config(config): return {"run": config.run} @staticmethod def build_dataset_config(config): datasets = config.get("datasets", None) if datasets is None: raise KeyError( "Expecting 'datasets' as the root key for dataset configuration." ) dataset_config = OmegaConf.create() for dataset_name in datasets: builder_cls = registry.get_builder_class(dataset_name) dataset_config_type = datasets[dataset_name].get("type", "default") dataset_config_path = builder_cls.default_config_path( type=dataset_config_type ) # hierarchy override, customized config > default config dataset_config = OmegaConf.merge( dataset_config, OmegaConf.load(dataset_config_path), {"datasets": {dataset_name: config["datasets"][dataset_name]}}, ) return dataset_config def _convert_to_dot_list(self, opts): if opts is None: opts = [] if len(opts) == 0: return opts has_equal = opts[0].find("=") != -1 if has_equal: return opts return [(opt + "=" + value) for opt, value in zip(opts[0::2], opts[1::2])] def get_config(self): return self.config @property def run_cfg(self): return self.config.run @property def datasets_cfg(self): return self.config.datasets @property def model_cfg(self): return self.config.model def pretty_print(self): logging.info("\n===== Running Parameters =====") logging.info(self._convert_node_to_json(self.config.run)) logging.info("\n====== Dataset Attributes ======") datasets = self.config.datasets for dataset in datasets: if dataset in self.config.datasets: logging.info(f"\n======== {dataset} =======") dataset_config = self.config.datasets[dataset] logging.info(self._convert_node_to_json(dataset_config)) else: logging.warning(f"No dataset named '{dataset}' in config. Skipping") logging.info(f"\n====== Model Attributes ======") logging.info(self._convert_node_to_json(self.config.model)) def _convert_node_to_json(self, node): container = OmegaConf.to_container(node, resolve=True) return json.dumps(container, indent=4, sort_keys=True) def to_dict(self): return OmegaConf.to_container(self.config) def node_to_dict(node): return OmegaConf.to_container(node) class ConfigValidator: """ This is a preliminary implementation to centralize and validate the configuration. May be altered in the future. A helper class to validate configurations from yaml file. This serves the following purposes: 1. Ensure all the options in the yaml are defined, raise error if not. 2. when type mismatches are found, the validator will raise an error. 3. a central place to store and display helpful messages for supported configurations. """ class _Argument: def __init__(self, name, choices=None, type=None, help=None): self.name = name self.val = None self.choices = choices self.type = type self.help = help def __str__(self): s = f"{self.name}={self.val}" if self.type is not None: s += f", ({self.type})" if self.choices is not None: s += f", choices: {self.choices}" if self.help is not None: s += f", ({self.help})" return s def __init__(self, description): self.description = description self.arguments = dict() self.parsed_args = None def __getitem__(self, key): assert self.parsed_args is not None, "No arguments parsed yet." return self.parsed_args[key] def __str__(self) -> str: return self.format_help() def add_argument(self, *args, **kwargs): """ Assume the first argument is the name of the argument. """ self.arguments[args[0]] = self._Argument(*args, **kwargs) def validate(self, config=None): """ Convert yaml config (dict-like) to list, required by argparse. """ for k, v in config.items(): assert ( k in self.arguments ), f"""{k} is not a valid argument. Support arguments are {self.format_arguments()}.""" if self.arguments[k].type is not None: try: self.arguments[k].val = self.arguments[k].type(v) except ValueError: raise ValueError(f"{k} is not a valid {self.arguments[k].type}.") if self.arguments[k].choices is not None: assert ( v in self.arguments[k].choices ), f"""{k} must be one of {self.arguments[k].choices}.""" return config def format_arguments(self): return str([f"{k}" for k in sorted(self.arguments.keys())]) def format_help(self): # description + key-value pair string for each argument help_msg = str(self.description) return help_msg + ", available arguments: " + self.format_arguments() def print_help(self): # display help message print(self.format_help()) def create_runner_config_validator(): validator = ConfigValidator(description="Runner configurations") validator.add_argument( "runner", type=str, choices=["runner_base", "runner_iter"], help="""Runner to use. The "runner_base" uses epoch-based training while iter-based runner runs based on iters. Default: runner_base""", ) # add argumetns for training dataset ratios validator.add_argument( "train_dataset_ratios", type=Dict[str, float], help="""Ratios of training dataset. This is used in iteration-based runner. Do not support for epoch-based runner because how to define an epoch becomes tricky. Default: None""", ) validator.add_argument( "max_iters", type=float, help="Maximum number of iterations to run.", ) validator.add_argument( "max_epoch", type=int, help="Maximum number of epochs to run.", ) # add arguments for iters_per_inner_epoch validator.add_argument( "iters_per_inner_epoch", type=float, help="Number of iterations per inner epoch. This is required when runner is runner_iter.", ) lr_scheds_choices = registry.list_lr_schedulers() validator.add_argument( "lr_sched", type=str, choices=lr_scheds_choices, help="Learning rate scheduler to use, from {}".format(lr_scheds_choices), ) task_choices = registry.list_tasks() validator.add_argument( "task", type=str, choices=task_choices, help="Task to use, from {}".format(task_choices), ) # add arguments for init_lr validator.add_argument( "init_lr", type=float, help="Initial learning rate. This will be the learning rate after warmup and before decay.", ) # add arguments for min_lr validator.add_argument( "min_lr", type=float, help="Minimum learning rate (after decay).", ) # add arguments for warmup_lr validator.add_argument( "warmup_lr", type=float, help="Starting learning rate for warmup.", ) # add arguments for learning rate decay rate validator.add_argument( "lr_decay_rate", type=float, help="Learning rate decay rate. Required if using a decaying learning rate scheduler.", ) # add arguments for weight decay validator.add_argument( "weight_decay", type=float, help="Weight decay rate.", ) # add arguments for training batch size validator.add_argument( "batch_size_train", type=int, help="Training batch size.", ) # add arguments for evaluation batch size validator.add_argument( "batch_size_eval", type=int, help="Evaluation batch size, including validation and testing.", ) # add arguments for number of workers for data loading validator.add_argument( "num_workers", help="Number of workers for data loading.", ) # add arguments for warm up steps validator.add_argument( "warmup_steps", type=int, help="Number of warmup steps. Required if a warmup schedule is used.", ) # add arguments for random seed validator.add_argument( "seed", type=int, help="Random seed.", ) # add arguments for output directory validator.add_argument( "output_dir", type=str, help="Output directory to save checkpoints and logs.", ) # add arguments for whether only use evaluation validator.add_argument( "evaluate", help="Whether to only evaluate the model. If true, training will not be performed.", ) # add arguments for splits used for training, e.g. ["train", "val"] validator.add_argument( "train_splits", type=list, help="Splits to use for training.", ) # add arguments for splits used for validation, e.g. ["val"] validator.add_argument( "valid_splits", type=list, help="Splits to use for validation. If not provided, will skip the validation.", ) # add arguments for splits used for testing, e.g. ["test"] validator.add_argument( "test_splits", type=list, help="Splits to use for testing. If not provided, will skip the testing.", ) # add arguments for accumulating gradient for iterations validator.add_argument( "accum_grad_iters", type=int, help="Number of iterations to accumulate gradient for.", ) # ====== distributed training ====== validator.add_argument( "device", type=str, choices=["cpu", "cuda"], help="Device to use. Support 'cuda' or 'cpu' as for now.", ) validator.add_argument( "world_size", type=int, help="Number of processes participating in the job.", ) validator.add_argument("dist_url", type=str) validator.add_argument("distributed", type=bool) # add arguments to opt using distributed sampler during evaluation or not validator.add_argument( "use_dist_eval_sampler", type=bool, help="Whether to use distributed sampler during evaluation or not.", ) # ====== task specific ====== # generation task specific arguments # add arguments for maximal length of text output validator.add_argument( "max_len", type=int, help="Maximal length of text output.", ) # add arguments for minimal length of text output validator.add_argument( "min_len", type=int, help="Minimal length of text output.", ) # add arguments number of beams validator.add_argument( "num_beams", type=int, help="Number of beams used for beam search.", ) # vqa task specific arguments # add arguments for number of answer candidates validator.add_argument( "num_ans_candidates", type=int, help="""For ALBEF and BLIP, these models first rank answers according to likelihood to select answer candidates.""", ) # add arguments for inference method validator.add_argument( "inference_method", type=str, choices=["genearte", "rank"], help="""Inference method to use for question answering. If rank, requires a answer list.""", ) # ====== model specific ====== validator.add_argument( "k_test", type=int, help="Number of top k most similar samples from ITC/VTC selection to be tested.", ) return validator