File size: 16,835 Bytes
2f044c1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
import collections
import itertools
from typing import Dict, List, Optional, Set, Tuple

from relik.inference.data.splitters.blank_sentence_splitter import BlankSentenceSplitter
from relik.inference.data.splitters.base_sentence_splitter import BaseSentenceSplitter
from relik.inference.data.tokenizers.base_tokenizer import BaseTokenizer
from relik.reader.data.relik_reader_sample import RelikReaderSample


class WindowManager:
    def __init__(
        self, tokenizer: BaseTokenizer, splitter: BaseSentenceSplitter | None = None
    ) -> None:
        self.tokenizer = tokenizer
        self.splitter = splitter or BlankSentenceSplitter()

    def create_windows(
        self,
        documents: str | List[str],
        window_size: int | None = None,
        stride: int | None = None,
        max_length: int | None = None,
        doc_id: str | int | None = None,
        doc_topic: str | None = None,
        is_split_into_words: bool = False,
        mentions: List[List[List[int]]] = None,
    ) -> Tuple[List[RelikReaderSample], List[RelikReaderSample]]:
        """
        Create windows from a list of documents.

        Args:
            documents (:obj:`str` or :obj:`List[str]`):
                The document(s) to split in windows.
            window_size (:obj:`int`):
                The size of the window.
            stride (:obj:`int`):
                The stride between two windows.
            max_length (:obj:`int`, `optional`):
                The maximum length of a window.
            doc_id (:obj:`str` or :obj:`int`, `optional`):
                The id of the document(s).
            doc_topic (:obj:`str`, `optional`):
                The topic of the document(s).
            is_split_into_words (:obj:`bool`, `optional`, defaults to :obj:`False`):
                Whether the input is already pre-tokenized (e.g., split into words). If :obj:`False`, the
                input will first be tokenized using the tokenizer, then the tokens will be split into words.
            mentions (:obj:`List[List[List[int]]]`, `optional`):
                The mentions of the document(s).

        Returns:
            :obj:`List[RelikReaderSample]`: The windows created from the documents.
        """
        # normalize input
        if isinstance(documents, str) or is_split_into_words:
            documents = [documents]

        # batch tokenize
        documents_tokens = self.tokenizer(
            documents, is_split_into_words=is_split_into_words
        )

        # set splitter params
        if hasattr(self.splitter, "window_size"):
            self.splitter.window_size = window_size or self.splitter.window_size
        if hasattr(self.splitter, "window_stride"):
            self.splitter.window_stride = stride or self.splitter.window_stride

        windowed_documents, windowed_blank_documents = [], []

        if mentions is not None:
            assert len(documents) == len(
                mentions
            ), f"documents and mentions should have the same length, got {len(documents)} and {len(mentions)}"
            doc_iter = zip(documents, documents_tokens, mentions)
        else:
            doc_iter = zip(documents, documents_tokens, itertools.repeat([]))

        for infered_doc_id, (document, document_tokens, document_mentions) in enumerate(
            doc_iter
        ):
            if doc_topic is None:
                doc_topic = document_tokens[0] if len(document_tokens) > 0 else ""

            if doc_id is None:
                doc_id = infered_doc_id

            splitted_document = self.splitter(document_tokens, max_length=max_length)

            document_windows = []
            for window_id, window in enumerate(splitted_document):
                window_text_start = window[0].idx
                window_text_end = window[-1].idx + len(window[-1].text)
                if isinstance(document, str):
                    text = document[window_text_start:window_text_end]
                else:
                    # window_text_start = window[0].idx
                    # window_text_end = window[-1].i
                    text = " ".join([w.text for w in window])
                sample = RelikReaderSample(
                    doc_id=doc_id,
                    window_id=window_id,
                    text=text,
                    tokens=[w.text for w in window],
                    words=[w.text for w in window],
                    doc_topic=doc_topic,
                    offset=window_text_start,
                    spans=[
                        [m[0], m[1]] for m in document_mentions
                        if window_text_end > m[0] >= window_text_start and window_text_end >= m[1] >= window_text_start
                    ],
                    token2char_start={str(i): w.idx for i, w in enumerate(window)},
                    token2char_end={
                        str(i): w.idx + len(w.text) for i, w in enumerate(window)
                    },
                    char2token_start={
                        str(w.idx): w.i for i, w in enumerate(window)
                    },
                    char2token_end={
                        str(w.idx + len(w.text)): w.i for i, w in enumerate(window)
                    },
                )
                if mentions is not None and len(sample.spans) == 0:
                    windowed_blank_documents.append(sample)
                else:
                    document_windows.append(sample)

            windowed_documents.extend(document_windows)
        if mentions is not None:
            return windowed_documents, windowed_blank_documents
        else:
            return windowed_documents, windowed_blank_documents

    def merge_windows(
        self, windows: List[RelikReaderSample]
    ) -> List[RelikReaderSample]:
        windows_by_doc_id = collections.defaultdict(list)
        for window in windows:
            windows_by_doc_id[window.doc_id].append(window)

        merged_window_by_doc = {
            doc_id: self._merge_doc_windows(doc_windows)
            for doc_id, doc_windows in windows_by_doc_id.items()
        }

        return list(merged_window_by_doc.values())

    def _merge_doc_windows(self, windows: List[RelikReaderSample]) -> RelikReaderSample:
        if len(windows) == 1:
            return windows[0]

        if len(windows) > 0 and getattr(windows[0], "offset", None) is not None:
            windows = sorted(windows, key=(lambda x: x.offset))

        window_accumulator = windows[0]

        for next_window in windows[1:]:
            window_accumulator = self._merge_window_pair(
                window_accumulator, next_window
            )

        return window_accumulator

    @staticmethod
    def _merge_tokens(
        window1: RelikReaderSample, window2: RelikReaderSample
    ) -> Tuple[list, dict, dict]:
        w1_tokens = window1.tokens[1:-1]
        w2_tokens = window2.tokens[1:-1]

        # find intersection if any
        tokens_intersection = 0
        for k in reversed(range(1, len(w1_tokens))):
            if w1_tokens[-k:] == w2_tokens[:k]:
                tokens_intersection = k
                break

        final_tokens = (
            [window1.tokens[0]]  # CLS
            + w1_tokens
            + w2_tokens[tokens_intersection:]
            + [window1.tokens[-1]]  # SEP
        )

        w2_starting_offset = len(w1_tokens) - tokens_intersection

        def merge_char_mapping(t2c1: dict, t2c2: dict) -> dict:
            final_t2c = dict()
            final_t2c.update(t2c1)
            for t, c in t2c2.items():
                t = int(t)
                if t < tokens_intersection:
                    continue
                final_t2c[str(t + w2_starting_offset)] = c
            return final_t2c

        return (
            final_tokens,
            merge_char_mapping(window1.token2char_start, window2.token2char_start),
            merge_char_mapping(window1.token2char_end, window2.token2char_end),
        )

    @staticmethod
    def _merge_words(
        window1: RelikReaderSample, window2: RelikReaderSample
    ) -> Tuple[list, dict, dict]:
        w1_words = window1.words
        w2_words = window2.words

        # find intersection if any
        words_intersection = 0
        for k in reversed(range(1, len(w1_words))):
            if w1_words[-k:] == w2_words[:k]:
                words_intersection = k
                break

        final_words = w1_words + w2_words[words_intersection:]

        w2_starting_offset = len(w1_words) - words_intersection

        def merge_word_mapping(t2c1: dict, t2c2: dict) -> dict:
            final_t2c = dict()
            if t2c1 is None:
                t2c1 = dict()
            if t2c2 is None:
                t2c2 = dict()
            final_t2c.update(t2c1)
            for t, c in t2c2.items():
                t = int(t)
                if t < words_intersection:
                    continue
                final_t2c[str(t + w2_starting_offset)] = c
            return final_t2c

        return (
            final_words,
            merge_word_mapping(window1.token2word_start, window2.token2word_start),
            merge_word_mapping(window1.token2word_end, window2.token2word_end),
        )

    @staticmethod
    def _merge_span_annotation(
        span_annotation1: List[list], span_annotation2: List[list]
    ) -> List[list]:
        uniq_store = set()
        final_span_annotation_store = []
        for span_annotation in itertools.chain(span_annotation1, span_annotation2):
            span_annotation_id = tuple(span_annotation)
            if span_annotation_id not in uniq_store:
                uniq_store.add(span_annotation_id)
                final_span_annotation_store.append(span_annotation)
        return sorted(final_span_annotation_store, key=lambda x: x[0])

    @staticmethod
    def _merge_predictions(
        window1: RelikReaderSample, window2: RelikReaderSample
    ) -> Tuple[Set[Tuple[int, int, str]], dict]:
        # a RelikReaderSample should have a filed called `predicted_spans`
        # that stores the span-level predictions, or a filed called
        # `predicted_triples` that stores the triple-level predictions

        # span predictions
        merged_span_predictions: Set = set()
        merged_span_probabilities = dict()
        # triple predictions
        merged_triplet_predictions: Set = set()
        merged_triplet_probs: Dict = dict()

        if (
            getattr(window1, "predicted_spans", None) is not None
            and getattr(window2, "predicted_spans", None) is not None
        ):
            merged_span_predictions = set(window1.predicted_spans).union(
                set(window2.predicted_spans)
            )
            merged_span_predictions = sorted(merged_span_predictions)
            # probabilities
            for span_prediction, predicted_probs in itertools.chain(
                window1.probs_window_labels_chars.items()
                if window1.probs_window_labels_chars is not None
                else [],
                window2.probs_window_labels_chars.items()
                if window2.probs_window_labels_chars is not None
                else [],
            ):
                if span_prediction not in merged_span_probabilities:
                    merged_span_probabilities[span_prediction] = predicted_probs

            if (
                getattr(window1, "predicted_triples", None) is not None
                and getattr(window2, "predicted_triples", None) is not None
            ):
                # try to merge the triples predictions
                # add offset to the second window
                window1_triplets = [
                    (
                        merged_span_predictions.index(window1.predicted_spans[t[0]]),
                        t[1],
                        merged_span_predictions.index(window1.predicted_spans[t[2]]),
                        t[3]
                    )
                    for t in window1.predicted_triples
                ]
                window2_triplets = [
                    (
                        merged_span_predictions.index(window2.predicted_spans[t[0]]),
                        t[1],
                        merged_span_predictions.index(window2.predicted_spans[t[2]]),
                        t[3]
                    )
                    for t in window2.predicted_triples
                ]
                merged_triplet_predictions = set(window1_triplets).union(
                    set(window2_triplets)
                )
                merged_triplet_predictions = sorted(merged_triplet_predictions)
                # for now no triplet probs, we don't need them for the moment

        return (
            merged_span_predictions,
            merged_span_probabilities,
            merged_triplet_predictions,
            merged_triplet_probs,
        )

    @staticmethod
    def _merge_candidates(window1: RelikReaderSample, window2: RelikReaderSample):
        candidates = []
        windows_candidates = []

        # TODO: retro-compatibility
        if getattr(window1, "candidates", None) is not None:
            candidates = window1.candidates
        if getattr(window2, "candidates", None) is not None:
            candidates += window2.candidates

        # TODO: retro-compatibility
        if getattr(window1, "windows_candidates", None) is not None:
            windows_candidates = window1.windows_candidates
        if getattr(window2, "windows_candidates", None) is not None:
            windows_candidates += window2.windows_candidates

        # TODO: add programmatically
        span_candidates = []
        if getattr(window1, "span_candidates", None) is not None:
            span_candidates = window1.span_candidates
        if getattr(window2, "span_candidates", None) is not None:
            span_candidates += window2.span_candidates

        triplet_candidates = []
        if getattr(window1, "triplet_candidates", None) is not None:
            triplet_candidates = window1.triplet_candidates
        if getattr(window2, "triplet_candidates", None) is not None:
            triplet_candidates += window2.triplet_candidates

        # make them unique
        candidates = list(set(candidates))
        windows_candidates = list(set(windows_candidates))

        span_candidates = list(set(span_candidates))
        triplet_candidates = list(set(triplet_candidates))

        return candidates, windows_candidates, span_candidates, triplet_candidates

    def _merge_window_pair(
        self,
        window1: RelikReaderSample,
        window2: RelikReaderSample,
    ) -> RelikReaderSample:
        merging_output = dict()

        if getattr(window1, "doc_id", None) is not None:
            assert window1.doc_id == window2.doc_id

        if getattr(window1, "offset", None) is not None:
            assert (
                window1.offset < window2.offset
            ), f"window 2 offset ({window2.offset}) is smaller that window 1 offset({window1.offset})"

        merging_output["doc_id"] = window1.doc_id
        merging_output["offset"] = window2.offset

        m_tokens, m_token2char_start, m_token2char_end = self._merge_tokens(
            window1, window2
        )

        m_words, m_token2word_start, m_token2word_end = self._merge_words(
            window1, window2
        )

        (
            m_candidates,
            m_windows_candidates,
            m_span_candidates,
            m_triplet_candidates,
        ) = self._merge_candidates(window1, window2)

        window_labels = None
        if getattr(window1, "window_labels", None) is not None:
            window_labels = self._merge_span_annotation(
                window1.window_labels, window2.window_labels
            )

        (
            predicted_spans,
            predicted_spans_probs,
            predicted_triples,
            predicted_triples_probs,
        ) = self._merge_predictions(window1, window2)

        merging_output.update(
            dict(
                tokens=m_tokens,
                words=m_words,
                token2char_start=m_token2char_start,
                token2char_end=m_token2char_end,
                token2word_start=m_token2word_start,
                token2word_end=m_token2word_end,
                window_labels=window_labels,
                candidates=m_candidates,
                span_candidates=m_span_candidates,
                triplet_candidates=m_triplet_candidates,
                windows_candidates=m_windows_candidates,
                predicted_spans=predicted_spans,
                predicted_spans_probs=predicted_spans_probs,
                predicted_triples=predicted_triples,
                predicted_triples_probs=predicted_triples_probs,
            )
        )

        return RelikReaderSample(**merging_output)