Gateston Johns
first real commit
9041389
raw
history blame contribute delete
No virus
6.41 kB
from __future__ import annotations
import dataclasses
import hashlib
import logging
from typing import TypeAlias, Union
import proto.entity_pb2 as entity_pb2
from domain.domain_protocol import DomainProtocol
from utils.dates import parse_date
Neo4jDict: TypeAlias = dict[str, Union[str, int, bool, list[str], list[int], list[bool]]]
@dataclasses.dataclass(frozen=True)
class EntityD(DomainProtocol[entity_pb2.Entity]):
entity_id: str
entity_name: str
@property
def id(self) -> str:
return self.entity_id
@classmethod
def _from_proto(cls, proto: entity_pb2.Entity) -> EntityD:
return EntityD(entity_id=proto.entity_id, entity_name=proto.entity_name)
def to_proto(self) -> entity_pb2.Entity:
return entity_pb2.Entity(entity_id=self.entity_id, entity_name=self.entity_name)
@property
def neo4j_create_cmd(self):
# TODO store entity_id?
return "MERGE (e:Entity {name: $name}) ON CREATE SET e.pdf_file = $pdf_file"
@property
def neo4j_create_args(self) -> Neo4jDict:
return {
"name": self.entity_name,
}
@dataclasses.dataclass(frozen=True)
class RelationshipD(DomainProtocol[entity_pb2.Relationship]):
relationship_id: str
start_date: str
end_date: str
source_text: str
predicted_movement: entity_pb2.PredictedMovement
@property
def id(self) -> str:
return self.relationship_id
def __post_init__(self):
if self.start_date and self.end_date:
start = parse_date(self.start_date)
end = parse_date(self.end_date)
if end < start:
logging.warning("end_date %s is before start_date %s",
self.end_date,
self.start_date)
# raise ValueError(f"end_date {self.end_date} is before start_date {self.start_date}")
@classmethod
def _from_proto(cls, proto: entity_pb2.Relationship) -> RelationshipD:
return RelationshipD(relationship_id=proto.relationship_id,
start_date=proto.start_date,
end_date=proto.end_date,
source_text=proto.source_text,
predicted_movement=proto.predicted_movement)
def to_proto(self) -> entity_pb2.Relationship:
return entity_pb2.Relationship(relationship_id=self.relationship_id,
start_date=self.start_date,
end_date=self.end_date,
source_text=self.source_text,
predicted_movement=self.predicted_movement)
@classmethod
def from_string(cls, relationship: str) -> entity_pb2.PredictedMovement:
if relationship == "PREDICTED_MOVEMENT_NEUTRAL":
return entity_pb2.PredictedMovement.PREDICTED_MOVEMENT_NEUTRAL
elif relationship == "PREDICTED_MOVEMENT_INCREASE":
return entity_pb2.PredictedMovement.PREDICTED_MOVEMENT_INCREASE
elif relationship == "PREDICTED_MOVEMENT_DECREASE":
return entity_pb2.PredictedMovement.PREDICTED_MOVEMENT_DECREASE
else:
return entity_pb2.PredictedMovement.PREDICTED_MOVEMENT_UNSPECIFIED
@property
def neo4j_create_cmd(self):
return """MATCH (from:Entity {name: $from_name})
MATCH (to:Entity {name: $to_name})
MERGE (from) -[r:Relationship {start_date: $start_date, end_date: $end_date, predicted_movement: $predicted_movement}]-> (to) ON CREATE SET r.source_text = $source_text, r.pdf_file = $pdf_file"""
@property
def neo4j_create_args(self) -> Neo4jDict:
return {
"start_date": self.start_date,
"end_date": self.end_date,
"predicted_movement": entity_pb2.PredictedMovement.Name(self.predicted_movement),
"source_text": self.source_text,
}
@dataclasses.dataclass(frozen=True)
class EntityRelationshipD(DomainProtocol[entity_pb2.EntityRelationship]):
from_entity: EntityD
relationship: RelationshipD
to_entity: EntityD
@property
def id(self) -> str:
return hashlib.sha256(self.to_proto().SerializeToString()).hexdigest()
@classmethod
def _from_proto(cls, proto: entity_pb2.EntityRelationship) -> EntityRelationshipD:
return EntityRelationshipD(from_entity=EntityD._from_proto(proto.from_entity),
relationship=RelationshipD._from_proto(proto.relationship),
to_entity=EntityD._from_proto(proto.to_entity))
def to_proto(self) -> entity_pb2.EntityRelationship:
return entity_pb2.EntityRelationship(from_entity=self.from_entity.to_proto(),
relationship=self.relationship.to_proto(),
to_entity=self.to_entity.to_proto())
@property
def neo4j_create_cmds(self):
return [
self.from_entity.neo4j_create_cmd,
self.to_entity.neo4j_create_cmd,
self.relationship.neo4j_create_cmd
]
@property
def neo4j_create_args(self) -> list[Neo4jDict]:
relationship_args = {
**self.relationship.neo4j_create_args,
'from_name': self.from_entity.entity_name,
'to_name': self.to_entity.entity_name,
}
return [
self.from_entity.neo4j_create_args, self.to_entity.neo4j_create_args, relationship_args
]
@dataclasses.dataclass(frozen=True)
class EntityKnowledgeGraphD(DomainProtocol[entity_pb2.EntityKnowledgeGraph]):
entity_relationships: list[EntityRelationshipD]
@property
def id(self) -> str:
return hashlib.sha256(self.to_proto().SerializeToString()).hexdigest()
@classmethod
def _from_proto(cls, proto: entity_pb2.EntityKnowledgeGraph) -> EntityKnowledgeGraphD:
return EntityKnowledgeGraphD(entity_relationships=[
EntityRelationshipD._from_proto(entity_relationship)
for entity_relationship in proto.entity_relationships
])
def to_proto(self) -> entity_pb2.EntityKnowledgeGraph:
return entity_pb2.EntityKnowledgeGraph(entity_relationships=[
entity_relationship.to_proto() for entity_relationship in self.entity_relationships
])