CarlosMalaga's picture
Upload 201 files
2f044c1 verified
raw
history blame
14.6 kB
from typing import Dict, List
from collections import defaultdict
from lightning.pytorch.callbacks import Callback
from relik.reader.data.relik_reader_re_data import RelikREDataset
from relik.reader.data.relik_reader_sample import RelikReaderSample
from relik.reader.relik_reader_predictor import RelikReaderPredictor
from relik.reader.utils.metrics import compute_metrics
class StrongMatching:
def __call__(self, predicted_samples: List[RelikReaderSample]) -> Dict:
# accumulators
correct_predictions, total_predictions, total_gold = (
0,
0,
0,
)
correct_predictions_strict, total_predictions_strict = (
0,
0,
)
correct_predictions_bound, total_predictions_bound = (
0,
0,
)
correct_span_predictions, total_span_predictions, total_gold_spans = (
0,
0,
0,
)
(
correct_span_in_triplets_predictions,
total_span_in_triplets_predictions,
total_gold_spans_in_triplets,
) = (
0,
0,
0,
)
# collect data from samples
for sample in predicted_samples:
if sample.triplets is None:
sample.triplets = []
if sample.span_candidates:
predicted_annotations_strict = set(
[
(
triplet["subject"]["start"],
triplet["subject"]["end"],
triplet["subject"]["type"],
triplet["relation"]["name"],
triplet["object"]["start"],
triplet["object"]["end"],
triplet["object"]["type"],
)
for triplet in sample.predicted_relations
]
)
gold_annotations_strict = set(
[
(
triplet["subject"]["start"],
triplet["subject"]["end"],
triplet["subject"]["type"],
triplet["relation"]["name"],
triplet["object"]["start"],
triplet["object"]["end"],
triplet["object"]["type"],
)
for triplet in sample.triplets
]
)
predicted_spans_strict = set((ss, se, st) for (ss, se, st) in sample.predicted_entities)
gold_spans_strict = set(sample.entities)
predicted_spans_in_triplets = set(
[
(
triplet["subject"]["start"],
triplet["subject"]["end"],
triplet["subject"]["type"],
)
for triplet in sample.predicted_relations
]
+ [
(
triplet["object"]["start"],
triplet["object"]["end"],
triplet["object"]["type"],
)
for triplet in sample.predicted_relations
]
)
gold_spans_in_triplets = set(
[
(
triplet["subject"]["start"],
triplet["subject"]["end"],
triplet["subject"]["type"],
)
for triplet in sample.triplets
]
+ [
(
triplet["object"]["start"],
triplet["object"]["end"],
triplet["object"]["type"],
)
for triplet in sample.triplets
]
)
# strict
correct_span_predictions += len(
predicted_spans_strict.intersection(gold_spans_strict)
)
total_span_predictions += len(predicted_spans_strict)
correct_span_in_triplets_predictions += len(
predicted_spans_in_triplets.intersection(gold_spans_in_triplets)
)
total_span_in_triplets_predictions += len(predicted_spans_in_triplets)
total_gold_spans_in_triplets += len(gold_spans_in_triplets)
correct_predictions_strict += len(
predicted_annotations_strict.intersection(gold_annotations_strict)
)
total_predictions_strict += len(predicted_annotations_strict)
predicted_annotations = set(
[
(
triplet["subject"]["start"],
triplet["subject"]["end"],
-1,
triplet["relation"]["name"],
triplet["object"]["start"],
triplet["object"]["end"],
-1,
)
for triplet in sample.predicted_relations
]
)
gold_annotations = set(
[
(
triplet["subject"]["start"],
triplet["subject"]["end"],
-1,
triplet["relation"]["name"],
triplet["object"]["start"],
triplet["object"]["end"],
-1,
)
for triplet in sample.triplets
]
)
predicted_spans = set(
[(ss, se) for (ss, se, _) in sample.predicted_entities]
)
gold_spans = set([(ss, se) for (ss, se, _) in sample.entities])
total_gold_spans += len(gold_spans)
correct_predictions_bound += len(predicted_spans.intersection(gold_spans))
total_predictions_bound += len(predicted_spans)
total_predictions += len(predicted_annotations)
total_gold += len(gold_annotations)
# correct relation extraction
correct_predictions += len(
predicted_annotations.intersection(gold_annotations)
)
span_precision, span_recall, span_f1 = compute_metrics(
correct_span_predictions, total_span_predictions, total_gold_spans
)
bound_precision, bound_recall, bound_f1 = compute_metrics(
correct_predictions_bound, total_predictions_bound, total_gold_spans
)
precision, recall, f1 = compute_metrics(
correct_predictions, total_predictions, total_gold
)
if sample.span_candidates:
precision_strict, recall_strict, f1_strict = compute_metrics(
correct_predictions_strict, total_predictions_strict, total_gold
)
(
span_in_triplet_precisiion,
span_in_triplet_recall,
span_in_triplet_f1,
) = compute_metrics(
correct_span_in_triplets_predictions,
total_span_in_triplets_predictions,
total_gold_spans_in_triplets,
)
return {
"span-precision-strict": span_precision,
"span-recall-strict": span_recall,
"span-f1-strict": span_f1,
"span-precision": bound_precision,
"span-recall": bound_recall,
"span-f1": bound_f1,
"span-in-triplet-precision": span_in_triplet_precisiion,
"span-in-triplet-recall": span_in_triplet_recall,
"span-in-triplet-f1": span_in_triplet_f1,
"precision": precision,
"recall": recall,
"f1": f1,
"precision-strict": precision_strict,
"recall-strict": recall_strict,
"f1-strict": f1_strict,
}
else:
return {
"span-precision": bound_precision,
"span-recall": bound_recall,
"span-f1": bound_f1,
"precision": precision,
"recall": recall,
"f1": f1,
}
class StrongMatchingPerRelation:
def __call__(self, predicted_samples: List[RelikReaderSample]) -> Dict:
correct_predictions, total_predictions, total_gold = (
defaultdict(int),
defaultdict(int),
defaultdict(int),
)
correct_predictions_strict, total_predictions_strict = (
defaultdict(int),
defaultdict(int),
)
# collect data from samples
for sample in predicted_samples:
if sample.triplets is None:
sample.triplets = []
if sample.span_candidates:
gold_annotations_strict = set(
[
(
triplet["subject"]["start"],
triplet["subject"]["end"],
triplet["subject"]["type"],
triplet["relation"]["name"],
triplet["object"]["start"],
triplet["object"]["end"],
triplet["object"]["type"],
)
for triplet in sample.triplets
]
)
# compute correct preds per triplet["relation"]["name"]
for triplet in sample.predicted_relations:
predicted_annotations_strict = (
triplet["subject"]["start"],
triplet["subject"]["end"],
triplet["subject"]["type"],
triplet["relation"]["name"],
triplet["object"]["start"],
triplet["object"]["end"],
triplet["object"]["type"],
)
if predicted_annotations_strict in gold_annotations_strict:
correct_predictions_strict[triplet["relation"]["name"]] += 1
total_predictions_strict[triplet["relation"]["name"]] += 1
gold_annotations = set(
[
(
triplet["subject"]["start"],
triplet["subject"]["end"],
-1,
triplet["relation"]["name"],
triplet["object"]["start"],
triplet["object"]["end"],
-1,
)
for triplet in sample.triplets
]
)
for triplet in sample.predicted_relations:
predicted_annotations = (
triplet["subject"]["start"],
triplet["subject"]["end"],
-1,
triplet["relation"]["name"],
triplet["object"]["start"],
triplet["object"]["end"],
-1,
)
if predicted_annotations in gold_annotations:
correct_predictions[triplet["relation"]["name"]] += 1
total_predictions[triplet["relation"]["name"]] += 1
for triplet in sample.triplets:
total_gold[triplet["relation"]["name"]] += 1
metrics = {}
metrics_non_zero = 0
for relation in total_gold.keys():
precision, recall, f1 = compute_metrics(
correct_predictions[relation],
total_predictions[relation],
total_gold[relation],
)
metrics[f"{relation}-precision"] = precision
metrics[f"{relation}-recall"] = recall
metrics[f"{relation}-f1"] = f1
precision_strict, recall_strict, f1_strict = compute_metrics(
correct_predictions_strict[relation],
total_predictions_strict[relation],
total_gold[relation],
)
metrics[f"{relation}-precision-strict"] = precision_strict
metrics[f"{relation}-recall-strict"] = recall_strict
metrics[f"{relation}-f1-strict"] = f1_strict
if metrics[f"{relation}-f1-strict"] > 0:
metrics_non_zero += 1
# print in a readable way
print(
f"{relation} precision: {precision:.4f} recall: {recall:.4f} f1: {f1:.4f} precision_strict: {precision_strict:.4f} recall_strict: {recall_strict:.4f} f1_strict: {f1_strict:.4f} support: {total_gold[relation]}"
)
print(f"metrics_non_zero: {metrics_non_zero}")
return metrics
class REStrongMatchingCallback(Callback):
def __init__(self, dataset_path: str, dataset_conf, log_metric: str = "val_") -> None:
super().__init__()
self.dataset_path = dataset_path
self.dataset_conf = dataset_conf
self.strong_matching_metric = StrongMatching()
self.log_metric = log_metric
def on_validation_epoch_start(self, trainer, pl_module) -> None:
dataloader = trainer.val_dataloaders
if (
self.dataset_path == dataloader.dataset.dataset_path
and dataloader.dataset.samples is not None
and len(dataloader.dataset.samples) > 0
):
relik_reader_predictor = RelikReaderPredictor(
pl_module.relik_reader_re_model, dataloader=trainer.val_dataloaders
)
else:
relik_reader_predictor = RelikReaderPredictor(
pl_module.relik_reader_re_model
)
predicted_samples = relik_reader_predictor._predict(
self.dataset_path,
None,
self.dataset_conf,
)
predicted_samples = list(predicted_samples)
for sample in predicted_samples:
RelikREDataset._convert_annotations(sample)
for k, v in self.strong_matching_metric(predicted_samples).items():
pl_module.log(f"{self.log_metric}{k}", v)