Spaces:
Running
on
Zero
Running
on
Zero
import dataclasses | |
from dataclasses import dataclass, _MISSING_TYPE | |
from munch import Munch | |
EXPECTED = "___REQUIRED___" | |
EXPECTED_TRAIN = "___REQUIRED_TRAIN___" | |
# pylint: disable=invalid-field-call | |
def nested_dto(x, raw=False): | |
return dataclasses.field(default_factory=lambda: x if raw else Munch.fromDict(x)) | |
class Base: | |
training: bool = None | |
def __new__(cls, **kwargs): | |
training = kwargs.get('training', True) | |
setteable_fields = cls.setteable_fields(**kwargs) | |
mandatory_fields = cls.mandatory_fields(**kwargs) | |
invalid_kwargs = [ | |
{k: v} for k, v in kwargs.items() if k not in setteable_fields or v == EXPECTED or (v == EXPECTED_TRAIN and training is not False) | |
] | |
print(mandatory_fields) | |
assert ( | |
len(invalid_kwargs) == 0 | |
), f"Invalid fields detected when initializing this DTO: {invalid_kwargs}.\nDeclare this field and set it to None or EXPECTED in order to make it setteable." | |
missing_kwargs = [f for f in mandatory_fields if f not in kwargs] | |
assert ( | |
len(missing_kwargs) == 0 | |
), f"Required fields missing initializing this DTO: {missing_kwargs}." | |
return object.__new__(cls) | |
def setteable_fields(cls, **kwargs): | |
return [f.name for f in dataclasses.fields(cls) if f.default is None or isinstance(f.default, _MISSING_TYPE) or f.default == EXPECTED or f.default == EXPECTED_TRAIN] | |
def mandatory_fields(cls, **kwargs): | |
training = kwargs.get('training', True) | |
return [f.name for f in dataclasses.fields(cls) if isinstance(f.default, _MISSING_TYPE) and isinstance(f.default_factory, _MISSING_TYPE) or f.default == EXPECTED or (f.default == EXPECTED_TRAIN and training is not False)] | |
def from_dict(cls, kwargs): | |
for k in kwargs: | |
if isinstance(kwargs[k], (dict, list, tuple)): | |
kwargs[k] = Munch.fromDict(kwargs[k]) | |
return cls(**kwargs) | |
def to_dict(self): | |
# selfdict = dataclasses.asdict(self) # needs to pickle stuff, doesn't support some more complex classes | |
selfdict = {} | |
for k in dataclasses.fields(self): | |
selfdict[k.name] = getattr(self, k.name) | |
if isinstance(selfdict[k.name], Munch): | |
selfdict[k.name] = selfdict[k.name].toDict() | |
return selfdict | |