from __future__ import annotations import dataclasses import hashlib import logging from typing import TypeAlias, Union import proto.entity_pb2 as entity_pb2 from domain.domain_protocol import DomainProtocol from utils.dates import parse_date Neo4jDict: TypeAlias = dict[str, Union[str, int, bool, list[str], list[int], list[bool]]] @dataclasses.dataclass(frozen=True) class EntityD(DomainProtocol[entity_pb2.Entity]): entity_id: str entity_name: str @property def id(self) -> str: return self.entity_id @classmethod def _from_proto(cls, proto: entity_pb2.Entity) -> EntityD: return EntityD(entity_id=proto.entity_id, entity_name=proto.entity_name) def to_proto(self) -> entity_pb2.Entity: return entity_pb2.Entity(entity_id=self.entity_id, entity_name=self.entity_name) @property def neo4j_create_cmd(self): # TODO store entity_id? return "MERGE (e:Entity {name: $name}) ON CREATE SET e.pdf_file = $pdf_file" @property def neo4j_create_args(self) -> Neo4jDict: return { "name": self.entity_name, } @dataclasses.dataclass(frozen=True) class RelationshipD(DomainProtocol[entity_pb2.Relationship]): relationship_id: str start_date: str end_date: str source_text: str predicted_movement: entity_pb2.PredictedMovement @property def id(self) -> str: return self.relationship_id def __post_init__(self): if self.start_date and self.end_date: start = parse_date(self.start_date) end = parse_date(self.end_date) if end < start: logging.warning("end_date %s is before start_date %s", self.end_date, self.start_date) # raise ValueError(f"end_date {self.end_date} is before start_date {self.start_date}") @classmethod def _from_proto(cls, proto: entity_pb2.Relationship) -> RelationshipD: return RelationshipD(relationship_id=proto.relationship_id, start_date=proto.start_date, end_date=proto.end_date, source_text=proto.source_text, predicted_movement=proto.predicted_movement) def to_proto(self) -> entity_pb2.Relationship: return entity_pb2.Relationship(relationship_id=self.relationship_id, start_date=self.start_date, end_date=self.end_date, source_text=self.source_text, predicted_movement=self.predicted_movement) @classmethod def from_string(cls, relationship: str) -> entity_pb2.PredictedMovement: if relationship == "PREDICTED_MOVEMENT_NEUTRAL": return entity_pb2.PredictedMovement.PREDICTED_MOVEMENT_NEUTRAL elif relationship == "PREDICTED_MOVEMENT_INCREASE": return entity_pb2.PredictedMovement.PREDICTED_MOVEMENT_INCREASE elif relationship == "PREDICTED_MOVEMENT_DECREASE": return entity_pb2.PredictedMovement.PREDICTED_MOVEMENT_DECREASE else: return entity_pb2.PredictedMovement.PREDICTED_MOVEMENT_UNSPECIFIED @property def neo4j_create_cmd(self): return """MATCH (from:Entity {name: $from_name}) MATCH (to:Entity {name: $to_name}) MERGE (from) -[r:Relationship {start_date: $start_date, end_date: $end_date, predicted_movement: $predicted_movement}]-> (to) ON CREATE SET r.source_text = $source_text, r.pdf_file = $pdf_file""" @property def neo4j_create_args(self) -> Neo4jDict: return { "start_date": self.start_date, "end_date": self.end_date, "predicted_movement": entity_pb2.PredictedMovement.Name(self.predicted_movement), "source_text": self.source_text, } @dataclasses.dataclass(frozen=True) class EntityRelationshipD(DomainProtocol[entity_pb2.EntityRelationship]): from_entity: EntityD relationship: RelationshipD to_entity: EntityD @property def id(self) -> str: return hashlib.sha256(self.to_proto().SerializeToString()).hexdigest() @classmethod def _from_proto(cls, proto: entity_pb2.EntityRelationship) -> EntityRelationshipD: return EntityRelationshipD(from_entity=EntityD._from_proto(proto.from_entity), relationship=RelationshipD._from_proto(proto.relationship), to_entity=EntityD._from_proto(proto.to_entity)) def to_proto(self) -> entity_pb2.EntityRelationship: return entity_pb2.EntityRelationship(from_entity=self.from_entity.to_proto(), relationship=self.relationship.to_proto(), to_entity=self.to_entity.to_proto()) @property def neo4j_create_cmds(self): return [ self.from_entity.neo4j_create_cmd, self.to_entity.neo4j_create_cmd, self.relationship.neo4j_create_cmd ] @property def neo4j_create_args(self) -> list[Neo4jDict]: relationship_args = { **self.relationship.neo4j_create_args, 'from_name': self.from_entity.entity_name, 'to_name': self.to_entity.entity_name, } return [ self.from_entity.neo4j_create_args, self.to_entity.neo4j_create_args, relationship_args ] @dataclasses.dataclass(frozen=True) class EntityKnowledgeGraphD(DomainProtocol[entity_pb2.EntityKnowledgeGraph]): entity_relationships: list[EntityRelationshipD] @property def id(self) -> str: return hashlib.sha256(self.to_proto().SerializeToString()).hexdigest() @classmethod def _from_proto(cls, proto: entity_pb2.EntityKnowledgeGraph) -> EntityKnowledgeGraphD: return EntityKnowledgeGraphD(entity_relationships=[ EntityRelationshipD._from_proto(entity_relationship) for entity_relationship in proto.entity_relationships ]) def to_proto(self) -> entity_pb2.EntityKnowledgeGraph: return entity_pb2.EntityKnowledgeGraph(entity_relationships=[ entity_relationship.to_proto() for entity_relationship in self.entity_relationships ])