import collections import itertools from dataclasses import dataclass from typing import List, Optional, Set, Tuple from relik.inference.data.tokenizers.base_tokenizer import BaseTokenizer from relik.reader.data.relik_reader_sample import RelikReaderSample @dataclass class Window: doc_id: int window_id: int text: str tokens: List[str] doc_topic: Optional[str] offset: int token2char_start: dict token2char_end: dict window_candidates: Optional[List[str]] = None class WindowManager: def __init__(self, tokenizer: BaseTokenizer) -> None: self.tokenizer = tokenizer def tokenize(self, document: str) -> Tuple[List[str], List[Tuple[int, int]]]: tokenized_document = self.tokenizer(document) tokens = [] tokens_char_mapping = [] for token in tokenized_document: tokens.append(token.text) tokens_char_mapping.append((token.start_char, token.end_char)) return tokens, tokens_char_mapping def create_windows( self, document: str, window_size: int, stride: int, doc_id: int = 0, doc_topic: str = None, ) -> List[RelikReaderSample]: document_tokens, tokens_char_mapping = self.tokenize(document) if doc_topic is None: doc_topic = document_tokens[0] if len(document_tokens) > 0 else "" document_windows = [] if len(document_tokens) <= window_size: text = document # relik_reader_sample = RelikReaderSample() document_windows.append( # Window( RelikReaderSample( doc_id=doc_id, window_id=0, text=text, tokens=document_tokens, doc_topic=doc_topic, offset=0, token2char_start={ str(i): tokens_char_mapping[i][0] for i in range(len(document_tokens)) }, token2char_end={ str(i): tokens_char_mapping[i][1] for i in range(len(document_tokens)) }, ) ) else: for window_id, i in enumerate(range(0, len(document_tokens), stride)): # if the last stride is smaller than the window size, then we can # include more tokens form the previous window. if i != 0 and i + window_size > len(document_tokens): overflowing_tokens = i + window_size - len(document_tokens) if overflowing_tokens >= stride: break i -= overflowing_tokens involved_token_indices = list( range(i, min(i + window_size, len(document_tokens) - 1)) ) window_tokens = [document_tokens[j] for j in involved_token_indices] window_text_start = tokens_char_mapping[involved_token_indices[0]][0] window_text_end = tokens_char_mapping[involved_token_indices[-1]][1] text = document[window_text_start:window_text_end] document_windows.append( # Window( RelikReaderSample( # dict( doc_id=doc_id, window_id=window_id, text=text, tokens=window_tokens, doc_topic=doc_topic, offset=window_text_start, token2char_start={ str(i): tokens_char_mapping[ti][0] for i, ti in enumerate(involved_token_indices) }, token2char_end={ str(i): tokens_char_mapping[ti][1] for i, ti in enumerate(involved_token_indices) }, # ) ) ) return document_windows def merge_windows( self, windows: List[RelikReaderSample] ) -> List[RelikReaderSample]: windows_by_doc_id = collections.defaultdict(list) for window in windows: windows_by_doc_id[window.doc_id].append(window) merged_window_by_doc = { doc_id: self.merge_doc_windows(doc_windows) for doc_id, doc_windows in windows_by_doc_id.items() } return list(merged_window_by_doc.values()) def merge_doc_windows(self, windows: List[RelikReaderSample]) -> RelikReaderSample: if len(windows) == 1: return windows[0] if len(windows) > 0 and getattr(windows[0], "offset", None) is not None: windows = sorted(windows, key=(lambda x: x.offset)) window_accumulator = windows[0] for next_window in windows[1:]: window_accumulator = self._merge_window_pair( window_accumulator, next_window ) return window_accumulator def _merge_tokens( self, window1: RelikReaderSample, window2: RelikReaderSample ) -> Tuple[list, dict, dict]: w1_tokens = window1.tokens[1:-1] w2_tokens = window2.tokens[1:-1] # find intersection tokens_intersection = None for k in reversed(range(1, len(w1_tokens))): if w1_tokens[-k:] == w2_tokens[:k]: tokens_intersection = k break assert tokens_intersection is not None, ( f"{window1.doc_id} - {window1.sent_id} - {window1.offset}" + f" {window2.doc_id} - {window2.sent_id} - {window2.offset}\n" + f"w1 tokens: {w1_tokens}\n" + f"w2 tokens: {w2_tokens}\n" ) final_tokens = ( [window1.tokens[0]] # CLS + w1_tokens + w2_tokens[tokens_intersection:] + [window1.tokens[-1]] # SEP ) w2_starting_offset = len(w1_tokens) - tokens_intersection def merge_char_mapping(t2c1: dict, t2c2: dict) -> dict: final_t2c = dict() final_t2c.update(t2c1) for t, c in t2c2.items(): t = int(t) if t < tokens_intersection: continue final_t2c[str(t + w2_starting_offset)] = c return final_t2c return ( final_tokens, merge_char_mapping(window1.token2char_start, window2.token2char_start), merge_char_mapping(window1.token2char_end, window2.token2char_end), ) def _merge_span_annotation( self, span_annotation1: List[list], span_annotation2: List[list] ) -> List[list]: uniq_store = set() final_span_annotation_store = [] for span_annotation in itertools.chain(span_annotation1, span_annotation2): span_annotation_id = tuple(span_annotation) if span_annotation_id not in uniq_store: uniq_store.add(span_annotation_id) final_span_annotation_store.append(span_annotation) return sorted(final_span_annotation_store, key=lambda x: x[0]) def _merge_predictions( self, window1: RelikReaderSample, window2: RelikReaderSample, ) -> Tuple[Set[Tuple[int, int, str]], dict]: merged_predictions = window1.predicted_window_labels_chars.union( window2.predicted_window_labels_chars ) span_title_probabilities = dict() # probabilities for span_prediction, predicted_probs in itertools.chain( window1.probs_window_labels_chars.items(), window2.probs_window_labels_chars.items(), ): if span_prediction not in span_title_probabilities: span_title_probabilities[span_prediction] = predicted_probs return merged_predictions, span_title_probabilities def _merge_window_pair( self, window1: RelikReaderSample, window2: RelikReaderSample, ) -> RelikReaderSample: merging_output = dict() if getattr(window1, "doc_id", None) is not None: assert window1.doc_id == window2.doc_id if getattr(window1, "offset", None) is not None: assert ( window1.offset < window2.offset ), f"window 2 offset ({window2.offset}) is smaller that window 1 offset({window1.offset})" merging_output["doc_id"] = window1.doc_id merging_output["offset"] = window2.offset m_tokens, m_token2char_start, m_token2char_end = self._merge_tokens( window1, window2 ) window_labels = None if getattr(window1, "window_labels", None) is not None: window_labels = self._merge_span_annotation( window1.window_labels, window2.window_labels ) ( predicted_window_labels_chars, probs_window_labels_chars, ) = self._merge_predictions( window1, window2, ) merging_output.update( dict( tokens=m_tokens, token2char_start=m_token2char_start, token2char_end=m_token2char_end, window_labels=window_labels, predicted_window_labels_chars=predicted_window_labels_chars, probs_window_labels_chars=probs_window_labels_chars, ) ) return RelikReaderSample(**merging_output)