elineve's picture
Upload 301 files
07423df
raw
history blame
8.17 kB
import dataclasses
import logging
from dataclasses import dataclass, fields
from typing import Any, Dict, List, Optional, Sequence, Set, Tuple
from llm_studio.src import possible_values
from llm_studio.src.nesting import Dependency, Nesting
from llm_studio.src.order import Order
from llm_studio.src.tooltips import tooltips
logger = logging.getLogger(__name__)
def _get_bases_below_parent(cls: type, parent: type, bases=None) -> Set[type]:
if bases is None:
bases = set()
if parent not in cls.__bases__:
for base in cls.__bases__:
bases.update(_get_bases_below_parent(base, parent, bases))
else:
# don't support multiple inheritance when
# inherting directly from the parent
assert len(cls.__bases__) == 1
bases.add(cls)
return bases
@dataclass
class DefaultConfig:
"""
Template for any configuration file
"""
def __post_init__(self):
self._possible_values: Dict[str, Any] = {k: None for k in self.__dict__}
self._visibility = {k: 0 for k in self.__dict__}
# go up the class hierarchy until we are one below the `DefaultConfig`
bases = _get_bases_below_parent(self.__class__, DefaultConfig)
# there must be exactly one unique class up the class hierarchy
# which inherits directly from the `DefaultConfig`
assert len(bases) == 1
base = next(iter(bases))
# initialize the order to the fields this class has
self._order = Order([field.name for field in fields(base)])
# initialize nesting dependencies
self._nesting = Nesting()
def _get_possible_values(
self, field: str, value: Any, type_annotation: type, mode: str, dataset_fn=None
) -> Optional[Tuple[Optional[possible_values.Value], Any]]:
"""
Returns a set of possible values for the field provided, and the current value.
Args:
field: the field
value: the preliminary value of the field.
type_annotation: Type Annotation of the field.
mode: current mode, one of {"train", "test", "predict"}.
dataset_fn: A function returning a tuple (dataset, value). Will be called
if the possible values depend on the dataset.
Returns:
Possible values for the field, the current value.
"""
poss_values = self._possible_values.get(field, None)
if isinstance(poss_values, possible_values.DatasetValue):
if dataset_fn is None:
raise ValueError(
f"{poss_values} needs a dataset to compute possible values!\n"
"`dataset_fn` must be provided."
)
dataset, value = dataset_fn(field, value)
poss_values, value = poss_values.get_value(
dataset=dataset, value=value, type_annotation=type_annotation, mode=mode
)
elif isinstance(poss_values, Sequence):
if all(isinstance(x, (float, int)) for x in poss_values):
poss_values = possible_values.Number(
min=poss_values[0], max=poss_values[1], step=poss_values[2]
)
elif all(isinstance(x, str) for x in poss_values):
poss_values = possible_values.String(tuple(poss_values))
else:
raise ValueError(
f"Could not interpret {poss_values} as any possible value class."
)
return poss_values, value
def _get_tooltips(self, field: str, predict: bool = False) -> Optional[str]:
"""
Returns a tooltip for the field provided
"""
return tooltips.get(f"experiments_{field}", None)
def _get_visibility(self, field: str) -> Optional[int]:
"""Returns a visibility level for the field provided.
0 -- visible in the Wave app
-1 -- not visible in the Wave App
-2 -- visible in Dataset Import, but not visible in Create Experiment
"""
return self._visibility.get(field, None)
def _get_nesting_triggers(self) -> Set[str]:
"""Returns a Set of keys other elements are depending on"""
return self._nesting.triggers
def _get_nesting_dependencies(self, key: str) -> List[Dependency] | None:
"""Returns a all dependencies for a given key"""
if key in self._nesting.dependencies:
dependencies = self._nesting.dependencies[key]
else:
dependencies = None
return dependencies
def _get_order(self, warn_if_unset=True) -> List[str]:
"""
Returns the order in which to show the keys in the config.
Args:
warn_if_unset: Whether to log a warning if order is unset for multiple keys.
Returns:
A list of the same length and with same elements as `self.__dict__.keys()`.
"""
keys = self.__dict__.keys()
ordered_keys = [key for key in self._order if key in keys]
unordered_keys = list(set(keys) - set(ordered_keys))
unordered_ui_keys = [
key
for key in unordered_keys
if not (key.startswith("_") or self._get_visibility(key) == -1)
]
# warn if there is more than one key without order.
# one is not problematic since it will just always be last
if warn_if_unset and len(unordered_ui_keys) > 1:
logger.warning(f"No order set for keys: {unordered_ui_keys}.")
return ordered_keys + unordered_keys
@classmethod
def get_annotations(cls):
"""Returns type annotations through all the Parent config classes"""
d: Dict[str, Any] = {}
for c in cls.mro()[::-1]:
try:
d.update(**c.__annotations__)
except AttributeError:
# object, at least, has no __annotations__ attribute.
pass
return d
@classmethod
def from_dict(cls, d: dict):
"""Creates a config object from a dictionary"""
d_filtered = {k: v for k, v in d.items() if k in cls.get_annotations()}
if len(d) != len(d_filtered):
logger.warning(
f"Keys {set(d.keys()) - set(d_filtered.keys())} are not in the config."
)
return cls(**d_filtered) # mypy: ignore
@dataclass
class DefaultConfigProblemBase(DefaultConfig):
"""
Base class for all problem configs.
Defines the interface for all problem configs.
"""
experiment_name: str
output_directory: str
llm_backbone: str
dataset: Any
tokenizer: Any
architecture: Any
training: Any
augmentation: Any
prediction: Any
environment: Any
logging: Any
@property
def problem_type(self) -> str:
"""
Parse problem_type from config filename,
for example: text_causal_language_modeling_config.py -> causal_language_modeling
"""
return type(self).__dict__["__module__"].split(".")[-1].replace("_config", "")
@classmethod
def from_dict(cls, cfg_dict: dict):
class_fields = {f.name: f for f in dataclasses.fields(cls)}
# Prepare arguments for creating a new dataclass instance
init_args = {}
for field_name, field_obj in class_fields.items():
if hasattr(field_obj.type, "from_dict"):
attr_value = cfg_dict.get(field_name, {})
init_args[field_name] = field_obj.type.from_dict(attr_value)
else:
# Use the value from cfg_dict,
# or the field's default value if not available in cfg_dict
init_args[field_name] = cfg_dict.get(field_name, field_obj.default)
return cls(**init_args)
def check(self) -> Dict[str, List]:
"""
Checks for errors (incompatible settings) for the specific problem type.
Returns:
A dictionary with two keys:
- "title": A list of error titles.
- "message": A list of error messages.
"""
errors: Dict[str, List] = {"title": [], "message": []}
return errors