|
from typing import List |
|
|
|
from relik.reader.data.relik_reader_sample import RelikReaderSample |
|
from relik.reader.utils.special_symbols import NME_SYMBOL |
|
|
|
|
|
def merge_patches_predictions(sample) -> None: |
|
sample._d["predicted_window_labels"] = dict() |
|
predicted_window_labels = sample._d["predicted_window_labels"] |
|
|
|
sample._d["span_title_probabilities"] = dict() |
|
span_title_probabilities = sample._d["span_title_probabilities"] |
|
|
|
span2title = dict() |
|
for _, patch_info in sorted(sample.patches.items(), key=lambda x: x[0]): |
|
|
|
for predicted_title, predicted_spans in patch_info[ |
|
"predicted_window_labels" |
|
].items(): |
|
for pred_span in predicted_spans: |
|
pred_span = tuple(pred_span) |
|
curr_title = span2title.get(pred_span) |
|
if curr_title is None or curr_title == NME_SYMBOL: |
|
span2title[pred_span] = predicted_title |
|
|
|
|
|
|
|
|
|
for predicted_span, titles_probabilities in patch_info[ |
|
"span_title_probabilities" |
|
].items(): |
|
if predicted_span not in span_title_probabilities: |
|
span_title_probabilities[predicted_span] = titles_probabilities |
|
|
|
for span, title in span2title.items(): |
|
if title not in predicted_window_labels: |
|
predicted_window_labels[title] = list() |
|
predicted_window_labels[title].append(span) |
|
|
|
|
|
def remove_duplicate_samples( |
|
samples: List[RelikReaderSample], |
|
) -> List[RelikReaderSample]: |
|
seen_sample = set() |
|
samples_store = [] |
|
for sample in samples: |
|
sample_id = f"{sample.doc_id}#{sample.sent_id}#{sample.offset}" |
|
if sample_id not in seen_sample: |
|
seen_sample.add(sample_id) |
|
samples_store.append(sample) |
|
return samples_store |
|
|