from __future__ import annotations from typing import Optional, Protocol, Tuple, Type, TypeVar, get_args from google.protobuf import json_format, message MessageType = TypeVar("MessageType", bound=message.Message) DomainProtocolType = TypeVar("DomainProtocolType", bound='DomainProtocol') class ProtoDeserializationError(Exception): ... class DomainProtocol(Protocol[MessageType]): @property def id(self) -> str: ... @classmethod def _from_proto(cls: Type[DomainProtocolType], proto: MessageType) -> DomainProtocolType: ... def to_proto(self) -> MessageType: ... @classmethod def message_cls(cls: Type[DomainProtocolType]) -> Type[MessageType]: orig_bases: Optional[Tuple[Type[MessageType], ...]] = getattr(cls, "__orig_bases__", None) if not orig_bases: raise ValueError(f"Class {cls} does not have __orig_bases__") if len(orig_bases) != 1: raise ValueError(f"Class {cls} has unexpected number of bases: {orig_bases}") return get_args(orig_bases[0])[0] @classmethod def from_proto(cls: Type[DomainProtocolType], proto: MessageType, allow_empty: bool = False) -> DomainProtocolType: try: if not allow_empty: cls.validate_proto_not_empty(proto) return cls._from_proto(proto) except Exception as e: error_str = f"Failed to convert {cls} - {e}" raise ProtoDeserializationError(error_str) from e @classmethod def from_json(cls: Type[DomainProtocolType], json_str: str) -> DomainProtocolType: try: proto_cls = cls.message_cls() proto = proto_cls() json_format.Parse(json_str, proto) return cls.from_proto(proto) except json_format.ParseError as e: error_str = f"{cls} failed to parse json string: {json_str} - {e}" raise ProtoDeserializationError(error_str) from e def to_json(self) -> str: return json_format.MessageToJson(self.to_proto()).replace("\n", " ") @classmethod def validate_proto_not_empty(cls, proto: message.Message): if cls.is_empty(proto): raise ValueError("Proto is empty") @classmethod def is_empty(cls, proto: message.Message) -> bool: descriptor = getattr(proto, 'DESCRIPTOR', None) fields = list(descriptor.fields) if descriptor else [] while fields: field = fields.pop() if field.label == field.LABEL_REPEATED: eval_func = lambda x: x == field.default_value if field.type == field.TYPE_MESSAGE: eval_func = cls.is_empty if not all([eval_func(item) for item in getattr(proto, field.name)]): return False elif field.type == field.TYPE_MESSAGE: if not cls.is_empty(getattr(proto, field.name)): return False else: field_value = getattr(proto, field.name) if field_value != field.default_value: return False return True