# SOURCE: https://github.com/Sulam-Group/IBYDMT/blob/main/ibydmt/utils/config.py import os from dataclasses import dataclass from enum import Enum from itertools import product from typing import Any, Iterable, Mapping, Optional, Union import torch from ml_collections import ConfigDict from numpy import ndarray Array = Union[ndarray, torch.Tensor] class TestType(Enum): GLOBAL = "global" GLOBAL_COND = "global_cond" LOCAL_COND = "local_cond" class ConceptType(Enum): DATASET = "dataset" CLASS = "class" IMAGE = "image" @dataclass class Constants: WORKDIR = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") class DataConfig(ConfigDict): def __init__(self, config_dict: Optional[Mapping[str, Any]] = None): super().__init__() if config_dict is None: config_dict = {} self.dataset: str = config_dict.get("dataset", None) self.backbone: str = config_dict.get("backbone", None) self.bottleneck: str = config_dict.get("bottleneck", None) self.classifier: str = config_dict.get("classifier", None) self.sampler: str = config_dict.get("sampler", None) self.num_concepts: int = config_dict.get("num_concepts", None) class SpliceConfig(ConfigDict): def __init__(self, config_dict: Optional[Mapping[str, Any]] = None): super().__init__() if config_dict is None: config_dict = {} self.vocab: str = config_dict.get("vocab", None) self.vocab_size: int = config_dict.get("vocab_size", None) self.l1_penalty: float = config_dict.get("l1_penalty", None) class PCBMConfig(ConfigDict): def __init__(self, config_dict: Optional[Mapping[str, Any]] = None): super().__init__() if config_dict is None: config_dict = {} self.alpha: float = config_dict.get("alpha", None) self.l1_ratio: float = config_dict.get("l1_ratio", None) class cKDEConfig(ConfigDict): def __init__(self, config_dict: Optional[Mapping[str, Any]] = None): super().__init__() if config_dict is None: config_dict = {} self.metric: str = config_dict.get("metric", None) self.scale_method: str = config_dict.get("scale_method", None) self.scale: float = config_dict.get("scale", None) class TestingConfig(ConfigDict): def __init__(self, config_dict: Optional[Mapping[str, Any]] = None): super().__init__() if config_dict is None: config_dict = {} self.significance_level: float = config_dict.get("significance_level", None) self.wealth: str = config_dict.get("wealth", None) self.bet: str = config_dict.get("bet", None) self.kernel: str = config_dict.get("kernel", None) self.kernel_scale_method: str = config_dict.get("kernel_scale_method", None) self.kernel_scale: float = config_dict.get("kernel_scale", None) self.tau_max: int = config_dict.get("tau_max", None) self.images_per_class: int = config_dict.get("images_per_class", None) self.r: int = config_dict.get("r", None) self.cardinality: Iterable[int] = config_dict.get("cardinality", None) class Config(ConfigDict): def __init__(self, config_dict: Optional[Mapping[str, Any]] = None): super().__init__() if config_dict is None: config_dict = {} self.name: str = config_dict.get("name", None) self.data = DataConfig(config_dict.get("data", None)) self.splice = SpliceConfig(config_dict.get("splice", None)) self.pcbm = PCBMConfig(config_dict.get("pcbm", None)) self.ckde = cKDEConfig(config_dict.get("ckde", None)) self.testing = TestingConfig(config_dict.get("testing", None)) def backbone_name(self): backbone = self.data.backbone.strip().lower() return backbone.replace("/", "_").replace(":", "_") def sweep(self, keys: Iterable[str]): def _get(dict, key): keys = key.split(".") if len(keys) == 1: return dict[keys[0]] else: return _get(dict[keys[0]], ".".join(keys[1:])) def _set(dict, key, value): keys = key.split(".") if len(keys) == 1: dict[keys[0]] = value else: _set(dict[keys[0]], ".".join(keys[1:]), value) to_iterable = lambda v: v if isinstance(v, list) else [v] config_dict = self.to_dict() sweep_values = [_get(config_dict, key) for key in keys] sweep = list(product(*map(to_iterable, sweep_values))) configs: Iterable[Config] = [] for _sweep in sweep: _config_dict = config_dict.copy() for key, value in zip(keys, _sweep): _set(_config_dict, key, value) configs.append(Config(_config_dict)) return configs def register_config(name: str): def register(cls: Config): if name in configs: raise ValueError(f"Config {name} is already registered") configs[name] = cls return register def get_config(name: str) -> Config: return configs[name]() configs: Mapping[str, Config] = {}