from typing import Dict, List from lightning.pytorch.callbacks import Callback from relik.reader.data.relik_reader_sample import RelikReaderSample from relik.reader.relik_reader_predictor import RelikReaderPredictor from relik.reader.utils.metrics import f1_measure, safe_divide from relik.reader.utils.special_symbols import NME_SYMBOL class StrongMatching: def __call__(self, predicted_samples: List[RelikReaderSample]) -> Dict: # accumulators correct_predictions = 0 correct_predicted_entities = 0 correct_predictions_at_k = 0 total_predictions = 0 total_gold = 0 total_entities_predictions = 0 total_entities_gold = 0 correct_span_predictions = 0 miss_due_to_candidates = 0 # prediction index stats avg_correct_predicted_index = [] avg_wrong_predicted_index = [] less_index_predictions = [] # collect data from samples for sample in predicted_samples: predicted_annotations = sample.predicted_window_labels_chars predicted_annotations_probabilities = sample.probs_window_labels_chars gold_annotations = { (ss, se, entity) for ss, se, entity in sample.window_labels if entity != NME_SYMBOL } total_predictions += len(predicted_annotations) total_gold += len(gold_annotations) gold_entities = { entity for ss, se, entity in sample.window_labels if entity != NME_SYMBOL } total_entities_gold += len(gold_entities) pred_entities = { entity for ss, se, entity in predicted_annotations if entity != NME_SYMBOL } total_entities_predictions += len(pred_entities) # correct named entity detection predicted_spans = {(s, e) for s, e, _ in predicted_annotations} gold_spans = {(s, e) for s, e, _ in gold_annotations} correct_span_predictions += len(predicted_spans.intersection(gold_spans)) # correct entity linking correct_predictions += len( predicted_annotations.intersection(gold_annotations) ) for ss, se, ge in gold_annotations.difference(predicted_annotations): if ge not in sample.window_candidates: miss_due_to_candidates += 1 if ge in predicted_annotations_probabilities.get((ss, se), set()): correct_predictions_at_k += 1 # correct entity disambiguation correct_predicted_entities += len( pred_entities.intersection(gold_entities) ) # indices metrics predicted_spans_index = { (ss, se): ent for ss, se, ent in predicted_annotations } gold_spans_index = {(ss, se): ent for ss, se, ent in gold_annotations} for pred_span, pred_ent in predicted_spans_index.items(): gold_ent = gold_spans_index.get(pred_span) if pred_span not in gold_spans_index: continue # missing candidate if gold_ent not in sample.window_candidates: continue gold_idx = sample.window_candidates.index(gold_ent) if gold_idx is None: continue pred_idx = sample.window_candidates.index(pred_ent) if gold_ent != pred_ent: avg_wrong_predicted_index.append(pred_idx) if gold_idx is not None: if pred_idx > gold_idx: less_index_predictions.append(0) else: less_index_predictions.append(1) else: avg_correct_predicted_index.append(pred_idx) # compute NED metrics span_precision = safe_divide(correct_span_predictions, total_predictions) span_recall = safe_divide(correct_span_predictions, total_gold) span_f1 = f1_measure(span_precision, span_recall) # compute EL metrics precision = safe_divide(correct_predictions, total_predictions) recall = safe_divide(correct_predictions, total_gold) recall_at_k = safe_divide( (correct_predictions + correct_predictions_at_k), total_gold ) f1 = f1_measure(precision, recall) # comput ED metrics precision_entities = safe_divide(correct_predicted_entities, total_entities_predictions) recall_entities = safe_divide(correct_predicted_entities, total_entities_gold) span_entities_f1 = f1_measure(precision_entities, recall_entities) wrong_for_candidates = safe_divide(miss_due_to_candidates, total_gold) out_dict = { "span_precision": span_precision, "span_recall": span_recall, "span_f1": span_f1, "entities_precision": precision_entities, "entities_recall": recall_entities, "entities_f1": span_entities_f1, "core_precision": precision, "core_recall": recall, "core_recall-at-k": recall_at_k, "core_f1": round(f1, 4), "wrong-for-candidates": wrong_for_candidates, "index_errors_avg-index": safe_divide( sum(avg_wrong_predicted_index), len(avg_wrong_predicted_index) ), "index_correct_avg-index": safe_divide( sum(avg_correct_predicted_index), len(avg_correct_predicted_index) ), "index_avg-index": safe_divide( sum(avg_correct_predicted_index + avg_wrong_predicted_index), len(avg_correct_predicted_index + avg_wrong_predicted_index), ), "index_percentage-favoured-smaller-idx": safe_divide( sum(less_index_predictions), len(less_index_predictions) ), } return {k: round(v, 5) for k, v in out_dict.items()} class ELStrongMatchingCallback(Callback): def __init__(self, dataset_path: str, dataset_conf) -> None: super().__init__() self.dataset_path = dataset_path self.dataset_conf = dataset_conf self.strong_matching_metric = StrongMatching() def on_validation_epoch_start(self, trainer, pl_module) -> None: relik_reader_predictor = RelikReaderPredictor(pl_module.relik_reader_core_model) predicted_samples = relik_reader_predictor.predict( self.dataset_path, samples=None, dataset_conf=self.dataset_conf, ) predicted_samples = list(predicted_samples) for k, v in self.strong_matching_metric(predicted_samples).items(): pl_module.log(f"val_{k}", v)