|
import dataclasses |
|
import logging |
|
from dataclasses import dataclass, fields |
|
from typing import Any, Dict, List, Optional, Sequence, Set, Tuple |
|
|
|
from llm_studio.src import possible_values |
|
from llm_studio.src.nesting import Dependency, Nesting |
|
from llm_studio.src.order import Order |
|
from llm_studio.src.tooltips import tooltips |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
def _get_bases_below_parent(cls: type, parent: type, bases=None) -> Set[type]: |
|
if bases is None: |
|
bases = set() |
|
|
|
if parent not in cls.__bases__: |
|
for base in cls.__bases__: |
|
bases.update(_get_bases_below_parent(base, parent, bases)) |
|
else: |
|
|
|
|
|
assert len(cls.__bases__) == 1 |
|
|
|
bases.add(cls) |
|
|
|
return bases |
|
|
|
|
|
@dataclass |
|
class DefaultConfig: |
|
""" |
|
Template for any configuration file |
|
""" |
|
|
|
def __post_init__(self): |
|
self._possible_values: Dict[str, Any] = {k: None for k in self.__dict__} |
|
self._visibility = {k: 0 for k in self.__dict__} |
|
|
|
|
|
bases = _get_bases_below_parent(self.__class__, DefaultConfig) |
|
|
|
|
|
|
|
assert len(bases) == 1 |
|
base = next(iter(bases)) |
|
|
|
|
|
self._order = Order([field.name for field in fields(base)]) |
|
|
|
|
|
self._nesting = Nesting() |
|
|
|
def _get_possible_values( |
|
self, field: str, value: Any, type_annotation: type, mode: str, dataset_fn=None |
|
) -> Optional[Tuple[Optional[possible_values.Value], Any]]: |
|
""" |
|
Returns a set of possible values for the field provided, and the current value. |
|
|
|
Args: |
|
field: the field |
|
value: the preliminary value of the field. |
|
type_annotation: Type Annotation of the field. |
|
mode: current mode, one of {"train", "test", "predict"}. |
|
dataset_fn: A function returning a tuple (dataset, value). Will be called |
|
if the possible values depend on the dataset. |
|
|
|
Returns: |
|
Possible values for the field, the current value. |
|
""" |
|
|
|
poss_values = self._possible_values.get(field, None) |
|
|
|
if isinstance(poss_values, possible_values.DatasetValue): |
|
if dataset_fn is None: |
|
raise ValueError( |
|
f"{poss_values} needs a dataset to compute possible values!\n" |
|
"`dataset_fn` must be provided." |
|
) |
|
|
|
dataset, value = dataset_fn(field, value) |
|
poss_values, value = poss_values.get_value( |
|
dataset=dataset, value=value, type_annotation=type_annotation, mode=mode |
|
) |
|
elif isinstance(poss_values, Sequence): |
|
if all(isinstance(x, (float, int)) for x in poss_values): |
|
poss_values = possible_values.Number( |
|
min=poss_values[0], max=poss_values[1], step=poss_values[2] |
|
) |
|
elif all(isinstance(x, str) for x in poss_values): |
|
poss_values = possible_values.String(tuple(poss_values)) |
|
else: |
|
raise ValueError( |
|
f"Could not interpret {poss_values} as any possible value class." |
|
) |
|
|
|
return poss_values, value |
|
|
|
def _get_tooltips(self, field: str, predict: bool = False) -> Optional[str]: |
|
""" |
|
Returns a tooltip for the field provided |
|
""" |
|
return tooltips.get(f"experiments_{field}", None) |
|
|
|
def _get_visibility(self, field: str) -> Optional[int]: |
|
"""Returns a visibility level for the field provided. |
|
0 -- visible in the Wave app |
|
-1 -- not visible in the Wave App |
|
-2 -- visible in Dataset Import, but not visible in Create Experiment |
|
""" |
|
|
|
return self._visibility.get(field, None) |
|
|
|
def _get_nesting_triggers(self) -> Set[str]: |
|
"""Returns a Set of keys other elements are depending on""" |
|
|
|
return self._nesting.triggers |
|
|
|
def _get_nesting_dependencies(self, key: str) -> List[Dependency] | None: |
|
"""Returns a all dependencies for a given key""" |
|
|
|
if key in self._nesting.dependencies: |
|
dependencies = self._nesting.dependencies[key] |
|
else: |
|
dependencies = None |
|
return dependencies |
|
|
|
def _get_order(self, warn_if_unset=True) -> List[str]: |
|
""" |
|
Returns the order in which to show the keys in the config. |
|
|
|
Args: |
|
warn_if_unset: Whether to log a warning if order is unset for multiple keys. |
|
|
|
Returns: |
|
A list of the same length and with same elements as `self.__dict__.keys()`. |
|
""" |
|
|
|
keys = self.__dict__.keys() |
|
|
|
ordered_keys = [key for key in self._order if key in keys] |
|
unordered_keys = list(set(keys) - set(ordered_keys)) |
|
|
|
unordered_ui_keys = [ |
|
key |
|
for key in unordered_keys |
|
if not (key.startswith("_") or self._get_visibility(key) == -1) |
|
] |
|
|
|
|
|
|
|
if warn_if_unset and len(unordered_ui_keys) > 1: |
|
logger.warning(f"No order set for keys: {unordered_ui_keys}.") |
|
|
|
return ordered_keys + unordered_keys |
|
|
|
@classmethod |
|
def get_annotations(cls): |
|
"""Returns type annotations through all the Parent config classes""" |
|
|
|
d: Dict[str, Any] = {} |
|
for c in cls.mro()[::-1]: |
|
try: |
|
d.update(**c.__annotations__) |
|
except AttributeError: |
|
|
|
pass |
|
return d |
|
|
|
@classmethod |
|
def from_dict(cls, d: dict): |
|
"""Creates a config object from a dictionary""" |
|
d_filtered = {k: v for k, v in d.items() if k in cls.get_annotations()} |
|
if len(d) != len(d_filtered): |
|
logger.warning( |
|
f"Keys {set(d.keys()) - set(d_filtered.keys())} are not in the config." |
|
) |
|
return cls(**d_filtered) |
|
|
|
|
|
@dataclass |
|
class DefaultConfigProblemBase(DefaultConfig): |
|
""" |
|
Base class for all problem configs. |
|
Defines the interface for all problem configs. |
|
""" |
|
|
|
experiment_name: str |
|
output_directory: str |
|
llm_backbone: str |
|
|
|
dataset: Any |
|
tokenizer: Any |
|
architecture: Any |
|
training: Any |
|
augmentation: Any |
|
prediction: Any |
|
environment: Any |
|
logging: Any |
|
|
|
@property |
|
def problem_type(self) -> str: |
|
""" |
|
Parse problem_type from config filename, |
|
for example: text_causal_language_modeling_config.py -> causal_language_modeling |
|
""" |
|
return type(self).__dict__["__module__"].split(".")[-1].replace("_config", "") |
|
|
|
@classmethod |
|
def from_dict(cls, cfg_dict: dict): |
|
class_fields = {f.name: f for f in dataclasses.fields(cls)} |
|
|
|
|
|
init_args = {} |
|
for field_name, field_obj in class_fields.items(): |
|
if hasattr(field_obj.type, "from_dict"): |
|
attr_value = cfg_dict.get(field_name, {}) |
|
init_args[field_name] = field_obj.type.from_dict(attr_value) |
|
else: |
|
|
|
|
|
init_args[field_name] = cfg_dict.get(field_name, field_obj.default) |
|
|
|
return cls(**init_args) |
|
|
|
def check(self) -> Dict[str, List]: |
|
""" |
|
Checks for errors (incompatible settings) for the specific problem type. |
|
Returns: |
|
A dictionary with two keys: |
|
- "title": A list of error titles. |
|
- "message": A list of error messages. |
|
""" |
|
errors: Dict[str, List] = {"title": [], "message": []} |
|
return errors |
|
|