from typing import Dict, List from collections import defaultdict from lightning.pytorch.callbacks import Callback from relik.reader.data.relik_reader_re_data import RelikREDataset from relik.reader.data.relik_reader_sample import RelikReaderSample from relik.reader.relik_reader_predictor import RelikReaderPredictor from relik.reader.utils.metrics import compute_metrics class StrongMatching: def __call__(self, predicted_samples: List[RelikReaderSample]) -> Dict: # accumulators correct_predictions, total_predictions, total_gold = ( 0, 0, 0, ) correct_predictions_strict, total_predictions_strict = ( 0, 0, ) correct_predictions_bound, total_predictions_bound = ( 0, 0, ) correct_span_predictions, total_span_predictions, total_gold_spans = ( 0, 0, 0, ) ( correct_span_in_triplets_predictions, total_span_in_triplets_predictions, total_gold_spans_in_triplets, ) = ( 0, 0, 0, ) # collect data from samples for sample in predicted_samples: if sample.triplets is None: sample.triplets = [] if sample.span_candidates: predicted_annotations_strict = set( [ ( triplet["subject"]["start"], triplet["subject"]["end"], triplet["subject"]["type"], triplet["relation"]["name"], triplet["object"]["start"], triplet["object"]["end"], triplet["object"]["type"], ) for triplet in sample.predicted_relations ] ) gold_annotations_strict = set( [ ( triplet["subject"]["start"], triplet["subject"]["end"], triplet["subject"]["type"], triplet["relation"]["name"], triplet["object"]["start"], triplet["object"]["end"], triplet["object"]["type"], ) for triplet in sample.triplets ] ) predicted_spans_strict = set((ss, se, st) for (ss, se, st) in sample.predicted_entities) gold_spans_strict = set(sample.entities) predicted_spans_in_triplets = set( [ ( triplet["subject"]["start"], triplet["subject"]["end"], triplet["subject"]["type"], ) for triplet in sample.predicted_relations ] + [ ( triplet["object"]["start"], triplet["object"]["end"], triplet["object"]["type"], ) for triplet in sample.predicted_relations ] ) gold_spans_in_triplets = set( [ ( triplet["subject"]["start"], triplet["subject"]["end"], triplet["subject"]["type"], ) for triplet in sample.triplets ] + [ ( triplet["object"]["start"], triplet["object"]["end"], triplet["object"]["type"], ) for triplet in sample.triplets ] ) # strict correct_span_predictions += len( predicted_spans_strict.intersection(gold_spans_strict) ) total_span_predictions += len(predicted_spans_strict) correct_span_in_triplets_predictions += len( predicted_spans_in_triplets.intersection(gold_spans_in_triplets) ) total_span_in_triplets_predictions += len(predicted_spans_in_triplets) total_gold_spans_in_triplets += len(gold_spans_in_triplets) correct_predictions_strict += len( predicted_annotations_strict.intersection(gold_annotations_strict) ) total_predictions_strict += len(predicted_annotations_strict) predicted_annotations = set( [ ( triplet["subject"]["start"], triplet["subject"]["end"], -1, triplet["relation"]["name"], triplet["object"]["start"], triplet["object"]["end"], -1, ) for triplet in sample.predicted_relations ] ) gold_annotations = set( [ ( triplet["subject"]["start"], triplet["subject"]["end"], -1, triplet["relation"]["name"], triplet["object"]["start"], triplet["object"]["end"], -1, ) for triplet in sample.triplets ] ) predicted_spans = set( [(ss, se) for (ss, se, _) in sample.predicted_entities] ) gold_spans = set([(ss, se) for (ss, se, _) in sample.entities]) total_gold_spans += len(gold_spans) correct_predictions_bound += len(predicted_spans.intersection(gold_spans)) total_predictions_bound += len(predicted_spans) total_predictions += len(predicted_annotations) total_gold += len(gold_annotations) # correct relation extraction correct_predictions += len( predicted_annotations.intersection(gold_annotations) ) span_precision, span_recall, span_f1 = compute_metrics( correct_span_predictions, total_span_predictions, total_gold_spans ) bound_precision, bound_recall, bound_f1 = compute_metrics( correct_predictions_bound, total_predictions_bound, total_gold_spans ) precision, recall, f1 = compute_metrics( correct_predictions, total_predictions, total_gold ) if sample.span_candidates: precision_strict, recall_strict, f1_strict = compute_metrics( correct_predictions_strict, total_predictions_strict, total_gold ) ( span_in_triplet_precisiion, span_in_triplet_recall, span_in_triplet_f1, ) = compute_metrics( correct_span_in_triplets_predictions, total_span_in_triplets_predictions, total_gold_spans_in_triplets, ) return { "span-precision-strict": span_precision, "span-recall-strict": span_recall, "span-f1-strict": span_f1, "span-precision": bound_precision, "span-recall": bound_recall, "span-f1": bound_f1, "span-in-triplet-precision": span_in_triplet_precisiion, "span-in-triplet-recall": span_in_triplet_recall, "span-in-triplet-f1": span_in_triplet_f1, "precision": precision, "recall": recall, "f1": f1, "precision-strict": precision_strict, "recall-strict": recall_strict, "f1-strict": f1_strict, } else: return { "span-precision": bound_precision, "span-recall": bound_recall, "span-f1": bound_f1, "precision": precision, "recall": recall, "f1": f1, } class StrongMatchingPerRelation: def __call__(self, predicted_samples: List[RelikReaderSample]) -> Dict: correct_predictions, total_predictions, total_gold = ( defaultdict(int), defaultdict(int), defaultdict(int), ) correct_predictions_strict, total_predictions_strict = ( defaultdict(int), defaultdict(int), ) # collect data from samples for sample in predicted_samples: if sample.triplets is None: sample.triplets = [] if sample.span_candidates: gold_annotations_strict = set( [ ( triplet["subject"]["start"], triplet["subject"]["end"], triplet["subject"]["type"], triplet["relation"]["name"], triplet["object"]["start"], triplet["object"]["end"], triplet["object"]["type"], ) for triplet in sample.triplets ] ) # compute correct preds per triplet["relation"]["name"] for triplet in sample.predicted_relations: predicted_annotations_strict = ( triplet["subject"]["start"], triplet["subject"]["end"], triplet["subject"]["type"], triplet["relation"]["name"], triplet["object"]["start"], triplet["object"]["end"], triplet["object"]["type"], ) if predicted_annotations_strict in gold_annotations_strict: correct_predictions_strict[triplet["relation"]["name"]] += 1 total_predictions_strict[triplet["relation"]["name"]] += 1 gold_annotations = set( [ ( triplet["subject"]["start"], triplet["subject"]["end"], -1, triplet["relation"]["name"], triplet["object"]["start"], triplet["object"]["end"], -1, ) for triplet in sample.triplets ] ) for triplet in sample.predicted_relations: predicted_annotations = ( triplet["subject"]["start"], triplet["subject"]["end"], -1, triplet["relation"]["name"], triplet["object"]["start"], triplet["object"]["end"], -1, ) if predicted_annotations in gold_annotations: correct_predictions[triplet["relation"]["name"]] += 1 total_predictions[triplet["relation"]["name"]] += 1 for triplet in sample.triplets: total_gold[triplet["relation"]["name"]] += 1 metrics = {} metrics_non_zero = 0 for relation in total_gold.keys(): precision, recall, f1 = compute_metrics( correct_predictions[relation], total_predictions[relation], total_gold[relation], ) metrics[f"{relation}-precision"] = precision metrics[f"{relation}-recall"] = recall metrics[f"{relation}-f1"] = f1 precision_strict, recall_strict, f1_strict = compute_metrics( correct_predictions_strict[relation], total_predictions_strict[relation], total_gold[relation], ) metrics[f"{relation}-precision-strict"] = precision_strict metrics[f"{relation}-recall-strict"] = recall_strict metrics[f"{relation}-f1-strict"] = f1_strict if metrics[f"{relation}-f1-strict"] > 0: metrics_non_zero += 1 # print in a readable way print( f"{relation} precision: {precision:.4f} recall: {recall:.4f} f1: {f1:.4f} precision_strict: {precision_strict:.4f} recall_strict: {recall_strict:.4f} f1_strict: {f1_strict:.4f} support: {total_gold[relation]}" ) print(f"metrics_non_zero: {metrics_non_zero}") return metrics class REStrongMatchingCallback(Callback): def __init__(self, dataset_path: str, dataset_conf, log_metric: str = "val_") -> None: super().__init__() self.dataset_path = dataset_path self.dataset_conf = dataset_conf self.strong_matching_metric = StrongMatching() self.log_metric = log_metric def on_validation_epoch_start(self, trainer, pl_module) -> None: dataloader = trainer.val_dataloaders if ( self.dataset_path == dataloader.dataset.dataset_path and dataloader.dataset.samples is not None and len(dataloader.dataset.samples) > 0 ): relik_reader_predictor = RelikReaderPredictor( pl_module.relik_reader_re_model, dataloader=trainer.val_dataloaders ) else: relik_reader_predictor = RelikReaderPredictor( pl_module.relik_reader_re_model ) predicted_samples = relik_reader_predictor._predict( self.dataset_path, None, self.dataset_conf, ) predicted_samples = list(predicted_samples) for sample in predicted_samples: RelikREDataset._convert_annotations(sample) for k, v in self.strong_matching_metric(predicted_samples).items(): pl_module.log(f"{self.log_metric}{k}", v)