CarlosMalaga's picture
Upload 201 files
2f044c1 verified
raw
history blame
16.8 kB
import collections
import itertools
from typing import Dict, List, Optional, Set, Tuple
from relik.inference.data.splitters.blank_sentence_splitter import BlankSentenceSplitter
from relik.inference.data.splitters.base_sentence_splitter import BaseSentenceSplitter
from relik.inference.data.tokenizers.base_tokenizer import BaseTokenizer
from relik.reader.data.relik_reader_sample import RelikReaderSample
class WindowManager:
def __init__(
self, tokenizer: BaseTokenizer, splitter: BaseSentenceSplitter | None = None
) -> None:
self.tokenizer = tokenizer
self.splitter = splitter or BlankSentenceSplitter()
def create_windows(
self,
documents: str | List[str],
window_size: int | None = None,
stride: int | None = None,
max_length: int | None = None,
doc_id: str | int | None = None,
doc_topic: str | None = None,
is_split_into_words: bool = False,
mentions: List[List[List[int]]] = None,
) -> Tuple[List[RelikReaderSample], List[RelikReaderSample]]:
"""
Create windows from a list of documents.
Args:
documents (:obj:`str` or :obj:`List[str]`):
The document(s) to split in windows.
window_size (:obj:`int`):
The size of the window.
stride (:obj:`int`):
The stride between two windows.
max_length (:obj:`int`, `optional`):
The maximum length of a window.
doc_id (:obj:`str` or :obj:`int`, `optional`):
The id of the document(s).
doc_topic (:obj:`str`, `optional`):
The topic of the document(s).
is_split_into_words (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether the input is already pre-tokenized (e.g., split into words). If :obj:`False`, the
input will first be tokenized using the tokenizer, then the tokens will be split into words.
mentions (:obj:`List[List[List[int]]]`, `optional`):
The mentions of the document(s).
Returns:
:obj:`List[RelikReaderSample]`: The windows created from the documents.
"""
# normalize input
if isinstance(documents, str) or is_split_into_words:
documents = [documents]
# batch tokenize
documents_tokens = self.tokenizer(
documents, is_split_into_words=is_split_into_words
)
# set splitter params
if hasattr(self.splitter, "window_size"):
self.splitter.window_size = window_size or self.splitter.window_size
if hasattr(self.splitter, "window_stride"):
self.splitter.window_stride = stride or self.splitter.window_stride
windowed_documents, windowed_blank_documents = [], []
if mentions is not None:
assert len(documents) == len(
mentions
), f"documents and mentions should have the same length, got {len(documents)} and {len(mentions)}"
doc_iter = zip(documents, documents_tokens, mentions)
else:
doc_iter = zip(documents, documents_tokens, itertools.repeat([]))
for infered_doc_id, (document, document_tokens, document_mentions) in enumerate(
doc_iter
):
if doc_topic is None:
doc_topic = document_tokens[0] if len(document_tokens) > 0 else ""
if doc_id is None:
doc_id = infered_doc_id
splitted_document = self.splitter(document_tokens, max_length=max_length)
document_windows = []
for window_id, window in enumerate(splitted_document):
window_text_start = window[0].idx
window_text_end = window[-1].idx + len(window[-1].text)
if isinstance(document, str):
text = document[window_text_start:window_text_end]
else:
# window_text_start = window[0].idx
# window_text_end = window[-1].i
text = " ".join([w.text for w in window])
sample = RelikReaderSample(
doc_id=doc_id,
window_id=window_id,
text=text,
tokens=[w.text for w in window],
words=[w.text for w in window],
doc_topic=doc_topic,
offset=window_text_start,
spans=[
[m[0], m[1]] for m in document_mentions
if window_text_end > m[0] >= window_text_start and window_text_end >= m[1] >= window_text_start
],
token2char_start={str(i): w.idx for i, w in enumerate(window)},
token2char_end={
str(i): w.idx + len(w.text) for i, w in enumerate(window)
},
char2token_start={
str(w.idx): w.i for i, w in enumerate(window)
},
char2token_end={
str(w.idx + len(w.text)): w.i for i, w in enumerate(window)
},
)
if mentions is not None and len(sample.spans) == 0:
windowed_blank_documents.append(sample)
else:
document_windows.append(sample)
windowed_documents.extend(document_windows)
if mentions is not None:
return windowed_documents, windowed_blank_documents
else:
return windowed_documents, windowed_blank_documents
def merge_windows(
self, windows: List[RelikReaderSample]
) -> List[RelikReaderSample]:
windows_by_doc_id = collections.defaultdict(list)
for window in windows:
windows_by_doc_id[window.doc_id].append(window)
merged_window_by_doc = {
doc_id: self._merge_doc_windows(doc_windows)
for doc_id, doc_windows in windows_by_doc_id.items()
}
return list(merged_window_by_doc.values())
def _merge_doc_windows(self, windows: List[RelikReaderSample]) -> RelikReaderSample:
if len(windows) == 1:
return windows[0]
if len(windows) > 0 and getattr(windows[0], "offset", None) is not None:
windows = sorted(windows, key=(lambda x: x.offset))
window_accumulator = windows[0]
for next_window in windows[1:]:
window_accumulator = self._merge_window_pair(
window_accumulator, next_window
)
return window_accumulator
@staticmethod
def _merge_tokens(
window1: RelikReaderSample, window2: RelikReaderSample
) -> Tuple[list, dict, dict]:
w1_tokens = window1.tokens[1:-1]
w2_tokens = window2.tokens[1:-1]
# find intersection if any
tokens_intersection = 0
for k in reversed(range(1, len(w1_tokens))):
if w1_tokens[-k:] == w2_tokens[:k]:
tokens_intersection = k
break
final_tokens = (
[window1.tokens[0]] # CLS
+ w1_tokens
+ w2_tokens[tokens_intersection:]
+ [window1.tokens[-1]] # SEP
)
w2_starting_offset = len(w1_tokens) - tokens_intersection
def merge_char_mapping(t2c1: dict, t2c2: dict) -> dict:
final_t2c = dict()
final_t2c.update(t2c1)
for t, c in t2c2.items():
t = int(t)
if t < tokens_intersection:
continue
final_t2c[str(t + w2_starting_offset)] = c
return final_t2c
return (
final_tokens,
merge_char_mapping(window1.token2char_start, window2.token2char_start),
merge_char_mapping(window1.token2char_end, window2.token2char_end),
)
@staticmethod
def _merge_words(
window1: RelikReaderSample, window2: RelikReaderSample
) -> Tuple[list, dict, dict]:
w1_words = window1.words
w2_words = window2.words
# find intersection if any
words_intersection = 0
for k in reversed(range(1, len(w1_words))):
if w1_words[-k:] == w2_words[:k]:
words_intersection = k
break
final_words = w1_words + w2_words[words_intersection:]
w2_starting_offset = len(w1_words) - words_intersection
def merge_word_mapping(t2c1: dict, t2c2: dict) -> dict:
final_t2c = dict()
if t2c1 is None:
t2c1 = dict()
if t2c2 is None:
t2c2 = dict()
final_t2c.update(t2c1)
for t, c in t2c2.items():
t = int(t)
if t < words_intersection:
continue
final_t2c[str(t + w2_starting_offset)] = c
return final_t2c
return (
final_words,
merge_word_mapping(window1.token2word_start, window2.token2word_start),
merge_word_mapping(window1.token2word_end, window2.token2word_end),
)
@staticmethod
def _merge_span_annotation(
span_annotation1: List[list], span_annotation2: List[list]
) -> List[list]:
uniq_store = set()
final_span_annotation_store = []
for span_annotation in itertools.chain(span_annotation1, span_annotation2):
span_annotation_id = tuple(span_annotation)
if span_annotation_id not in uniq_store:
uniq_store.add(span_annotation_id)
final_span_annotation_store.append(span_annotation)
return sorted(final_span_annotation_store, key=lambda x: x[0])
@staticmethod
def _merge_predictions(
window1: RelikReaderSample, window2: RelikReaderSample
) -> Tuple[Set[Tuple[int, int, str]], dict]:
# a RelikReaderSample should have a filed called `predicted_spans`
# that stores the span-level predictions, or a filed called
# `predicted_triples` that stores the triple-level predictions
# span predictions
merged_span_predictions: Set = set()
merged_span_probabilities = dict()
# triple predictions
merged_triplet_predictions: Set = set()
merged_triplet_probs: Dict = dict()
if (
getattr(window1, "predicted_spans", None) is not None
and getattr(window2, "predicted_spans", None) is not None
):
merged_span_predictions = set(window1.predicted_spans).union(
set(window2.predicted_spans)
)
merged_span_predictions = sorted(merged_span_predictions)
# probabilities
for span_prediction, predicted_probs in itertools.chain(
window1.probs_window_labels_chars.items()
if window1.probs_window_labels_chars is not None
else [],
window2.probs_window_labels_chars.items()
if window2.probs_window_labels_chars is not None
else [],
):
if span_prediction not in merged_span_probabilities:
merged_span_probabilities[span_prediction] = predicted_probs
if (
getattr(window1, "predicted_triples", None) is not None
and getattr(window2, "predicted_triples", None) is not None
):
# try to merge the triples predictions
# add offset to the second window
window1_triplets = [
(
merged_span_predictions.index(window1.predicted_spans[t[0]]),
t[1],
merged_span_predictions.index(window1.predicted_spans[t[2]]),
t[3]
)
for t in window1.predicted_triples
]
window2_triplets = [
(
merged_span_predictions.index(window2.predicted_spans[t[0]]),
t[1],
merged_span_predictions.index(window2.predicted_spans[t[2]]),
t[3]
)
for t in window2.predicted_triples
]
merged_triplet_predictions = set(window1_triplets).union(
set(window2_triplets)
)
merged_triplet_predictions = sorted(merged_triplet_predictions)
# for now no triplet probs, we don't need them for the moment
return (
merged_span_predictions,
merged_span_probabilities,
merged_triplet_predictions,
merged_triplet_probs,
)
@staticmethod
def _merge_candidates(window1: RelikReaderSample, window2: RelikReaderSample):
candidates = []
windows_candidates = []
# TODO: retro-compatibility
if getattr(window1, "candidates", None) is not None:
candidates = window1.candidates
if getattr(window2, "candidates", None) is not None:
candidates += window2.candidates
# TODO: retro-compatibility
if getattr(window1, "windows_candidates", None) is not None:
windows_candidates = window1.windows_candidates
if getattr(window2, "windows_candidates", None) is not None:
windows_candidates += window2.windows_candidates
# TODO: add programmatically
span_candidates = []
if getattr(window1, "span_candidates", None) is not None:
span_candidates = window1.span_candidates
if getattr(window2, "span_candidates", None) is not None:
span_candidates += window2.span_candidates
triplet_candidates = []
if getattr(window1, "triplet_candidates", None) is not None:
triplet_candidates = window1.triplet_candidates
if getattr(window2, "triplet_candidates", None) is not None:
triplet_candidates += window2.triplet_candidates
# make them unique
candidates = list(set(candidates))
windows_candidates = list(set(windows_candidates))
span_candidates = list(set(span_candidates))
triplet_candidates = list(set(triplet_candidates))
return candidates, windows_candidates, span_candidates, triplet_candidates
def _merge_window_pair(
self,
window1: RelikReaderSample,
window2: RelikReaderSample,
) -> RelikReaderSample:
merging_output = dict()
if getattr(window1, "doc_id", None) is not None:
assert window1.doc_id == window2.doc_id
if getattr(window1, "offset", None) is not None:
assert (
window1.offset < window2.offset
), f"window 2 offset ({window2.offset}) is smaller that window 1 offset({window1.offset})"
merging_output["doc_id"] = window1.doc_id
merging_output["offset"] = window2.offset
m_tokens, m_token2char_start, m_token2char_end = self._merge_tokens(
window1, window2
)
m_words, m_token2word_start, m_token2word_end = self._merge_words(
window1, window2
)
(
m_candidates,
m_windows_candidates,
m_span_candidates,
m_triplet_candidates,
) = self._merge_candidates(window1, window2)
window_labels = None
if getattr(window1, "window_labels", None) is not None:
window_labels = self._merge_span_annotation(
window1.window_labels, window2.window_labels
)
(
predicted_spans,
predicted_spans_probs,
predicted_triples,
predicted_triples_probs,
) = self._merge_predictions(window1, window2)
merging_output.update(
dict(
tokens=m_tokens,
words=m_words,
token2char_start=m_token2char_start,
token2char_end=m_token2char_end,
token2word_start=m_token2word_start,
token2word_end=m_token2word_end,
window_labels=window_labels,
candidates=m_candidates,
span_candidates=m_span_candidates,
triplet_candidates=m_triplet_candidates,
windows_candidates=m_windows_candidates,
predicted_spans=predicted_spans,
predicted_spans_probs=predicted_spans_probs,
predicted_triples=predicted_triples,
predicted_triples_probs=predicted_triples_probs,
)
)
return RelikReaderSample(**merging_output)