Spaces:
Sleeping
Sleeping
import json | |
from typing import Optional | |
from proto.entity_pb2 import PredictedMovement | |
from tabulate import tabulate | |
from domain.entity_d import ( | |
EntityD, | |
EntityKnowledgeGraphD, | |
EntityRelationshipD, | |
RelationshipD, | |
) | |
from llm_handler.openai_handler import ( | |
ChatCompletionMessageParam, | |
ChatModelVersion, | |
OpenAIHandler, | |
) | |
from utils.dates import parse_date | |
FUZZY_MATCH_ENTITIES_PROMPT = ''' | |
You are an expert in financial analysis. You will be given a two lists of entities. Your task is to output a semantic mapping from the entities in list A to the entities in list B. This means an entity in list A that is semantically similar to an entity in list B should be mapped together. If there is no reasonable semantic match for an entity in list A, output an empty string. Output should be in the format of a JSON object. Ensure the entity keys are in the same order as the input list A. | |
Input: | |
List A: ["BofA", "Bank of Amerca Corp" "GDP", "Inflation", "Yen"] | |
List B: ["Bank of America", "inflation", "Gross Domestic Product", "oil"] | |
Output: | |
{ | |
"BofA": "Bank of America", | |
"Bank of America Corp": "Bank of America", | |
"GDP": "Gross Domestic Product", | |
"Inflation": "inflation", | |
"Yen": "" | |
} | |
''' | |
class EvaluationEngine: | |
_handler: OpenAIHandler | |
_MODEL_VERSION: ChatModelVersion = ChatModelVersion.GPT_4_O | |
_TEMPERATURE: float = 0.2 | |
def __init__(self, | |
ground_truth_kg: EntityKnowledgeGraphD, | |
openai_handler: Optional[OpenAIHandler] = None, | |
model_version: Optional[ChatModelVersion] = None): | |
self._handler = openai_handler or OpenAIHandler() | |
self._model_version = model_version or self._MODEL_VERSION | |
# setup adjacency list representation of ground truth knowledge graph | |
self.kg: dict[str, list[EntityRelationshipD]] = {} | |
for entity_relationship in ground_truth_kg.entity_relationships: | |
to_entity_name = entity_relationship.to_entity.entity_name | |
relationships = self.kg.get(to_entity_name, []) | |
relationships.append(entity_relationship) | |
self.kg[to_entity_name] = relationships | |
def _get_thesis_to_gt_entity_map(self, thesis_kg: EntityKnowledgeGraphD) -> dict[str, str]: | |
thesis_entities = [] | |
for entity_relationship in thesis_kg.entity_relationships: | |
thesis_entities.append(entity_relationship.to_entity.entity_name) | |
# LLM call to return out the matched entities | |
messages: list[ChatCompletionMessageParam] = [ | |
{ | |
"role": "system", "content": FUZZY_MATCH_ENTITIES_PROMPT | |
}, { | |
"role": "user", | |
"content": f"List A: {thesis_entities}\nList B: {list(self.kg.keys())}" | |
} | |
] | |
completion_text = self._handler.get_chat_completion(messages=messages, | |
model=self._model_version, | |
temperature=self._TEMPERATURE, | |
response_format={"type": "json_object"}) | |
thesis_to_gt_entity_mapping: dict[str, str] = json.loads(completion_text) | |
return thesis_to_gt_entity_mapping | |
def _get_relationships_matching_timeperiod( | |
self, gt_kg_to_node: str, relationship: RelationshipD) -> list[EntityRelationshipD]: | |
matching_relationships = [] | |
thesis_relationship_start = parse_date(relationship.start_date) | |
thesis_relationship_end = parse_date(relationship.end_date) | |
for gt_relationship in self.kg[gt_kg_to_node]: | |
gt_relationship_start = parse_date(gt_relationship.relationship.start_date) | |
gt_relationship_end = parse_date(gt_relationship.relationship.end_date) | |
if (gt_relationship_start <= thesis_relationship_start <= gt_relationship_end and \ | |
gt_relationship_start <= thesis_relationship_end <= gt_relationship_end): | |
# thesis relationship timeframe and gt relationship timeframe overlap | |
matching_relationships.append(gt_relationship) | |
return matching_relationships | |
def evaluate_thesis( | |
self, thesis_kg: EntityKnowledgeGraphD | |
) -> list[tuple[EntityRelationshipD, bool, Optional[EntityRelationshipD]]]: | |
thesis_to_kg_map = self._get_thesis_to_gt_entity_map(thesis_kg) | |
results = [] | |
for thesis_relationship in thesis_kg.entity_relationships: | |
thesis_to_node = thesis_relationship.to_entity.entity_name | |
kg_node = thesis_to_kg_map[thesis_to_node] | |
if not kg_node: # no matching entity in KG | |
results.append((thesis_relationship, False, None)) | |
continue | |
matching_relationships = self._get_relationships_matching_timeperiod( | |
kg_node, thesis_relationship.relationship) | |
for entity_relationship in matching_relationships: | |
if entity_relationship.relationship.predicted_movement == thesis_relationship.relationship.predicted_movement: | |
results.append((thesis_relationship, True, entity_relationship)) | |
else: | |
results.append((thesis_relationship, False, entity_relationship)) | |
if len(matching_relationships) == 0: | |
results.append((thesis_relationship, False, None)) | |
return results | |
def evaluate_and_display_thesis(self, thesis_kg: EntityKnowledgeGraphD): | |
results = self.evaluate_thesis(thesis_kg) | |
int_to_str = {1: "Neutral", 2: 'Increase', 3: 'Decrease'} | |
headers = ["Thesis Claim", "Supported by KG", "Related KG Relationship"] | |
table_data = [] | |
for triplet in results: | |
claim_entity = triplet[0].to_entity.entity_name | |
claim_movement = int_to_str[triplet[0].relationship.predicted_movement] | |
claim = f'{claim_entity} {claim_movement}' | |
if triplet[2]: | |
evidence = int_to_str[triplet[2].relationship.predicted_movement] | |
evidence += f' ({triplet[2].from_entity.entity_name}) ' | |
else: | |
evidence = "No evidence in KG" | |
table_data.append([claim, triplet[1], evidence]) | |
return tabulate(table_data, tablefmt="html", headers=headers) | |
if __name__ == '__main__': | |
# TODO: extract the cases into pytest tests | |
kg = EntityKnowledgeGraphD(entity_relationships=[ | |
EntityRelationshipD(from_entity=EntityD(entity_id='3', entity_name="analyst A"), | |
relationship=RelationshipD( | |
relationship_id='2', | |
start_date='2021-01-01', | |
end_date='2024-12-31', | |
source_text='', | |
predicted_movement=PredictedMovement.PREDICTED_MOVEMENT_INCREASE), | |
to_entity=EntityD(entity_id='1', entity_name="GDP")), | |
EntityRelationshipD(from_entity=EntityD(entity_id='5', entity_name="analyst B"), | |
relationship=RelationshipD( | |
relationship_id='3', | |
start_date='2021-01-01', | |
end_date='2021-12-31', | |
source_text='', | |
predicted_movement=PredictedMovement.PREDICTED_MOVEMENT_DECREASE), | |
to_entity=EntityD(entity_id='1', entity_name="GDP")), | |
EntityRelationshipD(from_entity=EntityD(entity_id='7', entity_name="analyst C"), | |
relationship=RelationshipD( | |
relationship_id='4', | |
start_date='2021-01-01', | |
end_date='2021-12-31', | |
source_text='', | |
predicted_movement=PredictedMovement.PREDICTED_MOVEMENT_NEUTRAL), | |
to_entity=EntityD(entity_id='1', entity_name="GDP")), | |
EntityRelationshipD(from_entity=EntityD(entity_id='9', entity_name="analyst D"), | |
relationship=RelationshipD( | |
relationship_id='5', | |
start_date='2021-01-01', | |
end_date='2021-12-31', | |
source_text='', | |
predicted_movement=PredictedMovement.PREDICTED_MOVEMENT_NEUTRAL), | |
to_entity=EntityD(entity_id='10', entity_name="USD")), | |
EntityRelationshipD( # out of time range for thesis | |
from_entity=EntityD(entity_id='9', entity_name="analyst E"), | |
relationship=RelationshipD( | |
relationship_id='5', | |
start_date='2024-01-01', | |
end_date='2024-12-31', | |
source_text='', | |
predicted_movement=PredictedMovement.PREDICTED_MOVEMENT_NEUTRAL), | |
to_entity=EntityD(entity_id='10', entity_name="USD")), | |
]) | |
thesis_claims = [ | |
EntityRelationshipD(from_entity=EntityD(entity_id='2', entity_name="user"), | |
relationship=RelationshipD( | |
relationship_id='1', | |
start_date='2021-01-01', | |
end_date='2021-12-31', | |
source_text='', | |
predicted_movement=PredictedMovement.PREDICTED_MOVEMENT_INCREASE), | |
to_entity=EntityD(entity_id='1', entity_name="Gross Domestic Product")), | |
EntityRelationshipD(from_entity=EntityD(entity_id='2', entity_name="user"), | |
relationship=RelationshipD( | |
relationship_id='1', | |
start_date='2021-01-01', | |
end_date='2021-12-31', | |
source_text='', | |
predicted_movement=PredictedMovement.PREDICTED_MOVEMENT_INCREASE), | |
to_entity=EntityD(entity_id='1', entity_name="US$")), | |
EntityRelationshipD(from_entity=EntityD(entity_id='2', entity_name="user"), | |
relationship=RelationshipD( | |
relationship_id='1', | |
start_date='2021-01-01', | |
end_date='2021-12-31', | |
source_text='', | |
predicted_movement=PredictedMovement.PREDICTED_MOVEMENT_INCREASE), | |
to_entity=EntityD(entity_id='1', entity_name="Yen")), | |
] | |
thesis = EntityKnowledgeGraphD(entity_relationships=thesis_claims) | |
eval_engine = EvaluationEngine(kg) | |
eval_engine.evaluate_and_display_thesis(thesis) | |