|
from typing import Dict, List |
|
|
|
from lightning.pytorch.callbacks import Callback |
|
|
|
from relik.reader.data.relik_reader_sample import RelikReaderSample |
|
from relik.reader.relik_reader_predictor import RelikReaderPredictor |
|
from relik.reader.utils.metrics import f1_measure, safe_divide |
|
from relik.reader.utils.special_symbols import NME_SYMBOL |
|
|
|
|
|
class StrongMatching: |
|
def __call__(self, predicted_samples: List[RelikReaderSample]) -> Dict: |
|
|
|
correct_predictions = 0 |
|
correct_predicted_entities = 0 |
|
correct_predictions_at_k = 0 |
|
total_predictions = 0 |
|
total_gold = 0 |
|
total_entities_predictions = 0 |
|
total_entities_gold = 0 |
|
correct_span_predictions = 0 |
|
miss_due_to_candidates = 0 |
|
|
|
|
|
avg_correct_predicted_index = [] |
|
avg_wrong_predicted_index = [] |
|
less_index_predictions = [] |
|
|
|
|
|
for sample in predicted_samples: |
|
predicted_annotations = sample.predicted_window_labels_chars |
|
predicted_annotations_probabilities = sample.probs_window_labels_chars |
|
gold_annotations = { |
|
(ss, se, entity) |
|
for ss, se, entity in sample.window_labels |
|
if entity != NME_SYMBOL |
|
} |
|
|
|
total_predictions += len(predicted_annotations) |
|
total_gold += len(gold_annotations) |
|
|
|
gold_entities = { |
|
entity |
|
for ss, se, entity in sample.window_labels |
|
if entity != NME_SYMBOL |
|
} |
|
total_entities_gold += len(gold_entities) |
|
|
|
pred_entities = { |
|
entity |
|
for ss, se, entity in predicted_annotations |
|
if entity != NME_SYMBOL |
|
} |
|
total_entities_predictions += len(pred_entities) |
|
|
|
|
|
predicted_spans = {(s, e) for s, e, _ in predicted_annotations} |
|
gold_spans = {(s, e) for s, e, _ in gold_annotations} |
|
correct_span_predictions += len(predicted_spans.intersection(gold_spans)) |
|
|
|
|
|
correct_predictions += len( |
|
predicted_annotations.intersection(gold_annotations) |
|
) |
|
|
|
for ss, se, ge in gold_annotations.difference(predicted_annotations): |
|
if ge not in sample.window_candidates: |
|
miss_due_to_candidates += 1 |
|
if ge in predicted_annotations_probabilities.get((ss, se), set()): |
|
correct_predictions_at_k += 1 |
|
|
|
|
|
correct_predicted_entities += len( |
|
pred_entities.intersection(gold_entities) |
|
) |
|
|
|
|
|
predicted_spans_index = { |
|
(ss, se): ent for ss, se, ent in predicted_annotations |
|
} |
|
gold_spans_index = {(ss, se): ent for ss, se, ent in gold_annotations} |
|
|
|
for pred_span, pred_ent in predicted_spans_index.items(): |
|
gold_ent = gold_spans_index.get(pred_span) |
|
|
|
if pred_span not in gold_spans_index: |
|
continue |
|
|
|
|
|
if gold_ent not in sample.window_candidates: |
|
continue |
|
|
|
gold_idx = sample.window_candidates.index(gold_ent) |
|
if gold_idx is None: |
|
continue |
|
pred_idx = sample.window_candidates.index(pred_ent) |
|
|
|
if gold_ent != pred_ent: |
|
avg_wrong_predicted_index.append(pred_idx) |
|
|
|
if gold_idx is not None: |
|
if pred_idx > gold_idx: |
|
less_index_predictions.append(0) |
|
else: |
|
less_index_predictions.append(1) |
|
|
|
else: |
|
avg_correct_predicted_index.append(pred_idx) |
|
|
|
|
|
span_precision = safe_divide(correct_span_predictions, total_predictions) |
|
span_recall = safe_divide(correct_span_predictions, total_gold) |
|
span_f1 = f1_measure(span_precision, span_recall) |
|
|
|
|
|
precision = safe_divide(correct_predictions, total_predictions) |
|
recall = safe_divide(correct_predictions, total_gold) |
|
recall_at_k = safe_divide( |
|
(correct_predictions + correct_predictions_at_k), total_gold |
|
) |
|
f1 = f1_measure(precision, recall) |
|
|
|
|
|
|
|
precision_entities = safe_divide(correct_predicted_entities, total_entities_predictions) |
|
recall_entities = safe_divide(correct_predicted_entities, total_entities_gold) |
|
span_entities_f1 = f1_measure(precision_entities, recall_entities) |
|
|
|
|
|
wrong_for_candidates = safe_divide(miss_due_to_candidates, total_gold) |
|
|
|
out_dict = { |
|
"span_precision": span_precision, |
|
"span_recall": span_recall, |
|
"span_f1": span_f1, |
|
"entities_precision": precision_entities, |
|
"entities_recall": recall_entities, |
|
"entities_f1": span_entities_f1, |
|
"core_precision": precision, |
|
"core_recall": recall, |
|
"core_recall-at-k": recall_at_k, |
|
"core_f1": round(f1, 4), |
|
"wrong-for-candidates": wrong_for_candidates, |
|
"index_errors_avg-index": safe_divide( |
|
sum(avg_wrong_predicted_index), len(avg_wrong_predicted_index) |
|
), |
|
"index_correct_avg-index": safe_divide( |
|
sum(avg_correct_predicted_index), len(avg_correct_predicted_index) |
|
), |
|
"index_avg-index": safe_divide( |
|
sum(avg_correct_predicted_index + avg_wrong_predicted_index), |
|
len(avg_correct_predicted_index + avg_wrong_predicted_index), |
|
), |
|
"index_percentage-favoured-smaller-idx": safe_divide( |
|
sum(less_index_predictions), len(less_index_predictions) |
|
), |
|
} |
|
|
|
return {k: round(v, 5) for k, v in out_dict.items()} |
|
|
|
|
|
class ELStrongMatchingCallback(Callback): |
|
def __init__(self, dataset_path: str, dataset_conf) -> None: |
|
super().__init__() |
|
self.dataset_path = dataset_path |
|
self.dataset_conf = dataset_conf |
|
self.strong_matching_metric = StrongMatching() |
|
|
|
def on_validation_epoch_start(self, trainer, pl_module) -> None: |
|
relik_reader_predictor = RelikReaderPredictor(pl_module.relik_reader_core_model) |
|
predicted_samples = relik_reader_predictor.predict( |
|
self.dataset_path, |
|
samples=None, |
|
dataset_conf=self.dataset_conf, |
|
) |
|
predicted_samples = list(predicted_samples) |
|
for k, v in self.strong_matching_metric(predicted_samples).items(): |
|
pl_module.log(f"val_{k}", v) |
|
|