Spaces:
Sleeping
Sleeping
File size: 5,304 Bytes
dffe47c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
# SOURCE: https://github.com/Sulam-Group/IBYDMT/blob/main/ibydmt/utils/config.py
import os
from dataclasses import dataclass
from enum import Enum
from itertools import product
from typing import Any, Iterable, Mapping, Optional, Union
import torch
from ml_collections import ConfigDict
from numpy import ndarray
Array = Union[ndarray, torch.Tensor]
class TestType(Enum):
GLOBAL = "global"
GLOBAL_COND = "global_cond"
LOCAL_COND = "local_cond"
class ConceptType(Enum):
DATASET = "dataset"
CLASS = "class"
IMAGE = "image"
@dataclass
class Constants:
WORKDIR = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class DataConfig(ConfigDict):
def __init__(self, config_dict: Optional[Mapping[str, Any]] = None):
super().__init__()
if config_dict is None:
config_dict = {}
self.dataset: str = config_dict.get("dataset", None)
self.backbone: str = config_dict.get("backbone", None)
self.bottleneck: str = config_dict.get("bottleneck", None)
self.classifier: str = config_dict.get("classifier", None)
self.sampler: str = config_dict.get("sampler", None)
self.num_concepts: int = config_dict.get("num_concepts", None)
class SpliceConfig(ConfigDict):
def __init__(self, config_dict: Optional[Mapping[str, Any]] = None):
super().__init__()
if config_dict is None:
config_dict = {}
self.vocab: str = config_dict.get("vocab", None)
self.vocab_size: int = config_dict.get("vocab_size", None)
self.l1_penalty: float = config_dict.get("l1_penalty", None)
class PCBMConfig(ConfigDict):
def __init__(self, config_dict: Optional[Mapping[str, Any]] = None):
super().__init__()
if config_dict is None:
config_dict = {}
self.alpha: float = config_dict.get("alpha", None)
self.l1_ratio: float = config_dict.get("l1_ratio", None)
class cKDEConfig(ConfigDict):
def __init__(self, config_dict: Optional[Mapping[str, Any]] = None):
super().__init__()
if config_dict is None:
config_dict = {}
self.metric: str = config_dict.get("metric", None)
self.scale_method: str = config_dict.get("scale_method", None)
self.scale: float = config_dict.get("scale", None)
class TestingConfig(ConfigDict):
def __init__(self, config_dict: Optional[Mapping[str, Any]] = None):
super().__init__()
if config_dict is None:
config_dict = {}
self.significance_level: float = config_dict.get("significance_level", None)
self.wealth: str = config_dict.get("wealth", None)
self.bet: str = config_dict.get("bet", None)
self.kernel: str = config_dict.get("kernel", None)
self.kernel_scale_method: str = config_dict.get("kernel_scale_method", None)
self.kernel_scale: float = config_dict.get("kernel_scale", None)
self.tau_max: int = config_dict.get("tau_max", None)
self.images_per_class: int = config_dict.get("images_per_class", None)
self.r: int = config_dict.get("r", None)
self.cardinality: Iterable[int] = config_dict.get("cardinality", None)
class Config(ConfigDict):
def __init__(self, config_dict: Optional[Mapping[str, Any]] = None):
super().__init__()
if config_dict is None:
config_dict = {}
self.name: str = config_dict.get("name", None)
self.data = DataConfig(config_dict.get("data", None))
self.splice = SpliceConfig(config_dict.get("splice", None))
self.pcbm = PCBMConfig(config_dict.get("pcbm", None))
self.ckde = cKDEConfig(config_dict.get("ckde", None))
self.testing = TestingConfig(config_dict.get("testing", None))
def backbone_name(self):
backbone = self.data.backbone.strip().lower()
return backbone.replace("/", "_").replace(":", "_")
def sweep(self, keys: Iterable[str]):
def _get(dict, key):
keys = key.split(".")
if len(keys) == 1:
return dict[keys[0]]
else:
return _get(dict[keys[0]], ".".join(keys[1:]))
def _set(dict, key, value):
keys = key.split(".")
if len(keys) == 1:
dict[keys[0]] = value
else:
_set(dict[keys[0]], ".".join(keys[1:]), value)
to_iterable = lambda v: v if isinstance(v, list) else [v]
config_dict = self.to_dict()
sweep_values = [_get(config_dict, key) for key in keys]
sweep = list(product(*map(to_iterable, sweep_values)))
configs: Iterable[Config] = []
for _sweep in sweep:
_config_dict = config_dict.copy()
for key, value in zip(keys, _sweep):
_set(_config_dict, key, value)
configs.append(Config(_config_dict))
return configs
def register_config(name: str):
def register(cls: Config):
if name in configs:
raise ValueError(f"Config {name} is already registered")
configs[name] = cls
return register
def get_config(name: str) -> Config:
return configs[name]()
configs: Mapping[str, Config] = {}
|