|
import logging |
|
from typing import ( |
|
Any, |
|
Callable, |
|
Dict, |
|
Generator, |
|
Iterable, |
|
Iterator, |
|
List, |
|
NamedTuple, |
|
Optional, |
|
Tuple, |
|
Union, |
|
) |
|
|
|
import numpy as np |
|
import torch |
|
from torch.utils.data import IterableDataset |
|
from tqdm import tqdm |
|
from transformers import AutoTokenizer, PreTrainedTokenizer |
|
|
|
from relik.reader.data.relik_reader_data_utils import ( |
|
add_noise_to_value, |
|
batchify, |
|
chunks, |
|
flatten, |
|
) |
|
from relik.reader.data.relik_reader_sample import ( |
|
RelikReaderSample, |
|
load_relik_reader_samples, |
|
) |
|
from relik.reader.utils.special_symbols import NME_SYMBOL |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
def preprocess_dataset( |
|
input_dataset: Iterable[dict], |
|
transformer_model: str, |
|
add_topic: bool, |
|
) -> Iterable[dict]: |
|
tokenizer = AutoTokenizer.from_pretrained(transformer_model) |
|
for dataset_elem in tqdm(input_dataset, desc="Preprocessing input dataset"): |
|
if len(dataset_elem["tokens"]) == 0: |
|
print( |
|
f"Dataset element with doc id: {dataset_elem['doc_id']}", |
|
f"and offset {dataset_elem['offset']} does not contain any token", |
|
"Skipping it", |
|
) |
|
continue |
|
|
|
new_dataset_elem = dict( |
|
doc_id=dataset_elem["doc_id"], |
|
offset=dataset_elem["offset"], |
|
) |
|
|
|
tokenization_out = tokenizer( |
|
dataset_elem["tokens"], |
|
return_offsets_mapping=True, |
|
add_special_tokens=False, |
|
) |
|
|
|
window_tokens = tokenization_out.input_ids |
|
window_tokens = flatten(window_tokens) |
|
|
|
offsets_mapping = [ |
|
[ |
|
( |
|
ss + dataset_elem["token2char_start"][str(i)], |
|
se + dataset_elem["token2char_start"][str(i)], |
|
) |
|
for ss, se in tokenization_out.offset_mapping[i] |
|
] |
|
for i in range(len(dataset_elem["tokens"])) |
|
] |
|
|
|
offsets_mapping = flatten(offsets_mapping) |
|
|
|
assert len(offsets_mapping) == len(window_tokens) |
|
|
|
window_tokens = ( |
|
[tokenizer.cls_token_id] + window_tokens + [tokenizer.sep_token_id] |
|
) |
|
|
|
topic_offset = 0 |
|
if add_topic: |
|
topic_tokens = tokenizer( |
|
dataset_elem["doc_topic"], add_special_tokens=False |
|
).input_ids |
|
topic_offset = len(topic_tokens) |
|
new_dataset_elem["topic_tokens"] = topic_offset |
|
window_tokens = window_tokens[:1] + topic_tokens + window_tokens[1:] |
|
|
|
new_dataset_elem.update( |
|
dict( |
|
tokens=window_tokens, |
|
token2char_start={ |
|
str(i): s |
|
for i, (s, _) in enumerate(offsets_mapping, start=topic_offset) |
|
}, |
|
token2char_end={ |
|
str(i): e |
|
for i, (_, e) in enumerate(offsets_mapping, start=topic_offset) |
|
}, |
|
window_candidates=dataset_elem["window_candidates"], |
|
window_candidates_scores=dataset_elem.get("window_candidates_scores"), |
|
) |
|
) |
|
|
|
if "window_labels" in dataset_elem: |
|
window_labels = [ |
|
(s, e, l.replace("_", " ")) for s, e, l in dataset_elem["window_labels"] |
|
] |
|
|
|
new_dataset_elem["window_labels"] = window_labels |
|
|
|
if not all( |
|
[ |
|
s in new_dataset_elem["token2char_start"].values() |
|
for s, _, _ in new_dataset_elem["window_labels"] |
|
] |
|
): |
|
print( |
|
"Mismatching token start char mapping with labels", |
|
new_dataset_elem["token2char_start"], |
|
new_dataset_elem["window_labels"], |
|
dataset_elem["tokens"], |
|
) |
|
continue |
|
|
|
if not all( |
|
[ |
|
e in new_dataset_elem["token2char_end"].values() |
|
for _, e, _ in new_dataset_elem["window_labels"] |
|
] |
|
): |
|
print( |
|
"Mismatching token end char mapping with labels", |
|
new_dataset_elem["token2char_end"], |
|
new_dataset_elem["window_labels"], |
|
dataset_elem["tokens"], |
|
) |
|
continue |
|
|
|
yield new_dataset_elem |
|
|
|
|
|
def preprocess_sample( |
|
relik_sample: RelikReaderSample, |
|
tokenizer, |
|
lowercase_policy: float, |
|
add_topic: bool = False, |
|
) -> None: |
|
if len(relik_sample.tokens) == 0: |
|
return |
|
|
|
if lowercase_policy > 0: |
|
lc_tokens = np.random.uniform(0, 1, len(relik_sample.tokens)) < lowercase_policy |
|
relik_sample.tokens = [ |
|
t.lower() if lc else t for t, lc in zip(relik_sample.tokens, lc_tokens) |
|
] |
|
|
|
tokenization_out = tokenizer( |
|
relik_sample.tokens, |
|
return_offsets_mapping=True, |
|
add_special_tokens=False, |
|
) |
|
|
|
window_tokens = tokenization_out.input_ids |
|
window_tokens = flatten(window_tokens) |
|
|
|
offsets_mapping = [ |
|
[ |
|
( |
|
ss + relik_sample.token2char_start[str(i)], |
|
se + relik_sample.token2char_start[str(i)], |
|
) |
|
for ss, se in tokenization_out.offset_mapping[i] |
|
] |
|
for i in range(len(relik_sample.tokens)) |
|
] |
|
|
|
offsets_mapping = flatten(offsets_mapping) |
|
|
|
assert len(offsets_mapping) == len(window_tokens) |
|
|
|
window_tokens = [tokenizer.cls_token_id] + window_tokens + [tokenizer.sep_token_id] |
|
|
|
topic_offset = 0 |
|
if add_topic: |
|
topic_tokens = tokenizer( |
|
relik_sample.doc_topic, add_special_tokens=False |
|
).input_ids |
|
topic_offset = len(topic_tokens) |
|
relik_sample.topic_tokens = topic_offset |
|
window_tokens = window_tokens[:1] + topic_tokens + window_tokens[1:] |
|
|
|
relik_sample._d.update( |
|
dict( |
|
tokens=window_tokens, |
|
token2char_start={ |
|
str(i): s |
|
for i, (s, _) in enumerate(offsets_mapping, start=topic_offset) |
|
}, |
|
token2char_end={ |
|
str(i): e |
|
for i, (_, e) in enumerate(offsets_mapping, start=topic_offset) |
|
}, |
|
) |
|
) |
|
|
|
if "window_labels" in relik_sample._d: |
|
relik_sample.window_labels = [ |
|
(s, e, l.replace("_", " ")) for s, e, l in relik_sample.window_labels |
|
] |
|
|
|
|
|
class TokenizationOutput(NamedTuple): |
|
input_ids: torch.Tensor |
|
attention_mask: torch.Tensor |
|
token_type_ids: torch.Tensor |
|
prediction_mask: torch.Tensor |
|
special_symbols_mask: torch.Tensor |
|
|
|
|
|
class RelikDataset(IterableDataset): |
|
def __init__( |
|
self, |
|
dataset_path: Optional[str], |
|
materialize_samples: bool, |
|
transformer_model: Union[str, PreTrainedTokenizer], |
|
special_symbols: List[str], |
|
shuffle_candidates: Optional[Union[bool, float]] = False, |
|
for_inference: bool = False, |
|
noise_param: float = 0.1, |
|
sorting_fields: Optional[str] = None, |
|
tokens_per_batch: int = 2048, |
|
batch_size: int = None, |
|
max_batch_size: int = 128, |
|
section_size: int = 50_000, |
|
prebatch: bool = True, |
|
random_drop_gold_candidates: float = 0.0, |
|
use_nme: bool = True, |
|
max_subwords_per_candidate: bool = 22, |
|
mask_by_instances: bool = False, |
|
min_length: int = 5, |
|
max_length: int = 2048, |
|
model_max_length: int = 1000, |
|
split_on_cand_overload: bool = True, |
|
skip_empty_training_samples: bool = False, |
|
drop_last: bool = False, |
|
samples: Optional[Iterator[RelikReaderSample]] = None, |
|
lowercase_policy: float = 0.0, |
|
**kwargs, |
|
): |
|
super().__init__(**kwargs) |
|
self.dataset_path = dataset_path |
|
self.materialize_samples = materialize_samples |
|
self.samples: Optional[List[RelikReaderSample]] = None |
|
if self.materialize_samples: |
|
self.samples = list() |
|
|
|
if isinstance(transformer_model, str): |
|
self.tokenizer = self._build_tokenizer(transformer_model, special_symbols) |
|
else: |
|
self.tokenizer = transformer_model |
|
self.special_symbols = special_symbols |
|
self.shuffle_candidates = shuffle_candidates |
|
self.for_inference = for_inference |
|
self.noise_param = noise_param |
|
self.batching_fields = ["input_ids"] |
|
self.sorting_fields = ( |
|
sorting_fields if sorting_fields is not None else self.batching_fields |
|
) |
|
|
|
self.tokens_per_batch = tokens_per_batch |
|
self.batch_size = batch_size |
|
self.max_batch_size = max_batch_size |
|
self.section_size = section_size |
|
self.prebatch = prebatch |
|
|
|
self.random_drop_gold_candidates = random_drop_gold_candidates |
|
self.use_nme = use_nme |
|
self.max_subwords_per_candidate = max_subwords_per_candidate |
|
self.mask_by_instances = mask_by_instances |
|
self.min_length = min_length |
|
self.max_length = max_length |
|
self.model_max_length = ( |
|
model_max_length |
|
if model_max_length < self.tokenizer.model_max_length |
|
else self.tokenizer.model_max_length |
|
) |
|
|
|
|
|
self.transformer_model = ( |
|
transformer_model |
|
if isinstance(transformer_model, str) |
|
else transformer_model.name_or_path |
|
) |
|
self.split_on_cand_overload = split_on_cand_overload |
|
self.skip_empty_training_samples = skip_empty_training_samples |
|
self.drop_last = drop_last |
|
self.lowercase_policy = lowercase_policy |
|
self.samples = samples |
|
|
|
def _build_tokenizer(self, transformer_model: str, special_symbols: List[str]): |
|
return AutoTokenizer.from_pretrained( |
|
transformer_model, |
|
additional_special_tokens=[ss for ss in special_symbols], |
|
add_prefix_space=True, |
|
) |
|
|
|
@property |
|
def fields_batcher(self) -> Dict[str, Union[None, Callable[[list], Any]]]: |
|
fields_batchers = { |
|
"input_ids": lambda x: batchify( |
|
x, padding_value=self.tokenizer.pad_token_id |
|
), |
|
"attention_mask": lambda x: batchify(x, padding_value=0), |
|
"token_type_ids": lambda x: batchify(x, padding_value=0), |
|
"prediction_mask": lambda x: batchify(x, padding_value=1), |
|
"global_attention": lambda x: batchify(x, padding_value=0), |
|
"token2word": None, |
|
"sample": None, |
|
"special_symbols_mask": lambda x: batchify(x, padding_value=False), |
|
"start_labels": lambda x: batchify(x, padding_value=-100), |
|
"end_labels": lambda x: batchify(x, padding_value=-100), |
|
"predictable_candidates_symbols": None, |
|
"predictable_candidates": None, |
|
"patch_offset": None, |
|
"optimus_labels": None, |
|
} |
|
|
|
if "roberta" in self.transformer_model: |
|
del fields_batchers["token_type_ids"] |
|
|
|
return fields_batchers |
|
|
|
def _build_input_ids( |
|
self, sentence_input_ids: List[int], candidates_input_ids: List[List[int]] |
|
) -> List[int]: |
|
return ( |
|
[self.tokenizer.cls_token_id] |
|
+ sentence_input_ids |
|
+ [self.tokenizer.sep_token_id] |
|
+ flatten(candidates_input_ids) |
|
+ [self.tokenizer.sep_token_id] |
|
) |
|
|
|
def _get_special_symbols_mask(self, input_ids: torch.Tensor) -> torch.Tensor: |
|
special_symbols_mask = input_ids >= ( |
|
len(self.tokenizer) - len(self.special_symbols) |
|
) |
|
special_symbols_mask[0] = True |
|
return special_symbols_mask |
|
|
|
def _build_tokenizer_essentials( |
|
self, input_ids, original_sequence, sample |
|
) -> TokenizationOutput: |
|
input_ids = torch.tensor(input_ids, dtype=torch.long) |
|
attention_mask = torch.ones_like(input_ids) |
|
|
|
total_sequence_len = len(input_ids) |
|
predictable_sentence_len = len(original_sequence) |
|
|
|
|
|
token_type_ids = torch.cat( |
|
[ |
|
input_ids.new_zeros( |
|
predictable_sentence_len + 2 |
|
), |
|
input_ids.new_ones(total_sequence_len - predictable_sentence_len - 2), |
|
] |
|
) |
|
|
|
|
|
|
|
prediction_mask = torch.tensor( |
|
[1] |
|
+ ([0] * predictable_sentence_len) |
|
+ ([1] * (total_sequence_len - predictable_sentence_len - 1)) |
|
) |
|
|
|
|
|
|
|
topic_tokens = getattr(sample, "topic_tokens", None) |
|
if topic_tokens is not None: |
|
prediction_mask[1 : 1 + topic_tokens] = 1 |
|
|
|
|
|
|
|
if self.mask_by_instances: |
|
char_start2token = { |
|
cs: int(tok) for tok, cs in sample.token2char_start.items() |
|
} |
|
char_end2token = {ce: int(tok) for tok, ce in sample.token2char_end.items()} |
|
instances_mask = torch.ones_like(prediction_mask) |
|
for _, span_info in sample.instance_id2span_data.items(): |
|
span_info = span_info[0] |
|
token_start = char_start2token[span_info[0]] + 1 |
|
token_end = char_end2token[span_info[1]] + 1 |
|
instances_mask[token_start : token_end + 1] = 0 |
|
|
|
prediction_mask += instances_mask |
|
prediction_mask[prediction_mask > 1] = 1 |
|
|
|
assert len(prediction_mask) == len(input_ids) |
|
|
|
|
|
special_symbols_mask = self._get_special_symbols_mask(input_ids) |
|
|
|
return TokenizationOutput( |
|
input_ids, |
|
attention_mask, |
|
token_type_ids, |
|
prediction_mask, |
|
special_symbols_mask, |
|
) |
|
|
|
def _build_labels( |
|
self, |
|
sample, |
|
tokenization_output: TokenizationOutput, |
|
predictable_candidates: List[str], |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
start_labels = [0] * len(tokenization_output.input_ids) |
|
end_labels = [0] * len(tokenization_output.input_ids) |
|
|
|
char_start2token = {v: int(k) for k, v in sample.token2char_start.items()} |
|
char_end2token = {v: int(k) for k, v in sample.token2char_end.items()} |
|
for cs, ce, gold_candidate_title in sample.window_labels: |
|
if gold_candidate_title not in predictable_candidates: |
|
if self.use_nme: |
|
gold_candidate_title = NME_SYMBOL |
|
else: |
|
continue |
|
|
|
start_bpe = char_start2token[cs] + 1 |
|
end_bpe = char_end2token[ce] + 1 |
|
class_index = predictable_candidates.index(gold_candidate_title) |
|
if ( |
|
start_labels[start_bpe] == 0 and end_labels[end_bpe] == 0 |
|
): |
|
start_labels[start_bpe] = class_index + 1 |
|
end_labels[end_bpe] = class_index + 1 |
|
else: |
|
print( |
|
"Found entity with the same last subword, it will not be included." |
|
) |
|
print( |
|
cs, |
|
ce, |
|
gold_candidate_title, |
|
start_labels, |
|
end_labels, |
|
sample.doc_id, |
|
) |
|
|
|
ignored_labels_indices = tokenization_output.prediction_mask == 1 |
|
|
|
start_labels = torch.tensor(start_labels, dtype=torch.long) |
|
start_labels[ignored_labels_indices] = -100 |
|
|
|
end_labels = torch.tensor(end_labels, dtype=torch.long) |
|
end_labels[ignored_labels_indices] = -100 |
|
|
|
return start_labels, end_labels |
|
|
|
def produce_sample_bag( |
|
self, sample, predictable_candidates: List[str], candidates_starting_offset: int |
|
) -> Optional[Tuple[dict, list, int]]: |
|
|
|
input_subwords = sample.tokens[1:-1] |
|
candidates_symbols = self.special_symbols[candidates_starting_offset:] |
|
|
|
predictable_candidates = list(predictable_candidates) |
|
original_predictable_candidates = list(predictable_candidates) |
|
|
|
|
|
if self.use_nme: |
|
predictable_candidates.insert(0, NME_SYMBOL) |
|
|
|
|
|
candidates_symbols = candidates_symbols[: len(predictable_candidates)] |
|
candidates_encoding_result = self.tokenizer.batch_encode_plus( |
|
[ |
|
"{} {}".format(cs, ct) if ct != NME_SYMBOL else NME_SYMBOL |
|
for cs, ct in zip(candidates_symbols, predictable_candidates) |
|
], |
|
add_special_tokens=False, |
|
).input_ids |
|
|
|
if ( |
|
self.max_subwords_per_candidate is not None |
|
and self.max_subwords_per_candidate > 0 |
|
): |
|
candidates_encoding_result = [ |
|
cer[: self.max_subwords_per_candidate] |
|
for cer in candidates_encoding_result |
|
] |
|
|
|
|
|
if ( |
|
sum(map(len, candidates_encoding_result)) |
|
+ len(input_subwords) |
|
+ 20 |
|
> self.model_max_length |
|
): |
|
acceptable_tokens_from_candidates = ( |
|
self.model_max_length - 20 - len(input_subwords) |
|
) |
|
i = 0 |
|
cum_len = 0 |
|
while ( |
|
cum_len + len(candidates_encoding_result[i]) |
|
< acceptable_tokens_from_candidates |
|
): |
|
cum_len += len(candidates_encoding_result[i]) |
|
i += 1 |
|
|
|
candidates_encoding_result = candidates_encoding_result[:i] |
|
candidates_symbols = candidates_symbols[:i] |
|
predictable_candidates = predictable_candidates[:i] |
|
|
|
|
|
input_ids = self._build_input_ids( |
|
sentence_input_ids=input_subwords, |
|
candidates_input_ids=candidates_encoding_result, |
|
) |
|
|
|
|
|
tokenization_output = self._build_tokenizer_essentials( |
|
input_ids, input_subwords, sample |
|
) |
|
|
|
output_dict = { |
|
"input_ids": tokenization_output.input_ids, |
|
"attention_mask": tokenization_output.attention_mask, |
|
"token_type_ids": tokenization_output.token_type_ids, |
|
"prediction_mask": tokenization_output.prediction_mask, |
|
"special_symbols_mask": tokenization_output.special_symbols_mask, |
|
"sample": sample, |
|
"predictable_candidates_symbols": candidates_symbols, |
|
"predictable_candidates": predictable_candidates, |
|
} |
|
|
|
|
|
if sample.window_labels is not None: |
|
start_labels, end_labels = self._build_labels( |
|
sample, |
|
tokenization_output, |
|
predictable_candidates, |
|
) |
|
output_dict.update(start_labels=start_labels, end_labels=end_labels) |
|
|
|
if ( |
|
"roberta" in self.transformer_model |
|
or "longformer" in self.transformer_model |
|
): |
|
del output_dict["token_type_ids"] |
|
|
|
predictable_candidates_set = set(predictable_candidates) |
|
remaining_candidates = [ |
|
candidate |
|
for candidate in original_predictable_candidates |
|
if candidate not in predictable_candidates_set |
|
] |
|
total_used_candidates = ( |
|
candidates_starting_offset |
|
+ len(predictable_candidates) |
|
- (1 if self.use_nme else 0) |
|
) |
|
|
|
if self.use_nme: |
|
assert predictable_candidates[0] == NME_SYMBOL |
|
|
|
return output_dict, remaining_candidates, total_used_candidates |
|
|
|
def __iter__(self): |
|
dataset_iterator = self.dataset_iterator_func() |
|
|
|
current_dataset_elements = [] |
|
|
|
i = None |
|
for i, dataset_elem in enumerate(dataset_iterator, start=1): |
|
if ( |
|
self.section_size is not None |
|
and len(current_dataset_elements) == self.section_size |
|
): |
|
for batch in self.materialize_batches(current_dataset_elements): |
|
yield batch |
|
current_dataset_elements = [] |
|
|
|
current_dataset_elements.append(dataset_elem) |
|
|
|
if i % 50_000 == 0: |
|
logger.info(f"Processed: {i} number of elements") |
|
|
|
if len(current_dataset_elements) != 0: |
|
for batch in self.materialize_batches(current_dataset_elements): |
|
yield batch |
|
|
|
if i is not None: |
|
logger.info(f"Dataset finished: {i} number of elements processed") |
|
else: |
|
logger.warning("Dataset empty") |
|
|
|
def dataset_iterator_func(self): |
|
skipped_instances = 0 |
|
data_samples = ( |
|
load_relik_reader_samples(self.dataset_path) |
|
if self.samples is None |
|
else self.samples |
|
) |
|
for sample in data_samples: |
|
preprocess_sample( |
|
sample, self.tokenizer, lowercase_policy=self.lowercase_policy |
|
) |
|
current_patch = 0 |
|
sample_bag, used_candidates = None, None |
|
remaining_candidates = list(sample.window_candidates) |
|
|
|
if not self.for_inference: |
|
|
|
if ( |
|
self.random_drop_gold_candidates > 0.0 |
|
and np.random.uniform() < self.random_drop_gold_candidates |
|
and len(set(ct for _, _, ct in sample.window_labels)) > 1 |
|
): |
|
|
|
np.random.shuffle(sample.window_labels) |
|
n_dropped_candidates = np.random.randint( |
|
0, len(sample.window_labels) - 1 |
|
) |
|
dropped_candidates = [ |
|
label_elem[-1] |
|
for label_elem in sample.window_labels[:n_dropped_candidates] |
|
] |
|
dropped_candidates = set(dropped_candidates) |
|
|
|
|
|
if NME_SYMBOL in dropped_candidates: |
|
dropped_candidates.remove(NME_SYMBOL) |
|
|
|
|
|
sample.window_labels = [ |
|
(s, e, _l) |
|
if _l not in dropped_candidates |
|
else (s, e, NME_SYMBOL) |
|
for s, e, _l in sample.window_labels |
|
] |
|
remaining_candidates = [ |
|
wc |
|
for wc in remaining_candidates |
|
if wc not in dropped_candidates |
|
] |
|
|
|
|
|
if ( |
|
isinstance(self.shuffle_candidates, bool) |
|
and self.shuffle_candidates |
|
) or ( |
|
isinstance(self.shuffle_candidates, float) |
|
and np.random.uniform() < self.shuffle_candidates |
|
): |
|
np.random.shuffle(remaining_candidates) |
|
|
|
while len(remaining_candidates) != 0: |
|
sample_bag = self.produce_sample_bag( |
|
sample, |
|
predictable_candidates=remaining_candidates, |
|
candidates_starting_offset=used_candidates |
|
if used_candidates is not None |
|
else 0, |
|
) |
|
if sample_bag is not None: |
|
sample_bag, remaining_candidates, used_candidates = sample_bag |
|
if ( |
|
self.for_inference |
|
or not self.skip_empty_training_samples |
|
or ( |
|
( |
|
sample_bag.get("start_labels") is not None |
|
and torch.any(sample_bag["start_labels"] > 1).item() |
|
) |
|
or ( |
|
sample_bag.get("optimus_labels") is not None |
|
and len(sample_bag["optimus_labels"]) > 0 |
|
) |
|
) |
|
): |
|
sample_bag["patch_offset"] = current_patch |
|
current_patch += 1 |
|
yield sample_bag |
|
else: |
|
skipped_instances += 1 |
|
if skipped_instances % 1000 == 0 and skipped_instances != 0: |
|
logger.info( |
|
f"Skipped {skipped_instances} instances since they did not have any gold labels..." |
|
) |
|
|
|
|
|
|
|
if not self.split_on_cand_overload: |
|
break |
|
|
|
def preshuffle_elements(self, dataset_elements: List): |
|
|
|
|
|
|
|
|
|
if not self.for_inference: |
|
dataset_elements = np.random.permutation(dataset_elements) |
|
|
|
sorting_fn = ( |
|
lambda elem: add_noise_to_value( |
|
sum(len(elem[k]) for k in self.sorting_fields), |
|
noise_param=self.noise_param, |
|
) |
|
if not self.for_inference |
|
else sum(len(elem[k]) for k in self.sorting_fields) |
|
) |
|
|
|
dataset_elements = sorted(dataset_elements, key=sorting_fn) |
|
|
|
if self.for_inference: |
|
return dataset_elements |
|
|
|
ds = list(chunks(dataset_elements, 64)) |
|
np.random.shuffle(ds) |
|
return flatten(ds) |
|
|
|
def materialize_batches( |
|
self, dataset_elements: List[Dict[str, Any]] |
|
) -> Generator[Dict[str, Any], None, None]: |
|
if self.prebatch: |
|
dataset_elements = self.preshuffle_elements(dataset_elements) |
|
|
|
current_batch = [] |
|
|
|
|
|
def output_batch() -> Dict[str, Any]: |
|
assert ( |
|
len( |
|
set([len(elem["predictable_candidates"]) for elem in current_batch]) |
|
) |
|
== 1 |
|
), " ".join( |
|
map( |
|
str, [len(elem["predictable_candidates"]) for elem in current_batch] |
|
) |
|
) |
|
|
|
batch_dict = dict() |
|
|
|
de_values_by_field = { |
|
fn: [de[fn] for de in current_batch if fn in de] |
|
for fn in self.fields_batcher |
|
} |
|
|
|
|
|
|
|
de_values_by_field = { |
|
fn: fvs for fn, fvs in de_values_by_field.items() if len(fvs) > 0 |
|
} |
|
|
|
assert len(set([len(v) for v in de_values_by_field.values()])) |
|
|
|
|
|
|
|
de_values_by_field = { |
|
fn: fvs |
|
for fn, fvs in de_values_by_field.items() |
|
if all([fv is not None for fv in fvs]) |
|
} |
|
|
|
for field_name, field_values in de_values_by_field.items(): |
|
field_batch = ( |
|
self.fields_batcher[field_name](field_values) |
|
if self.fields_batcher[field_name] is not None |
|
else field_values |
|
) |
|
|
|
batch_dict[field_name] = field_batch |
|
|
|
return batch_dict |
|
|
|
max_len_discards, min_len_discards = 0, 0 |
|
|
|
should_token_batch = self.batch_size is None |
|
|
|
curr_pred_elements = -1 |
|
for de in dataset_elements: |
|
if ( |
|
should_token_batch |
|
and self.max_batch_size != -1 |
|
and len(current_batch) == self.max_batch_size |
|
) or (not should_token_batch and len(current_batch) == self.batch_size): |
|
yield output_batch() |
|
current_batch = [] |
|
curr_pred_elements = -1 |
|
|
|
too_long_fields = [ |
|
k |
|
for k in de |
|
if self.max_length != -1 |
|
and torch.is_tensor(de[k]) |
|
and len(de[k]) > self.max_length |
|
] |
|
if len(too_long_fields) > 0: |
|
max_len_discards += 1 |
|
continue |
|
|
|
too_short_fields = [ |
|
k |
|
for k in de |
|
if self.min_length != -1 |
|
and torch.is_tensor(de[k]) |
|
and len(de[k]) < self.min_length |
|
] |
|
if len(too_short_fields) > 0: |
|
min_len_discards += 1 |
|
continue |
|
|
|
if should_token_batch: |
|
de_len = sum(len(de[k]) for k in self.batching_fields) |
|
|
|
future_max_len = max( |
|
de_len, |
|
max( |
|
[ |
|
sum(len(bde[k]) for k in self.batching_fields) |
|
for bde in current_batch |
|
], |
|
default=0, |
|
), |
|
) |
|
|
|
future_tokens_per_batch = future_max_len * (len(current_batch) + 1) |
|
|
|
num_predictable_candidates = len(de["predictable_candidates"]) |
|
|
|
if len(current_batch) > 0 and ( |
|
future_tokens_per_batch >= self.tokens_per_batch |
|
or ( |
|
num_predictable_candidates != curr_pred_elements |
|
and curr_pred_elements != -1 |
|
) |
|
): |
|
yield output_batch() |
|
current_batch = [] |
|
|
|
current_batch.append(de) |
|
curr_pred_elements = len(de["predictable_candidates"]) |
|
|
|
if len(current_batch) != 0 and not self.drop_last: |
|
yield output_batch() |
|
|
|
if max_len_discards > 0: |
|
if self.for_inference: |
|
logger.warning( |
|
f"WARNING: Inference mode is True but {max_len_discards} samples longer than max length were " |
|
f"found. The {max_len_discards} samples will be DISCARDED. If you are doing some kind of evaluation" |
|
f", this can INVALIDATE results. This might happen if the max length was not set to -1 or if the " |
|
f"sample length exceeds the maximum length supported by the current model." |
|
) |
|
else: |
|
logger.warning( |
|
f"During iteration, {max_len_discards} elements were " |
|
f"discarded since longer than max length {self.max_length}" |
|
) |
|
|
|
if min_len_discards > 0: |
|
if self.for_inference: |
|
logger.warning( |
|
f"WARNING: Inference mode is True but {min_len_discards} samples shorter than min length were " |
|
f"found. The {min_len_discards} samples will be DISCARDED. If you are doing some kind of evaluation" |
|
f", this can INVALIDATE results. This might happen if the min length was not set to -1 or if the " |
|
f"sample length is shorter than the minimum length supported by the current model." |
|
) |
|
else: |
|
logger.warning( |
|
f"During iteration, {min_len_discards} elements were " |
|
f"discarded since shorter than min length {self.min_length}" |
|
) |
|
|
|
@staticmethod |
|
def convert_tokens_to_char_annotations( |
|
sample: RelikReaderSample, |
|
remove_nmes: bool = True, |
|
) -> RelikReaderSample: |
|
""" |
|
Converts the token annotations to char annotations. |
|
|
|
Args: |
|
sample (:obj:`RelikReaderSample`): |
|
The sample to convert. |
|
remove_nmes (:obj:`bool`, `optional`, defaults to :obj:`True`): |
|
Whether to remove the NMEs from the annotations. |
|
Returns: |
|
:obj:`RelikReaderSample`: The converted sample. |
|
""" |
|
char_annotations = set() |
|
for ( |
|
predicted_entity, |
|
predicted_spans, |
|
) in sample.predicted_window_labels.items(): |
|
if predicted_entity == NME_SYMBOL and remove_nmes: |
|
continue |
|
|
|
for span_start, span_end in predicted_spans: |
|
span_start = sample.token2char_start[str(span_start)] |
|
span_end = sample.token2char_end[str(span_end)] |
|
|
|
char_annotations.add((span_start, span_end, predicted_entity)) |
|
|
|
char_probs_annotations = dict() |
|
for ( |
|
span_start, |
|
span_end, |
|
), candidates_probs in sample.span_title_probabilities.items(): |
|
span_start = sample.token2char_start[str(span_start)] |
|
span_end = sample.token2char_end[str(span_end)] |
|
char_probs_annotations[(span_start, span_end)] = { |
|
title for title, _ in candidates_probs |
|
} |
|
|
|
sample.predicted_window_labels_chars = char_annotations |
|
sample.probs_window_labels_chars = char_probs_annotations |
|
|
|
return sample |
|
|
|
@staticmethod |
|
def merge_patches_predictions(sample) -> None: |
|
sample._d["predicted_window_labels"] = dict() |
|
predicted_window_labels = sample._d["predicted_window_labels"] |
|
|
|
sample._d["span_title_probabilities"] = dict() |
|
span_title_probabilities = sample._d["span_title_probabilities"] |
|
|
|
span2title = dict() |
|
for _, patch_info in sorted(sample.patches.items(), key=lambda x: x[0]): |
|
|
|
for predicted_title, predicted_spans in patch_info[ |
|
"predicted_window_labels" |
|
].items(): |
|
for pred_span in predicted_spans: |
|
pred_span = tuple(pred_span) |
|
curr_title = span2title.get(pred_span) |
|
if curr_title is None or curr_title == NME_SYMBOL: |
|
span2title[pred_span] = predicted_title |
|
|
|
|
|
|
|
|
|
for predicted_span, titles_probabilities in patch_info[ |
|
"span_title_probabilities" |
|
].items(): |
|
if predicted_span not in span_title_probabilities: |
|
span_title_probabilities[predicted_span] = titles_probabilities |
|
|
|
for span, title in span2title.items(): |
|
if title not in predicted_window_labels: |
|
predicted_window_labels[title] = list() |
|
predicted_window_labels[title].append(span) |
|
|