|
import collections |
|
import itertools |
|
from dataclasses import dataclass |
|
from typing import List, Optional, Set, Tuple |
|
|
|
from relik.inference.data.tokenizers.base_tokenizer import BaseTokenizer |
|
from relik.reader.data.relik_reader_sample import RelikReaderSample |
|
|
|
|
|
@dataclass |
|
class Window: |
|
doc_id: int |
|
window_id: int |
|
text: str |
|
tokens: List[str] |
|
doc_topic: Optional[str] |
|
offset: int |
|
token2char_start: dict |
|
token2char_end: dict |
|
window_candidates: Optional[List[str]] = None |
|
|
|
|
|
class WindowManager: |
|
def __init__(self, tokenizer: BaseTokenizer) -> None: |
|
self.tokenizer = tokenizer |
|
|
|
def tokenize(self, document: str) -> Tuple[List[str], List[Tuple[int, int]]]: |
|
tokenized_document = self.tokenizer(document) |
|
tokens = [] |
|
tokens_char_mapping = [] |
|
for token in tokenized_document: |
|
tokens.append(token.text) |
|
tokens_char_mapping.append((token.start_char, token.end_char)) |
|
return tokens, tokens_char_mapping |
|
|
|
def create_windows( |
|
self, |
|
document: str, |
|
window_size: int, |
|
stride: int, |
|
doc_id: int = 0, |
|
doc_topic: str = None, |
|
) -> List[RelikReaderSample]: |
|
document_tokens, tokens_char_mapping = self.tokenize(document) |
|
if doc_topic is None: |
|
doc_topic = document_tokens[0] if len(document_tokens) > 0 else "" |
|
document_windows = [] |
|
if len(document_tokens) <= window_size: |
|
text = document |
|
|
|
document_windows.append( |
|
|
|
RelikReaderSample( |
|
doc_id=doc_id, |
|
window_id=0, |
|
text=text, |
|
tokens=document_tokens, |
|
doc_topic=doc_topic, |
|
offset=0, |
|
token2char_start={ |
|
str(i): tokens_char_mapping[i][0] |
|
for i in range(len(document_tokens)) |
|
}, |
|
token2char_end={ |
|
str(i): tokens_char_mapping[i][1] |
|
for i in range(len(document_tokens)) |
|
}, |
|
) |
|
) |
|
else: |
|
for window_id, i in enumerate(range(0, len(document_tokens), stride)): |
|
|
|
|
|
if i != 0 and i + window_size > len(document_tokens): |
|
overflowing_tokens = i + window_size - len(document_tokens) |
|
if overflowing_tokens >= stride: |
|
break |
|
i -= overflowing_tokens |
|
|
|
involved_token_indices = list( |
|
range(i, min(i + window_size, len(document_tokens) - 1)) |
|
) |
|
window_tokens = [document_tokens[j] for j in involved_token_indices] |
|
window_text_start = tokens_char_mapping[involved_token_indices[0]][0] |
|
window_text_end = tokens_char_mapping[involved_token_indices[-1]][1] |
|
text = document[window_text_start:window_text_end] |
|
document_windows.append( |
|
|
|
RelikReaderSample( |
|
|
|
doc_id=doc_id, |
|
window_id=window_id, |
|
text=text, |
|
tokens=window_tokens, |
|
doc_topic=doc_topic, |
|
offset=window_text_start, |
|
token2char_start={ |
|
str(i): tokens_char_mapping[ti][0] |
|
for i, ti in enumerate(involved_token_indices) |
|
}, |
|
token2char_end={ |
|
str(i): tokens_char_mapping[ti][1] |
|
for i, ti in enumerate(involved_token_indices) |
|
}, |
|
|
|
) |
|
) |
|
return document_windows |
|
|
|
def merge_windows( |
|
self, windows: List[RelikReaderSample] |
|
) -> List[RelikReaderSample]: |
|
windows_by_doc_id = collections.defaultdict(list) |
|
for window in windows: |
|
windows_by_doc_id[window.doc_id].append(window) |
|
|
|
merged_window_by_doc = { |
|
doc_id: self.merge_doc_windows(doc_windows) |
|
for doc_id, doc_windows in windows_by_doc_id.items() |
|
} |
|
|
|
return list(merged_window_by_doc.values()) |
|
|
|
def merge_doc_windows(self, windows: List[RelikReaderSample]) -> RelikReaderSample: |
|
if len(windows) == 1: |
|
return windows[0] |
|
|
|
if len(windows) > 0 and getattr(windows[0], "offset", None) is not None: |
|
windows = sorted(windows, key=(lambda x: x.offset)) |
|
|
|
window_accumulator = windows[0] |
|
|
|
for next_window in windows[1:]: |
|
window_accumulator = self._merge_window_pair( |
|
window_accumulator, next_window |
|
) |
|
|
|
return window_accumulator |
|
|
|
def _merge_tokens( |
|
self, window1: RelikReaderSample, window2: RelikReaderSample |
|
) -> Tuple[list, dict, dict]: |
|
w1_tokens = window1.tokens[1:-1] |
|
w2_tokens = window2.tokens[1:-1] |
|
|
|
|
|
tokens_intersection = None |
|
for k in reversed(range(1, len(w1_tokens))): |
|
if w1_tokens[-k:] == w2_tokens[:k]: |
|
tokens_intersection = k |
|
break |
|
assert tokens_intersection is not None, ( |
|
f"{window1.doc_id} - {window1.sent_id} - {window1.offset}" |
|
+ f" {window2.doc_id} - {window2.sent_id} - {window2.offset}\n" |
|
+ f"w1 tokens: {w1_tokens}\n" |
|
+ f"w2 tokens: {w2_tokens}\n" |
|
) |
|
|
|
final_tokens = ( |
|
[window1.tokens[0]] |
|
+ w1_tokens |
|
+ w2_tokens[tokens_intersection:] |
|
+ [window1.tokens[-1]] |
|
) |
|
|
|
w2_starting_offset = len(w1_tokens) - tokens_intersection |
|
|
|
def merge_char_mapping(t2c1: dict, t2c2: dict) -> dict: |
|
final_t2c = dict() |
|
final_t2c.update(t2c1) |
|
for t, c in t2c2.items(): |
|
t = int(t) |
|
if t < tokens_intersection: |
|
continue |
|
final_t2c[str(t + w2_starting_offset)] = c |
|
return final_t2c |
|
|
|
return ( |
|
final_tokens, |
|
merge_char_mapping(window1.token2char_start, window2.token2char_start), |
|
merge_char_mapping(window1.token2char_end, window2.token2char_end), |
|
) |
|
|
|
def _merge_span_annotation( |
|
self, span_annotation1: List[list], span_annotation2: List[list] |
|
) -> List[list]: |
|
uniq_store = set() |
|
final_span_annotation_store = [] |
|
for span_annotation in itertools.chain(span_annotation1, span_annotation2): |
|
span_annotation_id = tuple(span_annotation) |
|
if span_annotation_id not in uniq_store: |
|
uniq_store.add(span_annotation_id) |
|
final_span_annotation_store.append(span_annotation) |
|
return sorted(final_span_annotation_store, key=lambda x: x[0]) |
|
|
|
def _merge_predictions( |
|
self, |
|
window1: RelikReaderSample, |
|
window2: RelikReaderSample, |
|
) -> Tuple[Set[Tuple[int, int, str]], dict]: |
|
merged_predictions = window1.predicted_window_labels_chars.union( |
|
window2.predicted_window_labels_chars |
|
) |
|
|
|
span_title_probabilities = dict() |
|
|
|
for span_prediction, predicted_probs in itertools.chain( |
|
window1.probs_window_labels_chars.items(), |
|
window2.probs_window_labels_chars.items(), |
|
): |
|
if span_prediction not in span_title_probabilities: |
|
span_title_probabilities[span_prediction] = predicted_probs |
|
|
|
return merged_predictions, span_title_probabilities |
|
|
|
def _merge_window_pair( |
|
self, |
|
window1: RelikReaderSample, |
|
window2: RelikReaderSample, |
|
) -> RelikReaderSample: |
|
merging_output = dict() |
|
|
|
if getattr(window1, "doc_id", None) is not None: |
|
assert window1.doc_id == window2.doc_id |
|
|
|
if getattr(window1, "offset", None) is not None: |
|
assert ( |
|
window1.offset < window2.offset |
|
), f"window 2 offset ({window2.offset}) is smaller that window 1 offset({window1.offset})" |
|
|
|
merging_output["doc_id"] = window1.doc_id |
|
merging_output["offset"] = window2.offset |
|
|
|
m_tokens, m_token2char_start, m_token2char_end = self._merge_tokens( |
|
window1, window2 |
|
) |
|
|
|
window_labels = None |
|
if getattr(window1, "window_labels", None) is not None: |
|
window_labels = self._merge_span_annotation( |
|
window1.window_labels, window2.window_labels |
|
) |
|
( |
|
predicted_window_labels_chars, |
|
probs_window_labels_chars, |
|
) = self._merge_predictions( |
|
window1, |
|
window2, |
|
) |
|
|
|
merging_output.update( |
|
dict( |
|
tokens=m_tokens, |
|
token2char_start=m_token2char_start, |
|
token2char_end=m_token2char_end, |
|
window_labels=window_labels, |
|
predicted_window_labels_chars=predicted_window_labels_chars, |
|
probs_window_labels_chars=probs_window_labels_chars, |
|
) |
|
) |
|
|
|
return RelikReaderSample(**merging_output) |
|
|