# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import re from mmengine.config import Config def replace_cfg_vals(ori_cfg): """Replace the string "${key}" with the corresponding value. Replace the "${key}" with the value of ori_cfg.key in the config. And support replacing the chained ${key}. Such as, replace "${key0.key1}" with the value of cfg.key0.key1. Code is modified from `vars.py < https://github.com/microsoft/SoftTeacher/blob/main/ssod/utils/vars.py>`_ # noqa: E501 Args: ori_cfg (mmengine.config.Config): The origin config with "${key}" generated from a file. Returns: updated_cfg [mmengine.config.Config]: The config with "${key}" replaced by the corresponding value. """ def get_value(cfg, key): for k in key.split('.'): cfg = cfg[k] return cfg def replace_value(cfg): if isinstance(cfg, dict): return {key: replace_value(value) for key, value in cfg.items()} elif isinstance(cfg, list): return [replace_value(item) for item in cfg] elif isinstance(cfg, tuple): return tuple([replace_value(item) for item in cfg]) elif isinstance(cfg, str): # the format of string cfg may be: # 1) "${key}", which will be replaced with cfg.key directly # 2) "xxx${key}xxx" or "xxx${key1}xxx${key2}xxx", # which will be replaced with the string of the cfg.key keys = pattern_key.findall(cfg) values = [get_value(ori_cfg, key[2:-1]) for key in keys] if len(keys) == 1 and keys[0] == cfg: # the format of string cfg is "${key}" cfg = values[0] else: for key, value in zip(keys, values): # the format of string cfg is # "xxx${key}xxx" or "xxx${key1}xxx${key2}xxx" assert not isinstance(value, (dict, list, tuple)), \ f'for the format of string cfg is ' \ f"'xxxxx${key}xxxxx' or 'xxx${key}xxx${key}xxx', " \ f"the type of the value of '${key}' " \ f'can not be dict, list, or tuple' \ f'but you input {type(value)} in {cfg}' cfg = cfg.replace(key, str(value)) return cfg else: return cfg # the pattern of string "${key}" pattern_key = re.compile(r'\$\{[a-zA-Z\d_.]*\}') # the type of ori_cfg._cfg_dict is mmengine.config.ConfigDict updated_cfg = Config( replace_value(ori_cfg._cfg_dict), filename=ori_cfg.filename) # replace the model with model_wrapper if updated_cfg.get('model_wrapper', None) is not None: updated_cfg.model = updated_cfg.model_wrapper updated_cfg.pop('model_wrapper') return updated_cfg