elineve's picture
Upload 301 files
07423df
import dataclasses
import logging
from dataclasses import dataclass, fields
from typing import Any, Dict, List, Optional, Sequence, Set, Tuple
from llm_studio.src import possible_values
from llm_studio.src.nesting import Dependency, Nesting
from llm_studio.src.order import Order
from llm_studio.src.tooltips import tooltips
logger = logging.getLogger(__name__)
def _get_bases_below_parent(cls: type, parent: type, bases=None) -> Set[type]:
if bases is None:
bases = set()
if parent not in cls.__bases__:
for base in cls.__bases__:
bases.update(_get_bases_below_parent(base, parent, bases))
else:
# don't support multiple inheritance when
# inherting directly from the parent
assert len(cls.__bases__) == 1
bases.add(cls)
return bases
@dataclass
class DefaultConfig:
"""
Template for any configuration file
"""
def __post_init__(self):
self._possible_values: Dict[str, Any] = {k: None for k in self.__dict__}
self._visibility = {k: 0 for k in self.__dict__}
# go up the class hierarchy until we are one below the `DefaultConfig`
bases = _get_bases_below_parent(self.__class__, DefaultConfig)
# there must be exactly one unique class up the class hierarchy
# which inherits directly from the `DefaultConfig`
assert len(bases) == 1
base = next(iter(bases))
# initialize the order to the fields this class has
self._order = Order([field.name for field in fields(base)])
# initialize nesting dependencies
self._nesting = Nesting()
def _get_possible_values(
self, field: str, value: Any, type_annotation: type, mode: str, dataset_fn=None
) -> Optional[Tuple[Optional[possible_values.Value], Any]]:
"""
Returns a set of possible values for the field provided, and the current value.
Args:
field: the field
value: the preliminary value of the field.
type_annotation: Type Annotation of the field.
mode: current mode, one of {"train", "test", "predict"}.
dataset_fn: A function returning a tuple (dataset, value). Will be called
if the possible values depend on the dataset.
Returns:
Possible values for the field, the current value.
"""
poss_values = self._possible_values.get(field, None)
if isinstance(poss_values, possible_values.DatasetValue):
if dataset_fn is None:
raise ValueError(
f"{poss_values} needs a dataset to compute possible values!\n"
"`dataset_fn` must be provided."
)
dataset, value = dataset_fn(field, value)
poss_values, value = poss_values.get_value(
dataset=dataset, value=value, type_annotation=type_annotation, mode=mode
)
elif isinstance(poss_values, Sequence):
if all(isinstance(x, (float, int)) for x in poss_values):
poss_values = possible_values.Number(
min=poss_values[0], max=poss_values[1], step=poss_values[2]
)
elif all(isinstance(x, str) for x in poss_values):
poss_values = possible_values.String(tuple(poss_values))
else:
raise ValueError(
f"Could not interpret {poss_values} as any possible value class."
)
return poss_values, value
def _get_tooltips(self, field: str, predict: bool = False) -> Optional[str]:
"""
Returns a tooltip for the field provided
"""
return tooltips.get(f"experiments_{field}", None)
def _get_visibility(self, field: str) -> Optional[int]:
"""Returns a visibility level for the field provided.
0 -- visible in the Wave app
-1 -- not visible in the Wave App
-2 -- visible in Dataset Import, but not visible in Create Experiment
"""
return self._visibility.get(field, None)
def _get_nesting_triggers(self) -> Set[str]:
"""Returns a Set of keys other elements are depending on"""
return self._nesting.triggers
def _get_nesting_dependencies(self, key: str) -> List[Dependency] | None:
"""Returns a all dependencies for a given key"""
if key in self._nesting.dependencies:
dependencies = self._nesting.dependencies[key]
else:
dependencies = None
return dependencies
def _get_order(self, warn_if_unset=True) -> List[str]:
"""
Returns the order in which to show the keys in the config.
Args:
warn_if_unset: Whether to log a warning if order is unset for multiple keys.
Returns:
A list of the same length and with same elements as `self.__dict__.keys()`.
"""
keys = self.__dict__.keys()
ordered_keys = [key for key in self._order if key in keys]
unordered_keys = list(set(keys) - set(ordered_keys))
unordered_ui_keys = [
key
for key in unordered_keys
if not (key.startswith("_") or self._get_visibility(key) == -1)
]
# warn if there is more than one key without order.
# one is not problematic since it will just always be last
if warn_if_unset and len(unordered_ui_keys) > 1:
logger.warning(f"No order set for keys: {unordered_ui_keys}.")
return ordered_keys + unordered_keys
@classmethod
def get_annotations(cls):
"""Returns type annotations through all the Parent config classes"""
d: Dict[str, Any] = {}
for c in cls.mro()[::-1]:
try:
d.update(**c.__annotations__)
except AttributeError:
# object, at least, has no __annotations__ attribute.
pass
return d
@classmethod
def from_dict(cls, d: dict):
"""Creates a config object from a dictionary"""
d_filtered = {k: v for k, v in d.items() if k in cls.get_annotations()}
if len(d) != len(d_filtered):
logger.warning(
f"Keys {set(d.keys()) - set(d_filtered.keys())} are not in the config."
)
return cls(**d_filtered) # mypy: ignore
@dataclass
class DefaultConfigProblemBase(DefaultConfig):
"""
Base class for all problem configs.
Defines the interface for all problem configs.
"""
experiment_name: str
output_directory: str
llm_backbone: str
dataset: Any
tokenizer: Any
architecture: Any
training: Any
augmentation: Any
prediction: Any
environment: Any
logging: Any
@property
def problem_type(self) -> str:
"""
Parse problem_type from config filename,
for example: text_causal_language_modeling_config.py -> causal_language_modeling
"""
return type(self).__dict__["__module__"].split(".")[-1].replace("_config", "")
@classmethod
def from_dict(cls, cfg_dict: dict):
class_fields = {f.name: f for f in dataclasses.fields(cls)}
# Prepare arguments for creating a new dataclass instance
init_args = {}
for field_name, field_obj in class_fields.items():
if hasattr(field_obj.type, "from_dict"):
attr_value = cfg_dict.get(field_name, {})
init_args[field_name] = field_obj.type.from_dict(attr_value)
else:
# Use the value from cfg_dict,
# or the field's default value if not available in cfg_dict
init_args[field_name] = cfg_dict.get(field_name, field_obj.default)
return cls(**init_args)
def check(self) -> Dict[str, List]:
"""
Checks for errors (incompatible settings) for the specific problem type.
Returns:
A dictionary with two keys:
- "title": A list of error titles.
- "message": A list of error messages.
"""
errors: Dict[str, List] = {"title": [], "message": []}
return errors