import json from typing import Optional from proto.entity_pb2 import PredictedMovement from tabulate import tabulate from domain.entity_d import ( EntityD, EntityKnowledgeGraphD, EntityRelationshipD, RelationshipD, ) from llm_handler.openai_handler import ( ChatCompletionMessageParam, ChatModelVersion, OpenAIHandler, ) from utils.dates import parse_date FUZZY_MATCH_ENTITIES_PROMPT = ''' You are an expert in financial analysis. You will be given a two lists of entities. Your task is to output a semantic mapping from the entities in list A to the entities in list B. This means an entity in list A that is semantically similar to an entity in list B should be mapped together. If there is no reasonable semantic match for an entity in list A, output an empty string. Output should be in the format of a JSON object. Ensure the entity keys are in the same order as the input list A. Input: List A: ["BofA", "Bank of Amerca Corp" "GDP", "Inflation", "Yen"] List B: ["Bank of America", "inflation", "Gross Domestic Product", "oil"] Output: { "BofA": "Bank of America", "Bank of America Corp": "Bank of America", "GDP": "Gross Domestic Product", "Inflation": "inflation", "Yen": "" } ''' class EvaluationEngine: _handler: OpenAIHandler _MODEL_VERSION: ChatModelVersion = ChatModelVersion.GPT_4_O _TEMPERATURE: float = 0.2 def __init__(self, ground_truth_kg: EntityKnowledgeGraphD, openai_handler: Optional[OpenAIHandler] = None, model_version: Optional[ChatModelVersion] = None): self._handler = openai_handler or OpenAIHandler() self._model_version = model_version or self._MODEL_VERSION # setup adjacency list representation of ground truth knowledge graph self.kg: dict[str, list[EntityRelationshipD]] = {} for entity_relationship in ground_truth_kg.entity_relationships: to_entity_name = entity_relationship.to_entity.entity_name relationships = self.kg.get(to_entity_name, []) relationships.append(entity_relationship) self.kg[to_entity_name] = relationships def _get_thesis_to_gt_entity_map(self, thesis_kg: EntityKnowledgeGraphD) -> dict[str, str]: thesis_entities = [] for entity_relationship in thesis_kg.entity_relationships: thesis_entities.append(entity_relationship.to_entity.entity_name) # LLM call to return out the matched entities messages: list[ChatCompletionMessageParam] = [ { "role": "system", "content": FUZZY_MATCH_ENTITIES_PROMPT }, { "role": "user", "content": f"List A: {thesis_entities}\nList B: {list(self.kg.keys())}" } ] completion_text = self._handler.get_chat_completion(messages=messages, model=self._model_version, temperature=self._TEMPERATURE, response_format={"type": "json_object"}) thesis_to_gt_entity_mapping: dict[str, str] = json.loads(completion_text) return thesis_to_gt_entity_mapping def _get_relationships_matching_timeperiod( self, gt_kg_to_node: str, relationship: RelationshipD) -> list[EntityRelationshipD]: matching_relationships = [] thesis_relationship_start = parse_date(relationship.start_date) thesis_relationship_end = parse_date(relationship.end_date) for gt_relationship in self.kg[gt_kg_to_node]: gt_relationship_start = parse_date(gt_relationship.relationship.start_date) gt_relationship_end = parse_date(gt_relationship.relationship.end_date) if (gt_relationship_start <= thesis_relationship_start <= gt_relationship_end and \ gt_relationship_start <= thesis_relationship_end <= gt_relationship_end): # thesis relationship timeframe and gt relationship timeframe overlap matching_relationships.append(gt_relationship) return matching_relationships def evaluate_thesis( self, thesis_kg: EntityKnowledgeGraphD ) -> list[tuple[EntityRelationshipD, bool, Optional[EntityRelationshipD]]]: thesis_to_kg_map = self._get_thesis_to_gt_entity_map(thesis_kg) results = [] for thesis_relationship in thesis_kg.entity_relationships: thesis_to_node = thesis_relationship.to_entity.entity_name kg_node = thesis_to_kg_map[thesis_to_node] if not kg_node: # no matching entity in KG results.append((thesis_relationship, False, None)) continue matching_relationships = self._get_relationships_matching_timeperiod( kg_node, thesis_relationship.relationship) for entity_relationship in matching_relationships: if entity_relationship.relationship.predicted_movement == thesis_relationship.relationship.predicted_movement: results.append((thesis_relationship, True, entity_relationship)) else: results.append((thesis_relationship, False, entity_relationship)) if len(matching_relationships) == 0: results.append((thesis_relationship, False, None)) return results def evaluate_and_display_thesis(self, thesis_kg: EntityKnowledgeGraphD): results = self.evaluate_thesis(thesis_kg) int_to_str = {1: "Neutral", 2: 'Increase', 3: 'Decrease'} headers = ["Thesis Claim", "Supported by KG", "Related KG Relationship"] table_data = [] for triplet in results: claim_entity = triplet[0].to_entity.entity_name claim_movement = int_to_str[triplet[0].relationship.predicted_movement] claim = f'{claim_entity} {claim_movement}' if triplet[2]: evidence = int_to_str[triplet[2].relationship.predicted_movement] evidence += f' ({triplet[2].from_entity.entity_name}) ' else: evidence = "No evidence in KG" table_data.append([claim, triplet[1], evidence]) return tabulate(table_data, tablefmt="html", headers=headers) if __name__ == '__main__': # TODO: extract the cases into pytest tests kg = EntityKnowledgeGraphD(entity_relationships=[ EntityRelationshipD(from_entity=EntityD(entity_id='3', entity_name="analyst A"), relationship=RelationshipD( relationship_id='2', start_date='2021-01-01', end_date='2024-12-31', source_text='', predicted_movement=PredictedMovement.PREDICTED_MOVEMENT_INCREASE), to_entity=EntityD(entity_id='1', entity_name="GDP")), EntityRelationshipD(from_entity=EntityD(entity_id='5', entity_name="analyst B"), relationship=RelationshipD( relationship_id='3', start_date='2021-01-01', end_date='2021-12-31', source_text='', predicted_movement=PredictedMovement.PREDICTED_MOVEMENT_DECREASE), to_entity=EntityD(entity_id='1', entity_name="GDP")), EntityRelationshipD(from_entity=EntityD(entity_id='7', entity_name="analyst C"), relationship=RelationshipD( relationship_id='4', start_date='2021-01-01', end_date='2021-12-31', source_text='', predicted_movement=PredictedMovement.PREDICTED_MOVEMENT_NEUTRAL), to_entity=EntityD(entity_id='1', entity_name="GDP")), EntityRelationshipD(from_entity=EntityD(entity_id='9', entity_name="analyst D"), relationship=RelationshipD( relationship_id='5', start_date='2021-01-01', end_date='2021-12-31', source_text='', predicted_movement=PredictedMovement.PREDICTED_MOVEMENT_NEUTRAL), to_entity=EntityD(entity_id='10', entity_name="USD")), EntityRelationshipD( # out of time range for thesis from_entity=EntityD(entity_id='9', entity_name="analyst E"), relationship=RelationshipD( relationship_id='5', start_date='2024-01-01', end_date='2024-12-31', source_text='', predicted_movement=PredictedMovement.PREDICTED_MOVEMENT_NEUTRAL), to_entity=EntityD(entity_id='10', entity_name="USD")), ]) thesis_claims = [ EntityRelationshipD(from_entity=EntityD(entity_id='2', entity_name="user"), relationship=RelationshipD( relationship_id='1', start_date='2021-01-01', end_date='2021-12-31', source_text='', predicted_movement=PredictedMovement.PREDICTED_MOVEMENT_INCREASE), to_entity=EntityD(entity_id='1', entity_name="Gross Domestic Product")), EntityRelationshipD(from_entity=EntityD(entity_id='2', entity_name="user"), relationship=RelationshipD( relationship_id='1', start_date='2021-01-01', end_date='2021-12-31', source_text='', predicted_movement=PredictedMovement.PREDICTED_MOVEMENT_INCREASE), to_entity=EntityD(entity_id='1', entity_name="US$")), EntityRelationshipD(from_entity=EntityD(entity_id='2', entity_name="user"), relationship=RelationshipD( relationship_id='1', start_date='2021-01-01', end_date='2021-12-31', source_text='', predicted_movement=PredictedMovement.PREDICTED_MOVEMENT_INCREASE), to_entity=EntityD(entity_id='1', entity_name="Yen")), ] thesis = EntityKnowledgeGraphD(entity_relationships=thesis_claims) eval_engine = EvaluationEngine(kg) eval_engine.evaluate_and_display_thesis(thesis)