Spaces:
Sleeping
Sleeping
from __future__ import annotations | |
from typing import Optional, Protocol, Tuple, Type, TypeVar, get_args | |
from google.protobuf import json_format, message | |
MessageType = TypeVar("MessageType", bound=message.Message) | |
DomainProtocolType = TypeVar("DomainProtocolType", bound='DomainProtocol') | |
class ProtoDeserializationError(Exception): | |
... | |
class DomainProtocol(Protocol[MessageType]): | |
def id(self) -> str: | |
... | |
def _from_proto(cls: Type[DomainProtocolType], proto: MessageType) -> DomainProtocolType: | |
... | |
def to_proto(self) -> MessageType: | |
... | |
def message_cls(cls: Type[DomainProtocolType]) -> Type[MessageType]: | |
orig_bases: Optional[Tuple[Type[MessageType], ...]] = getattr(cls, "__orig_bases__", None) | |
if not orig_bases: | |
raise ValueError(f"Class {cls} does not have __orig_bases__") | |
if len(orig_bases) != 1: | |
raise ValueError(f"Class {cls} has unexpected number of bases: {orig_bases}") | |
return get_args(orig_bases[0])[0] | |
def from_proto(cls: Type[DomainProtocolType], | |
proto: MessageType, | |
allow_empty: bool = False) -> DomainProtocolType: | |
try: | |
if not allow_empty: | |
cls.validate_proto_not_empty(proto) | |
return cls._from_proto(proto) | |
except Exception as e: | |
error_str = f"Failed to convert {cls} - {e}" | |
raise ProtoDeserializationError(error_str) from e | |
def from_json(cls: Type[DomainProtocolType], json_str: str) -> DomainProtocolType: | |
try: | |
proto_cls = cls.message_cls() | |
proto = proto_cls() | |
json_format.Parse(json_str, proto) | |
return cls.from_proto(proto) | |
except json_format.ParseError as e: | |
error_str = f"{cls} failed to parse json string: {json_str} - {e}" | |
raise ProtoDeserializationError(error_str) from e | |
def to_json(self) -> str: | |
return json_format.MessageToJson(self.to_proto()).replace("\n", " ") | |
def validate_proto_not_empty(cls, proto: message.Message): | |
if cls.is_empty(proto): | |
raise ValueError("Proto is empty") | |
def is_empty(cls, proto: message.Message) -> bool: | |
descriptor = getattr(proto, 'DESCRIPTOR', None) | |
fields = list(descriptor.fields) if descriptor else [] | |
while fields: | |
field = fields.pop() | |
if field.label == field.LABEL_REPEATED: | |
eval_func = lambda x: x == field.default_value | |
if field.type == field.TYPE_MESSAGE: | |
eval_func = cls.is_empty | |
if not all([eval_func(item) for item in getattr(proto, field.name)]): | |
return False | |
elif field.type == field.TYPE_MESSAGE: | |
if not cls.is_empty(getattr(proto, field.name)): | |
return False | |
else: | |
field_value = getattr(proto, field.name) | |
if field_value != field.default_value: | |
return False | |
return True | |