File size: 1,949 Bytes
626eca0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
from typing import List

from relik.reader.data.relik_reader_sample import RelikReaderSample
from relik.reader.utils.special_symbols import NME_SYMBOL


def merge_patches_predictions(sample) -> None:
    sample._d["predicted_window_labels"] = dict()
    predicted_window_labels = sample._d["predicted_window_labels"]

    sample._d["span_title_probabilities"] = dict()
    span_title_probabilities = sample._d["span_title_probabilities"]

    span2title = dict()
    for _, patch_info in sorted(sample.patches.items(), key=lambda x: x[0]):
        # selecting span predictions
        for predicted_title, predicted_spans in patch_info[
            "predicted_window_labels"
        ].items():
            for pred_span in predicted_spans:
                pred_span = tuple(pred_span)
                curr_title = span2title.get(pred_span)
                if curr_title is None or curr_title == NME_SYMBOL:
                    span2title[pred_span] = predicted_title
                # else:
                #     print("Merging at patch level")

        # selecting span predictions probability
        for predicted_span, titles_probabilities in patch_info[
            "span_title_probabilities"
        ].items():
            if predicted_span not in span_title_probabilities:
                span_title_probabilities[predicted_span] = titles_probabilities

    for span, title in span2title.items():
        if title not in predicted_window_labels:
            predicted_window_labels[title] = list()
        predicted_window_labels[title].append(span)


def remove_duplicate_samples(
    samples: List[RelikReaderSample],
) -> List[RelikReaderSample]:
    seen_sample = set()
    samples_store = []
    for sample in samples:
        sample_id = f"{sample.doc_id}#{sample.sent_id}#{sample.offset}"
        if sample_id not in seen_sample:
            seen_sample.add(sample_id)
            samples_store.append(sample)
    return samples_store