Spaces:
Sleeping
Sleeping
from __future__ import annotations | |
import dataclasses | |
import hashlib | |
import logging | |
from typing import TypeAlias, Union | |
import proto.entity_pb2 as entity_pb2 | |
from domain.domain_protocol import DomainProtocol | |
from utils.dates import parse_date | |
Neo4jDict: TypeAlias = dict[str, Union[str, int, bool, list[str], list[int], list[bool]]] | |
class EntityD(DomainProtocol[entity_pb2.Entity]): | |
entity_id: str | |
entity_name: str | |
def id(self) -> str: | |
return self.entity_id | |
def _from_proto(cls, proto: entity_pb2.Entity) -> EntityD: | |
return EntityD(entity_id=proto.entity_id, entity_name=proto.entity_name) | |
def to_proto(self) -> entity_pb2.Entity: | |
return entity_pb2.Entity(entity_id=self.entity_id, entity_name=self.entity_name) | |
def neo4j_create_cmd(self): | |
# TODO store entity_id? | |
return "MERGE (e:Entity {name: $name}) ON CREATE SET e.pdf_file = $pdf_file" | |
def neo4j_create_args(self) -> Neo4jDict: | |
return { | |
"name": self.entity_name, | |
} | |
class RelationshipD(DomainProtocol[entity_pb2.Relationship]): | |
relationship_id: str | |
start_date: str | |
end_date: str | |
source_text: str | |
predicted_movement: entity_pb2.PredictedMovement | |
def id(self) -> str: | |
return self.relationship_id | |
def __post_init__(self): | |
if self.start_date and self.end_date: | |
start = parse_date(self.start_date) | |
end = parse_date(self.end_date) | |
if end < start: | |
logging.warning("end_date %s is before start_date %s", | |
self.end_date, | |
self.start_date) | |
# raise ValueError(f"end_date {self.end_date} is before start_date {self.start_date}") | |
def _from_proto(cls, proto: entity_pb2.Relationship) -> RelationshipD: | |
return RelationshipD(relationship_id=proto.relationship_id, | |
start_date=proto.start_date, | |
end_date=proto.end_date, | |
source_text=proto.source_text, | |
predicted_movement=proto.predicted_movement) | |
def to_proto(self) -> entity_pb2.Relationship: | |
return entity_pb2.Relationship(relationship_id=self.relationship_id, | |
start_date=self.start_date, | |
end_date=self.end_date, | |
source_text=self.source_text, | |
predicted_movement=self.predicted_movement) | |
def from_string(cls, relationship: str) -> entity_pb2.PredictedMovement: | |
if relationship == "PREDICTED_MOVEMENT_NEUTRAL": | |
return entity_pb2.PredictedMovement.PREDICTED_MOVEMENT_NEUTRAL | |
elif relationship == "PREDICTED_MOVEMENT_INCREASE": | |
return entity_pb2.PredictedMovement.PREDICTED_MOVEMENT_INCREASE | |
elif relationship == "PREDICTED_MOVEMENT_DECREASE": | |
return entity_pb2.PredictedMovement.PREDICTED_MOVEMENT_DECREASE | |
else: | |
return entity_pb2.PredictedMovement.PREDICTED_MOVEMENT_UNSPECIFIED | |
def neo4j_create_cmd(self): | |
return """MATCH (from:Entity {name: $from_name}) | |
MATCH (to:Entity {name: $to_name}) | |
MERGE (from) -[r:Relationship {start_date: $start_date, end_date: $end_date, predicted_movement: $predicted_movement}]-> (to) ON CREATE SET r.source_text = $source_text, r.pdf_file = $pdf_file""" | |
def neo4j_create_args(self) -> Neo4jDict: | |
return { | |
"start_date": self.start_date, | |
"end_date": self.end_date, | |
"predicted_movement": entity_pb2.PredictedMovement.Name(self.predicted_movement), | |
"source_text": self.source_text, | |
} | |
class EntityRelationshipD(DomainProtocol[entity_pb2.EntityRelationship]): | |
from_entity: EntityD | |
relationship: RelationshipD | |
to_entity: EntityD | |
def id(self) -> str: | |
return hashlib.sha256(self.to_proto().SerializeToString()).hexdigest() | |
def _from_proto(cls, proto: entity_pb2.EntityRelationship) -> EntityRelationshipD: | |
return EntityRelationshipD(from_entity=EntityD._from_proto(proto.from_entity), | |
relationship=RelationshipD._from_proto(proto.relationship), | |
to_entity=EntityD._from_proto(proto.to_entity)) | |
def to_proto(self) -> entity_pb2.EntityRelationship: | |
return entity_pb2.EntityRelationship(from_entity=self.from_entity.to_proto(), | |
relationship=self.relationship.to_proto(), | |
to_entity=self.to_entity.to_proto()) | |
def neo4j_create_cmds(self): | |
return [ | |
self.from_entity.neo4j_create_cmd, | |
self.to_entity.neo4j_create_cmd, | |
self.relationship.neo4j_create_cmd | |
] | |
def neo4j_create_args(self) -> list[Neo4jDict]: | |
relationship_args = { | |
**self.relationship.neo4j_create_args, | |
'from_name': self.from_entity.entity_name, | |
'to_name': self.to_entity.entity_name, | |
} | |
return [ | |
self.from_entity.neo4j_create_args, self.to_entity.neo4j_create_args, relationship_args | |
] | |
class EntityKnowledgeGraphD(DomainProtocol[entity_pb2.EntityKnowledgeGraph]): | |
entity_relationships: list[EntityRelationshipD] | |
def id(self) -> str: | |
return hashlib.sha256(self.to_proto().SerializeToString()).hexdigest() | |
def _from_proto(cls, proto: entity_pb2.EntityKnowledgeGraph) -> EntityKnowledgeGraphD: | |
return EntityKnowledgeGraphD(entity_relationships=[ | |
EntityRelationshipD._from_proto(entity_relationship) | |
for entity_relationship in proto.entity_relationships | |
]) | |
def to_proto(self) -> entity_pb2.EntityKnowledgeGraph: | |
return entity_pb2.EntityKnowledgeGraph(entity_relationships=[ | |
entity_relationship.to_proto() for entity_relationship in self.entity_relationships | |
]) | |