|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from __future__ import annotations |
|
|
|
from typing import Any, TypeVar |
|
|
|
import attrs |
|
|
|
from omegaconf import DictConfig as LazyDict |
|
|
|
from .misc import Color |
|
|
|
T = TypeVar("T") |
|
|
|
|
|
def _is_attrs_instance(obj: object) -> bool: |
|
""" |
|
Helper function to check if an object is an instance of an attrs-defined class. |
|
|
|
Args: |
|
obj: The object to check. |
|
|
|
Returns: |
|
bool: True if the object is an instance of an attrs-defined class, False otherwise. |
|
""" |
|
return hasattr(obj, "__attrs_attrs__") |
|
|
|
|
|
def make_freezable(cls: T) -> T: |
|
""" |
|
A decorator that adds the capability to freeze instances of an attrs-defined class. |
|
|
|
NOTE: This requires the wrapped attrs to be defined with attrs.define(slots=False) because we need |
|
to hack on a "_is_frozen" attribute. |
|
|
|
This decorator enhances an attrs-defined class with the ability to be "frozen" at runtime. |
|
Once an instance is frozen, its attributes cannot be changed. It also recursively freezes |
|
any attrs-defined objects that are attributes of the class. |
|
|
|
Usage: |
|
@make_freezable |
|
@attrs.define(slots=False) |
|
class MyClass: |
|
attribute1: int |
|
attribute2: str |
|
|
|
obj = MyClass(1, 'a') |
|
obj.freeze() # Freeze the instance |
|
obj.attribute1 = 2 # Raises AttributeError |
|
|
|
Args: |
|
cls: The class to be decorated. |
|
|
|
Returns: |
|
The decorated class with added freezing capability. |
|
""" |
|
|
|
if not hasattr(cls, "__dict__"): |
|
raise TypeError( |
|
"make_freezable cannot be used with classes that do not define __dict__. Make sure that the wrapped " |
|
"class was defined with `@attrs.define(slots=False)`" |
|
) |
|
|
|
original_setattr = cls.__setattr__ |
|
|
|
def setattr_override(self, key, value) -> None: |
|
""" |
|
Override __setattr__ to allow modifications during initialization |
|
and prevent modifications once the instance is frozen. |
|
""" |
|
if hasattr(self, "_is_frozen") and self._is_frozen and key != "_is_frozen": |
|
raise AttributeError("Cannot modify frozen instance") |
|
original_setattr(self, key, value) |
|
|
|
cls.__setattr__ = setattr_override |
|
|
|
def freeze(self: object) -> None: |
|
""" |
|
Freeze the instance and all its attrs-defined attributes. |
|
""" |
|
for _, value in attrs.asdict(self, recurse=False).items(): |
|
if _is_attrs_instance(value) and hasattr(value, "freeze"): |
|
value.freeze() |
|
self._is_frozen = True |
|
|
|
cls.freeze = freeze |
|
|
|
return cls |
|
|
|
|
|
def _pretty_print_attrs_instance(obj: object, indent: int = 0, use_color: bool = False) -> str: |
|
""" |
|
Recursively pretty prints attrs objects with color. |
|
""" |
|
|
|
assert attrs.has(obj.__class__) |
|
|
|
lines: list[str] = [] |
|
for attribute in attrs.fields(obj.__class__): |
|
value = getattr(obj, attribute.name) |
|
if attrs.has(value.__class__): |
|
if use_color: |
|
lines.append(" " * indent + Color.cyan("* ") + Color.green(attribute.name) + ":") |
|
else: |
|
lines.append(" " * indent + "* " + attribute.name + ":") |
|
lines.append(_pretty_print_attrs_instance(value, indent + 1, use_color)) |
|
else: |
|
if use_color: |
|
lines.append( |
|
" " * indent + Color.cyan("* ") + Color.green(attribute.name) + ": " + Color.yellow(value) |
|
) |
|
else: |
|
lines.append(" " * indent + "* " + attribute.name + ": " + str(value)) |
|
return "\n".join(lines) |
|
|
|
|
|
@make_freezable |
|
@attrs.define(slots=False) |
|
class JobConfig: |
|
|
|
project: str = "" |
|
|
|
group: str = "" |
|
|
|
name: str = "" |
|
|
|
@property |
|
def path(self) -> str: |
|
return f"{self.project}/{self.group}/{self.name}" |
|
|
|
|
|
@make_freezable |
|
@attrs.define(slots=False) |
|
class Config: |
|
"""Config for a job. |
|
|
|
See /README.md/Configuration System for more info. |
|
""" |
|
|
|
|
|
model: LazyDict |
|
|
|
|
|
job: JobConfig = attrs.field(factory=JobConfig) |
|
|
|
def to_dict(self) -> dict[str, Any]: |
|
return attrs.asdict(self) |
|
|
|
def validate(self) -> None: |
|
"""Validate that the config has all required fields.""" |
|
assert self.job.project != "", "Project name is required." |
|
assert self.job.group != "", "Group name is required." |
|
assert self.job.name != "", "Job name is required." |
|
|