metisllm-dashboard / query_pipeline /evaluation_engine.py
Gateston Johns
first real commit
9041389
import json
from typing import Optional
from proto.entity_pb2 import PredictedMovement
from tabulate import tabulate
from domain.entity_d import (
EntityD,
EntityKnowledgeGraphD,
EntityRelationshipD,
RelationshipD,
)
from llm_handler.openai_handler import (
ChatCompletionMessageParam,
ChatModelVersion,
OpenAIHandler,
)
from utils.dates import parse_date
FUZZY_MATCH_ENTITIES_PROMPT = '''
You are an expert in financial analysis. You will be given a two lists of entities. Your task is to output a semantic mapping from the entities in list A to the entities in list B. This means an entity in list A that is semantically similar to an entity in list B should be mapped together. If there is no reasonable semantic match for an entity in list A, output an empty string. Output should be in the format of a JSON object. Ensure the entity keys are in the same order as the input list A.
Input:
List A: ["BofA", "Bank of Amerca Corp" "GDP", "Inflation", "Yen"]
List B: ["Bank of America", "inflation", "Gross Domestic Product", "oil"]
Output:
{
"BofA": "Bank of America",
"Bank of America Corp": "Bank of America",
"GDP": "Gross Domestic Product",
"Inflation": "inflation",
"Yen": ""
}
'''
class EvaluationEngine:
_handler: OpenAIHandler
_MODEL_VERSION: ChatModelVersion = ChatModelVersion.GPT_4_O
_TEMPERATURE: float = 0.2
def __init__(self,
ground_truth_kg: EntityKnowledgeGraphD,
openai_handler: Optional[OpenAIHandler] = None,
model_version: Optional[ChatModelVersion] = None):
self._handler = openai_handler or OpenAIHandler()
self._model_version = model_version or self._MODEL_VERSION
# setup adjacency list representation of ground truth knowledge graph
self.kg: dict[str, list[EntityRelationshipD]] = {}
for entity_relationship in ground_truth_kg.entity_relationships:
to_entity_name = entity_relationship.to_entity.entity_name
relationships = self.kg.get(to_entity_name, [])
relationships.append(entity_relationship)
self.kg[to_entity_name] = relationships
def _get_thesis_to_gt_entity_map(self, thesis_kg: EntityKnowledgeGraphD) -> dict[str, str]:
thesis_entities = []
for entity_relationship in thesis_kg.entity_relationships:
thesis_entities.append(entity_relationship.to_entity.entity_name)
# LLM call to return out the matched entities
messages: list[ChatCompletionMessageParam] = [
{
"role": "system", "content": FUZZY_MATCH_ENTITIES_PROMPT
}, {
"role": "user",
"content": f"List A: {thesis_entities}\nList B: {list(self.kg.keys())}"
}
]
completion_text = self._handler.get_chat_completion(messages=messages,
model=self._model_version,
temperature=self._TEMPERATURE,
response_format={"type": "json_object"})
thesis_to_gt_entity_mapping: dict[str, str] = json.loads(completion_text)
return thesis_to_gt_entity_mapping
def _get_relationships_matching_timeperiod(
self, gt_kg_to_node: str, relationship: RelationshipD) -> list[EntityRelationshipD]:
matching_relationships = []
thesis_relationship_start = parse_date(relationship.start_date)
thesis_relationship_end = parse_date(relationship.end_date)
for gt_relationship in self.kg[gt_kg_to_node]:
gt_relationship_start = parse_date(gt_relationship.relationship.start_date)
gt_relationship_end = parse_date(gt_relationship.relationship.end_date)
if (gt_relationship_start <= thesis_relationship_start <= gt_relationship_end and \
gt_relationship_start <= thesis_relationship_end <= gt_relationship_end):
# thesis relationship timeframe and gt relationship timeframe overlap
matching_relationships.append(gt_relationship)
return matching_relationships
def evaluate_thesis(
self, thesis_kg: EntityKnowledgeGraphD
) -> list[tuple[EntityRelationshipD, bool, Optional[EntityRelationshipD]]]:
thesis_to_kg_map = self._get_thesis_to_gt_entity_map(thesis_kg)
results = []
for thesis_relationship in thesis_kg.entity_relationships:
thesis_to_node = thesis_relationship.to_entity.entity_name
kg_node = thesis_to_kg_map[thesis_to_node]
if not kg_node: # no matching entity in KG
results.append((thesis_relationship, False, None))
continue
matching_relationships = self._get_relationships_matching_timeperiod(
kg_node, thesis_relationship.relationship)
for entity_relationship in matching_relationships:
if entity_relationship.relationship.predicted_movement == thesis_relationship.relationship.predicted_movement:
results.append((thesis_relationship, True, entity_relationship))
else:
results.append((thesis_relationship, False, entity_relationship))
if len(matching_relationships) == 0:
results.append((thesis_relationship, False, None))
return results
def evaluate_and_display_thesis(self, thesis_kg: EntityKnowledgeGraphD):
results = self.evaluate_thesis(thesis_kg)
int_to_str = {1: "Neutral", 2: 'Increase', 3: 'Decrease'}
headers = ["Thesis Claim", "Supported by KG", "Related KG Relationship"]
table_data = []
for triplet in results:
claim_entity = triplet[0].to_entity.entity_name
claim_movement = int_to_str[triplet[0].relationship.predicted_movement]
claim = f'{claim_entity} {claim_movement}'
if triplet[2]:
evidence = int_to_str[triplet[2].relationship.predicted_movement]
evidence += f' ({triplet[2].from_entity.entity_name}) '
else:
evidence = "No evidence in KG"
table_data.append([claim, triplet[1], evidence])
return tabulate(table_data, tablefmt="html", headers=headers)
if __name__ == '__main__':
# TODO: extract the cases into pytest tests
kg = EntityKnowledgeGraphD(entity_relationships=[
EntityRelationshipD(from_entity=EntityD(entity_id='3', entity_name="analyst A"),
relationship=RelationshipD(
relationship_id='2',
start_date='2021-01-01',
end_date='2024-12-31',
source_text='',
predicted_movement=PredictedMovement.PREDICTED_MOVEMENT_INCREASE),
to_entity=EntityD(entity_id='1', entity_name="GDP")),
EntityRelationshipD(from_entity=EntityD(entity_id='5', entity_name="analyst B"),
relationship=RelationshipD(
relationship_id='3',
start_date='2021-01-01',
end_date='2021-12-31',
source_text='',
predicted_movement=PredictedMovement.PREDICTED_MOVEMENT_DECREASE),
to_entity=EntityD(entity_id='1', entity_name="GDP")),
EntityRelationshipD(from_entity=EntityD(entity_id='7', entity_name="analyst C"),
relationship=RelationshipD(
relationship_id='4',
start_date='2021-01-01',
end_date='2021-12-31',
source_text='',
predicted_movement=PredictedMovement.PREDICTED_MOVEMENT_NEUTRAL),
to_entity=EntityD(entity_id='1', entity_name="GDP")),
EntityRelationshipD(from_entity=EntityD(entity_id='9', entity_name="analyst D"),
relationship=RelationshipD(
relationship_id='5',
start_date='2021-01-01',
end_date='2021-12-31',
source_text='',
predicted_movement=PredictedMovement.PREDICTED_MOVEMENT_NEUTRAL),
to_entity=EntityD(entity_id='10', entity_name="USD")),
EntityRelationshipD( # out of time range for thesis
from_entity=EntityD(entity_id='9', entity_name="analyst E"),
relationship=RelationshipD(
relationship_id='5',
start_date='2024-01-01',
end_date='2024-12-31',
source_text='',
predicted_movement=PredictedMovement.PREDICTED_MOVEMENT_NEUTRAL),
to_entity=EntityD(entity_id='10', entity_name="USD")),
])
thesis_claims = [
EntityRelationshipD(from_entity=EntityD(entity_id='2', entity_name="user"),
relationship=RelationshipD(
relationship_id='1',
start_date='2021-01-01',
end_date='2021-12-31',
source_text='',
predicted_movement=PredictedMovement.PREDICTED_MOVEMENT_INCREASE),
to_entity=EntityD(entity_id='1', entity_name="Gross Domestic Product")),
EntityRelationshipD(from_entity=EntityD(entity_id='2', entity_name="user"),
relationship=RelationshipD(
relationship_id='1',
start_date='2021-01-01',
end_date='2021-12-31',
source_text='',
predicted_movement=PredictedMovement.PREDICTED_MOVEMENT_INCREASE),
to_entity=EntityD(entity_id='1', entity_name="US$")),
EntityRelationshipD(from_entity=EntityD(entity_id='2', entity_name="user"),
relationship=RelationshipD(
relationship_id='1',
start_date='2021-01-01',
end_date='2021-12-31',
source_text='',
predicted_movement=PredictedMovement.PREDICTED_MOVEMENT_INCREASE),
to_entity=EntityD(entity_id='1', entity_name="Yen")),
]
thesis = EntityKnowledgeGraphD(entity_relationships=thesis_claims)
eval_engine = EvaluationEngine(kg)
eval_engine.evaluate_and_display_thesis(thesis)