|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import json |
|
import dataclasses |
|
import numpy as np |
|
from dataclasses import Field, MISSING |
|
from typing import IO, TypeVar, Type, get_args, get_origin, Union, Any, Tuple |
|
|
|
_X = TypeVar("_X") |
|
|
|
|
|
def load_dataclass(f: IO, cls: Type[_X], binary: bool = False) -> _X: |
|
""" |
|
Loads to a @dataclass or collection hierarchy including dataclasses |
|
from a json recursively. |
|
Call it like load_dataclass(f, typing.List[FrameAnnotationAnnotation]). |
|
raises KeyError if json has keys not mapping to the dataclass fields. |
|
|
|
Args: |
|
f: Either a path to a file, or a file opened for writing. |
|
cls: The class of the loaded dataclass. |
|
binary: Set to True if `f` is a file handle, else False. |
|
""" |
|
if binary: |
|
asdict = json.loads(f.read().decode("utf8")) |
|
else: |
|
asdict = json.load(f) |
|
|
|
|
|
cls = get_args(cls)[0] |
|
res = list(_dataclass_list_from_dict_list(asdict, cls)) |
|
|
|
return res |
|
|
|
|
|
def _resolve_optional(type_: Any) -> Tuple[bool, Any]: |
|
"""Check whether `type_` is equivalent to `typing.Optional[T]` for some T.""" |
|
if get_origin(type_) is Union: |
|
args = get_args(type_) |
|
if len(args) == 2 and args[1] == type(None): |
|
return True, args[0] |
|
if type_ is Any: |
|
return True, Any |
|
|
|
return False, type_ |
|
|
|
|
|
def _unwrap_type(tp): |
|
|
|
if get_origin(tp) is Union: |
|
args = get_args(tp) |
|
if len(args) == 2 and any(a is type(None) for a in args): |
|
|
|
return args[0] if args[1] is type(None) else args[1] |
|
return tp |
|
|
|
|
|
def _get_dataclass_field_default(field: Field) -> Any: |
|
if field.default_factory is not MISSING: |
|
|
|
|
|
return field.default_factory() |
|
elif field.default is not MISSING: |
|
return field.default |
|
else: |
|
return None |
|
|
|
|
|
def _dataclass_list_from_dict_list(dlist, typeannot): |
|
""" |
|
Vectorised version of `_dataclass_from_dict`. |
|
The output should be equivalent to |
|
`[_dataclass_from_dict(d, typeannot) for d in dlist]`. |
|
|
|
Args: |
|
dlist: list of objects to convert. |
|
typeannot: type of each of those objects. |
|
Returns: |
|
iterator or list over converted objects of the same length as `dlist`. |
|
|
|
Raises: |
|
ValueError: it assumes the objects have None's in consistent places across |
|
objects, otherwise it would ignore some values. This generally holds for |
|
auto-generated annotations, but otherwise use `_dataclass_from_dict`. |
|
""" |
|
|
|
cls = get_origin(typeannot) or typeannot |
|
|
|
if typeannot is Any: |
|
return dlist |
|
if all(obj is None for obj in dlist): |
|
return dlist |
|
if any(obj is None for obj in dlist): |
|
|
|
idx_notnone = [(i, obj) for i, obj in enumerate(dlist) if obj is not None] |
|
idx, notnone = zip(*idx_notnone) |
|
converted = _dataclass_list_from_dict_list(notnone, typeannot) |
|
res = [None] * len(dlist) |
|
for i, obj in zip(idx, converted): |
|
res[i] = obj |
|
return res |
|
|
|
is_optional, contained_type = _resolve_optional(typeannot) |
|
if is_optional: |
|
return _dataclass_list_from_dict_list(dlist, contained_type) |
|
|
|
|
|
if issubclass(cls, tuple) and hasattr(cls, "_fields"): |
|
|
|
types = cls.__annotations__.values() |
|
dlist_T = zip(*dlist) |
|
res_T = [ |
|
_dataclass_list_from_dict_list(key_list, tp) for key_list, tp in zip(dlist_T, types) |
|
] |
|
return [cls(*converted_as_tuple) for converted_as_tuple in zip(*res_T)] |
|
elif issubclass(cls, (list, tuple)): |
|
|
|
types = get_args(typeannot) |
|
if len(types) == 1: |
|
types = types * len(dlist[0]) |
|
dlist_T = zip(*dlist) |
|
res_T = ( |
|
_dataclass_list_from_dict_list(pos_list, tp) for pos_list, tp in zip(dlist_T, types) |
|
) |
|
if issubclass(cls, tuple): |
|
return list(zip(*res_T)) |
|
else: |
|
return [cls(converted_as_tuple) for converted_as_tuple in zip(*res_T)] |
|
elif issubclass(cls, dict): |
|
|
|
key_t, val_t = get_args(typeannot) |
|
all_keys_res = _dataclass_list_from_dict_list( |
|
[k for obj in dlist for k in obj.keys()], key_t |
|
) |
|
all_vals_res = _dataclass_list_from_dict_list( |
|
[k for obj in dlist for k in obj.values()], val_t |
|
) |
|
indices = np.cumsum([len(obj) for obj in dlist]) |
|
assert indices[-1] == len(all_keys_res) |
|
|
|
keys = np.split(list(all_keys_res), indices[:-1]) |
|
all_vals_res_iter = iter(all_vals_res) |
|
return [cls(zip(k, all_vals_res_iter)) for k in keys] |
|
elif not dataclasses.is_dataclass(typeannot): |
|
return dlist |
|
|
|
|
|
|
|
assert dataclasses.is_dataclass(cls) |
|
fieldtypes = { |
|
f.name: (_unwrap_type(f.type), _get_dataclass_field_default(f)) |
|
for f in dataclasses.fields(typeannot) |
|
} |
|
|
|
|
|
key_lists = ( |
|
_dataclass_list_from_dict_list([obj.get(k, default) for obj in dlist], type_) |
|
for k, (type_, default) in fieldtypes.items() |
|
) |
|
transposed = zip(*key_lists) |
|
return [cls(*vals_as_tuple) for vals_as_tuple in transposed] |
|
|