CarlosMalaga's picture
Upload 201 files
2f044c1 verified
raw
history blame
47.7 kB
import logging
from typing import (
Any,
Callable,
Dict,
Generator,
Iterator,
List,
NamedTuple,
Optional,
Tuple,
Union,
)
import numpy as np
import torch
import tqdm
from torch.utils.data import IterableDataset
from transformers import AutoTokenizer, PreTrainedTokenizer
from relik.reader.data.relik_reader_data_utils import (
add_noise_to_value,
batchify,
batchify_matrices,
batchify_tensor,
chunks,
flatten,
)
from relik.reader.data.relik_reader_sample import (
RelikReaderSample,
load_relik_reader_samples,
)
from relik.reader.utils.special_symbols import NME_SYMBOL
logger = logging.getLogger(__name__)
class TokenizationOutput(NamedTuple):
input_ids: torch.Tensor
attention_mask: torch.Tensor
token_type_ids: torch.Tensor
prediction_mask: torch.Tensor
special_symbols_mask: torch.Tensor
special_symbols_mask_entities: torch.Tensor
class RelikREDataset(IterableDataset):
def __init__(
self,
dataset_path: str,
materialize_samples: bool,
transformer_model: Union[str, PreTrainedTokenizer],
special_symbols: List[str],
shuffle_candidates: Optional[Union[bool, float]] = False,
flip_candidates: Optional[Union[bool, float]] = False,
for_inference: bool = False,
special_symbols_types=None,
noise_param: float = 0.1,
sorting_fields: Optional[str] = None,
tokens_per_batch: int = 2048,
batch_size: int = None,
max_batch_size: int = 128,
section_size: int = 500_000,
prebatch: bool = True,
add_gold_candidates: bool = True,
use_nme: bool = False,
min_length: int = -1,
max_length: int = 2048,
max_triplets: int = 50,
max_spans: int = 100,
model_max_length: int = 2048,
skip_empty_training_samples: bool = True,
drop_last: bool = False,
samples: Optional[Iterator[RelikReaderSample]] = None,
**kwargs,
):
super().__init__(**kwargs)
# mutable default arguments
if special_symbols_types is None:
special_symbols_types = []
self.dataset_path = dataset_path
self.materialize_samples = materialize_samples
self.samples: Optional[List[RelikReaderSample]] = samples
if self.materialize_samples and self.samples is None:
self.samples = list()
if isinstance(transformer_model, str):
self.tokenizer = self._build_tokenizer(
transformer_model, special_symbols + special_symbols_types
)
else:
self.tokenizer = transformer_model
self.special_symbols = special_symbols
self.special_symbols_types = special_symbols_types
self.shuffle_candidates = shuffle_candidates
self.flip_candidates = flip_candidates
self.for_inference = for_inference
self.noise_param = noise_param
self.batching_fields = ["input_ids"]
self.sorting_fields = (
sorting_fields if sorting_fields is not None else self.batching_fields
)
self.add_gold_candidates = add_gold_candidates
self.use_nme = use_nme
self.min_length = min_length
self.max_length = max_length
self.model_max_length = (
model_max_length
if model_max_length < self.tokenizer.model_max_length
else self.tokenizer.model_max_length
)
self.transformer_model = transformer_model
self.skip_empty_training_samples = skip_empty_training_samples
self.drop_last = drop_last
self.tokens_per_batch = tokens_per_batch
self.batch_size = batch_size
self.max_batch_size = max_batch_size
self.max_triplets = max_triplets
self.max_spans = max_spans
self.section_size = section_size
self.prebatch = prebatch
def _build_tokenizer(self, transformer_model: str, special_symbols: List[str]):
return AutoTokenizer.from_pretrained(
transformer_model,
additional_special_tokens=[ss for ss in special_symbols],
add_prefix_space=True,
)
@staticmethod
def get_special_symbols_re(num_entities: int, use_nme: bool = False) -> List[str]:
if use_nme:
return [NME_SYMBOL] + [f"[R-{i}]" for i in range(num_entities)]
else:
return [f"[R-{i}]" for i in range(num_entities)]
@staticmethod
def get_special_symbols(num_entities: int) -> List[str]:
return [NME_SYMBOL] + [f"[E-{i}]" for i in range(num_entities)]
@property
def fields_batcher(self) -> Dict[str, Union[None, Callable[[list], Any]]]:
fields_batchers = {
"input_ids": lambda x: batchify(
x, padding_value=self.tokenizer.pad_token_id
),
"attention_mask": lambda x: batchify(x, padding_value=0),
"token_type_ids": lambda x: batchify(x, padding_value=0),
"prediction_mask": lambda x: batchify(x, padding_value=1),
"global_attention": lambda x: batchify(x, padding_value=0),
"token2word": None,
"sample": None,
"special_symbols_mask": lambda x: batchify(x, padding_value=False),
"special_symbols_mask_entities": lambda x: batchify(x, padding_value=False),
"start_labels": lambda x: batchify(x, padding_value=-100),
"end_labels": lambda x: batchify_matrices(x, padding_value=-100),
"disambiguation_labels": lambda x: batchify(x, padding_value=-100),
"relation_labels": lambda x: batchify_tensor(x, padding_value=-100),
"predictable_candidates": None,
}
if (
isinstance(self.transformer_model, str)
and "roberta" in self.transformer_model
) or (
isinstance(self.transformer_model, PreTrainedTokenizer)
and "roberta" in self.transformer_model.config.model_type
):
del fields_batchers["token_type_ids"]
return fields_batchers
def _build_input_ids(
self, sentence_input_ids: List[int], candidates_input_ids: List[List[int]]
) -> List[int]:
return (
[self.tokenizer.cls_token_id]
+ sentence_input_ids
+ [self.tokenizer.sep_token_id]
+ flatten(candidates_input_ids)
+ [self.tokenizer.sep_token_id]
)
def _build_input(self, text: List[str], candidates: List[List[str]]) -> List[int]:
return (
text
+ [self.tokenizer.sep_token]
+ flatten(candidates)
+ [self.tokenizer.sep_token]
)
def _build_tokenizer_essentials(
self, input_ids, original_sequence, ents=0
) -> TokenizationOutput:
input_ids = torch.tensor(input_ids, dtype=torch.long)
attention_mask = torch.ones_like(input_ids)
if len(self.special_symbols_types) > 0:
# special symbols mask
special_symbols_mask = input_ids >= self.tokenizer.vocab_size
# select only the first N true values where N is len(entities_definitions)
special_symbols_mask_entities = special_symbols_mask.clone()
special_symbols_mask_entities[
special_symbols_mask_entities.cumsum(0) > ents
] = False
token_type_ids = (torch.cumsum(special_symbols_mask, dim=0) > 0).long()
special_symbols_mask = special_symbols_mask ^ special_symbols_mask_entities
else:
special_symbols_mask = input_ids >= self.tokenizer.vocab_size
special_symbols_mask_entities = special_symbols_mask.clone()
token_type_ids = (torch.cumsum(special_symbols_mask, dim=0) > 0).long()
prediction_mask = token_type_ids.roll(shifts=-1, dims=0)
prediction_mask[-1] = 1
prediction_mask[0] = 1
assert len(prediction_mask) == len(input_ids)
return TokenizationOutput(
input_ids,
attention_mask,
token_type_ids,
prediction_mask,
special_symbols_mask,
special_symbols_mask_entities,
)
@staticmethod
def _subindex(lst, target_values, dims):
for i, sublist in enumerate(lst):
match = all(sublist[dim] == target_values[dim] for dim in dims)
if match:
return i
def _build_labels(
self,
sample,
tokenization_output: TokenizationOutput,
) -> Tuple[torch.Tensor, torch.Tensor]:
start_labels = [0] * len(tokenization_output.input_ids)
end_labels = []
end_labels_tensor = [0] * len(tokenization_output.input_ids)
sample.entities.sort(key=lambda x: (x[0], x[1]))
prev_start_bpe = -1
entities_untyped = list(set([(ce[0], ce[1]) for ce in sample.entities]))
entities_untyped.sort(key=lambda x: (x[0], x[1]))
if len(self.special_symbols_types) > 0:
sample.entities = [(ce[0], ce[1], ce[2]) for ce in sample.entities]
disambiguation_labels = torch.zeros(
len(entities_untyped),
len(sample.span_candidates) + len(sample.triplet_candidates),
)
else:
sample.entities = [(ce[0], ce[1], "") for ce in sample.entities]
disambiguation_labels = torch.zeros(
len(entities_untyped), len(sample.triplet_candidates)
)
ignored_labels_indices = tokenization_output.prediction_mask == 1
offset = 0
for idx, c_ent in enumerate(sample.entities):
while len(sample.word2token[c_ent[0]]) == 0:
c_ent = (c_ent[0] + 1, c_ent[1], c_ent[2])
if len(sample.word2token) == c_ent[0]:
c_ent = None
break
if c_ent is None:
continue
while len(sample.word2token[c_ent[1] - 1]) == 0:
c_ent = (c_ent[0], c_ent[1] + 1, c_ent[2])
if len(sample.word2token) == c_ent[1]:
c_ent = None
break
if c_ent is None:
continue
start_bpe = sample.word2token[c_ent[0]][0] + 1
end_bpe = sample.word2token[c_ent[1] - 1][-1] + 1
class_index = idx
start_labels[start_bpe] = class_index + 1 # +1 for the NONE class
if start_bpe != prev_start_bpe:
end_labels.append(end_labels_tensor.copy())
end_labels[-1][:start_bpe] = [-100] * start_bpe
end_labels[-1][end_bpe] = class_index + 1
elif end_labels[-1][end_bpe] == 0:
end_labels[-1][end_bpe] = class_index + 1
else:
offset += 1
prev_start_bpe = start_bpe
continue
if len(self.special_symbols_types) > 0:
if c_ent[2] in sample.span_candidates:
entity_type_idx = sample.span_candidates.index(c_ent[2])
else:
entity_type_idx = 0
disambiguation_labels[idx - offset, entity_type_idx] = 1
prev_start_bpe = start_bpe
start_labels = torch.tensor(start_labels, dtype=torch.long)
start_labels[ignored_labels_indices] = -100
end_labels = torch.tensor(end_labels, dtype=torch.long)
end_labels[ignored_labels_indices.repeat(len(end_labels), 1)] = -100
relation_labels = torch.zeros(
len(entities_untyped), len(entities_untyped), len(sample.triplet_candidates)
)
for re in sample.triplets:
if re["relation"]["name"] not in sample.triplet_candidates:
re_class_index = len(sample.triplet_candidates) - 1
else:
re_class_index = sample.triplet_candidates.index(re["relation"]["name"])
subject_class_index = self._subindex(
entities_untyped, (re["subject"]["start"], re["subject"]["end"]), (0, 1)
)
object_class_index = self._subindex(
entities_untyped, (re["object"]["start"], re["object"]["end"]), (0, 1)
)
relation_labels[subject_class_index, object_class_index, re_class_index] = 1
if len(self.special_symbols_types) > 0:
disambiguation_labels[
subject_class_index, re_class_index + len(sample.span_candidates)
] = 1
disambiguation_labels[
object_class_index, re_class_index + len(sample.span_candidates)
] = 1
else:
disambiguation_labels[subject_class_index, re_class_index] = 1
disambiguation_labels[object_class_index, re_class_index] = 1
return start_labels, end_labels, disambiguation_labels, relation_labels
def __iter__(self):
dataset_iterator = self.dataset_iterator_func()
current_dataset_elements = []
i = None
for i, dataset_elem in enumerate(dataset_iterator, start=1):
if (
self.section_size is not None
and len(current_dataset_elements) == self.section_size
):
for batch in self.materialize_batches(current_dataset_elements):
yield batch
current_dataset_elements = []
current_dataset_elements.append(dataset_elem)
if i % 50_000 == 0:
logger.info(f"Processed: {i} number of elements")
if len(current_dataset_elements) != 0:
for batch in self.materialize_batches(current_dataset_elements):
yield batch
if i is not None:
logger.debug(f"Dataset finished: {i} number of elements processed")
else:
logger.warning("Dataset empty")
def dataset_iterator_func(self):
data_samples = (
load_relik_reader_samples(self.dataset_path)
if self.samples is None
or (isinstance(self.samples, list) and len(self.samples) == 0)
else self.samples
)
if self.materialize_samples:
data_acc = []
# take care of the tqdm nesting
# for sample in tqdm.tqdm(data_samples, desc="Reading dataset"):
for sample in data_samples:
if self.materialize_samples and sample.materialize is not None:
# tokenization_output = sample.materialize["tokenization_output"]
materialized = sample.materialize
del sample.materialize
yield {
"input_ids": materialized["tokenization_output"].input_ids,
"attention_mask": materialized[
"tokenization_output"
].attention_mask,
"token_type_ids": materialized[
"tokenization_output"
].token_type_ids,
"prediction_mask": materialized[
"tokenization_output"
].prediction_mask,
"special_symbols_mask": materialized[
"tokenization_output"
].special_symbols_mask,
"special_symbols_mask_entities": materialized[
"tokenization_output"
].special_symbols_mask_entities,
"sample": sample,
"start_labels": materialized["start_labels"],
"end_labels": materialized["end_labels"],
"disambiguation_labels": materialized["disambiguation_labels"],
"relation_labels": materialized["relation_labels"],
"predictable_candidates": materialized["candidates_symbols"],
}
sample.materialize = materialized
data_acc.append(sample)
continue
candidates_symbols = self.special_symbols
candidates_entities_symbols = self.special_symbols_types
# sample.candidates = sample.candidates[: self.max_candidates]
if len(self.special_symbols_types) > 0:
# sample.span_candidates = sample.span_candidates[
# : self.max_ent_candidates
# ]
# add NME as a possible candidate
assert sample.span_candidates is not None
if self.use_nme:
sample.span_candidates.insert(0, NME_SYMBOL)
# sample.candidates.insert(0, NME_SYMBOL)
sample.triplet_candidates = sample.triplet_candidates[
: min(len(candidates_symbols), self.max_triplets)
]
if len(self.special_symbols_types) > 0:
sample.span_candidates = sample.span_candidates[
: min(len(candidates_entities_symbols), self.max_spans)
]
# training time sample mods
if not self.for_inference:
# check whether the sample has labels if not skip
if (
sample.triplets is None or len(sample.triplets) == 0
) and self.skip_empty_training_samples:
logger.warning(
"Sample {} has no labels, skipping".format(sample.id)
)
continue
# add gold candidates if missing
if self.add_gold_candidates:
candidates_set = set(sample.triplet_candidates)
candidates_to_add = set()
for candidate_title in sample.triplets:
if candidate_title["relation"]["name"] not in candidates_set:
candidates_to_add.add(candidate_title["relation"]["name"])
if len(candidates_to_add) > 0:
# replacing last candidates with the gold ones
# this is done in order to preserve the ordering
candidates_to_add = list(candidates_to_add)
added_gold_candidates = 0
gold_candidates_titles_set = set(
set(ct["relation"]["name"] for ct in sample.triplets)
)
for i in reversed(range(len(sample.triplet_candidates))):
if (
sample.triplet_candidates[i]
not in gold_candidates_titles_set
and sample.triplet_candidates[i] != NME_SYMBOL
):
sample.triplet_candidates[i] = candidates_to_add[
added_gold_candidates
]
added_gold_candidates += 1
if len(candidates_to_add) == added_gold_candidates:
break
candidates_still_to_add = (
len(candidates_to_add) - added_gold_candidates
)
while (
len(sample.triplet_candidates)
<= min(len(candidates_symbols), self.max_triplets)
and candidates_still_to_add != 0
):
sample.triplet_candidates.append(
candidates_to_add[added_gold_candidates]
)
added_gold_candidates += 1
candidates_still_to_add -= 1
def shuffle_cands(shuffle_candidates, candidates):
if (
isinstance(shuffle_candidates, bool) and shuffle_candidates
) or (
isinstance(shuffle_candidates, float)
and np.random.uniform() < shuffle_candidates
):
np.random.shuffle(candidates)
if NME_SYMBOL in candidates:
candidates.remove(NME_SYMBOL)
candidates.insert(0, NME_SYMBOL)
return candidates
def flip_cands(flip_candidates, candidates):
# flip candidates
if (isinstance(flip_candidates, bool) and flip_candidates) or (
isinstance(flip_candidates, float)
and np.random.uniform() < flip_candidates
):
for i in range(len(candidates) - 1):
if np.random.uniform() < 0.5:
candidates[i], candidates[i + 1] = (
candidates[i + 1],
candidates[i],
)
if NME_SYMBOL in candidates:
candidates.remove(NME_SYMBOL)
candidates.insert(0, NME_SYMBOL)
return candidates
if self.shuffle_candidates:
sample.triplet_candidates = shuffle_cands(
self.shuffle_candidates, sample.triplet_candidates
)
if len(self.special_symbols_types) > 0:
sample.span_candidates = shuffle_cands(
self.shuffle_candidates, sample.span_candidates
)
elif self.flip_candidates:
sample.triplet_candidates = flip_cands(
self.flip_candidates, sample.triplet_candidates
)
if len(self.special_symbols_types) > 0:
sample.span_candidates = flip_cands(
self.flip_candidates, sample.span_candidates
)
# candidates encoding
candidates_symbols = candidates_symbols[: len(sample.triplet_candidates)]
candidates_encoding = [
["{} {}".format(cs, ct)] if ct != NME_SYMBOL else [NME_SYMBOL]
for cs, ct in zip(candidates_symbols, sample.triplet_candidates)
]
if len(self.special_symbols_types) > 0:
candidates_entities_symbols = candidates_entities_symbols[
: len(sample.span_candidates)
]
candidates_types_encoding = [
["{} {}".format(cs, ct)] if ct != NME_SYMBOL else [NME_SYMBOL]
for cs, ct in zip(
candidates_entities_symbols, sample.span_candidates
)
]
candidates_encoding = (
candidates_types_encoding
+ [[self.tokenizer.sep_token]]
+ candidates_encoding
)
pretoken_input = self._build_input(sample.words, candidates_encoding)
input_tokenized = self.tokenizer(
pretoken_input,
return_offsets_mapping=True,
add_special_tokens=False,
)
window_tokens = input_tokenized.input_ids
window_tokens = flatten(window_tokens)
offsets_mapping = [
[
(
ss + sample.token2char_start[str(i)],
se + sample.token2char_start[str(i)],
)
for ss, se in input_tokenized.offset_mapping[i]
]
for i in range(len(sample.words))
]
offsets_mapping = flatten(offsets_mapping)
token2char_start = {str(i): s for i, (s, _) in enumerate(offsets_mapping)}
token2char_end = {str(i): e for i, (_, e) in enumerate(offsets_mapping)}
token2word_start = {
str(i): int(sample._d["char2token_start"][str(s)])
for i, (s, _) in enumerate(offsets_mapping)
if str(s) in sample._d["char2token_start"]
}
token2word_end = {
str(i): int(sample._d["char2token_end"][str(e)])
for i, (_, e) in enumerate(offsets_mapping)
if str(e) in sample._d["char2token_end"]
}
# invert token2word_start and token2word_end
word2token_start = {str(v): int(k) for k, v in token2word_start.items()}
word2token_end = {str(v): int(k) for k, v in token2word_end.items()}
sample._d.update(
dict(
tokens=window_tokens,
token2char_start=token2char_start,
token2char_end=token2char_end,
token2word_start=token2word_start,
token2word_end=token2word_end,
word2token_start=word2token_start,
word2token_end=word2token_end,
)
)
input_subwords = flatten(input_tokenized["input_ids"][: len(sample.words)])
offsets = input_tokenized["offset_mapping"][: len(sample.words)]
token2word = []
word2token = {}
count = 0
for i, offset in enumerate(offsets):
word2token[i] = []
for token in offset:
token2word.append(i)
word2token[i].append(count)
count += 1
sample.token2word = token2word
sample.word2token = word2token
candidates_encoding_result = input_tokenized["input_ids"][
len(sample.words) + 1 : -1
]
i = 0
cum_len = 0
# drop candidates if the number of input tokens is too long for the model
if (
sum(map(len, candidates_encoding_result))
+ len(input_subwords)
+ 20 # + 20 special tokens
> self.model_max_length
):
if self.for_inference:
acceptable_tokens_from_candidates = (
self.model_max_length - 20 - len(input_subwords)
)
while (
cum_len + len(candidates_encoding_result[i])
< acceptable_tokens_from_candidates
):
cum_len += len(candidates_encoding_result[i])
i += 1
assert i > 0
candidates_encoding_result = candidates_encoding_result[:i]
if len(self.special_symbols_types) > 0:
candidates_symbols = candidates_symbols[
: i - len(sample.span_candidates)
]
sample.triplet_candidates = sample.triplet_candidates[
: i - len(sample.span_candidates)
]
else:
candidates_symbols = candidates_symbols[:i]
sample.triplet_candidates = sample.triplet_candidates[:i]
else:
gold_candidates_set = set(
[wl["relation"]["name"] for wl in sample.triplets]
)
gold_candidates_indices = [
i
for i, wc in enumerate(sample.triplet_candidates)
if wc in gold_candidates_set
]
if len(self.special_symbols_types) > 0:
gold_candidates_indices = [
i + len(sample.span_candidates)
for i in gold_candidates_indices
]
# add entities indices
gold_candidates_indices = gold_candidates_indices + list(
range(len(sample.span_candidates))
)
necessary_taken_tokens = sum(
map(
len,
[
candidates_encoding_result[i]
for i in gold_candidates_indices
],
)
)
acceptable_tokens_from_candidates = (
self.model_max_length
- 20
- len(input_subwords)
- necessary_taken_tokens
)
if acceptable_tokens_from_candidates <= 0:
logger.warning(
"Sample {} has no candidates after truncation due to max length".format(
sample.id
)
)
continue
# assert acceptable_tokens_from_candidates > 0
i = 0
cum_len = 0
while (
cum_len + len(candidates_encoding_result[i])
< acceptable_tokens_from_candidates
):
if i not in gold_candidates_indices:
cum_len += len(candidates_encoding_result[i])
i += 1
new_indices = sorted(
list(set(list(range(i)) + gold_candidates_indices))
)
# np.random.shuffle(new_indices)
candidates_encoding_result = [
candidates_encoding_result[i] for i in new_indices
]
if len(self.special_symbols_types) > 0:
sample.triplet_candidates = [
sample.triplet_candidates[i - len(sample.span_candidates)]
for i in new_indices[len(sample.span_candidates) :]
]
candidates_symbols = candidates_symbols[
: i - len(sample.span_candidates)
]
else:
candidates_symbols = [
candidates_symbols[i] for i in new_indices
]
sample.triplet_candidates = [
sample.triplet_candidates[i] for i in new_indices
]
if len(sample.triplet_candidates) == 0:
logger.warning(
"Sample {} has no candidates after truncation due to max length".format(
sample.sample_id
)
)
continue
# final input_ids build
input_ids = self._build_input_ids(
sentence_input_ids=input_subwords,
candidates_input_ids=candidates_encoding_result,
)
# complete input building (e.g. attention / prediction mask)
tokenization_output = self._build_tokenizer_essentials(
input_ids,
input_subwords,
min(len(sample.span_candidates), len(self.special_symbols_types))
if sample.span_candidates is not None
else 0,
)
# labels creation
start_labels, end_labels, disambiguation_labels, relation_labels = (
None,
None,
None,
None,
)
if sample.entities is not None and len(sample.entities) > 0:
(
start_labels,
end_labels,
disambiguation_labels,
relation_labels,
) = self._build_labels(
sample,
tokenization_output,
)
if self.materialize_samples:
sample.materialize = {
"tokenization_output": tokenization_output,
"start_labels": start_labels,
"end_labels": end_labels,
"disambiguation_labels": disambiguation_labels,
"relation_labels": relation_labels,
"candidates_symbols": candidates_symbols,
}
data_acc.append(sample)
yield {
"input_ids": tokenization_output.input_ids,
"attention_mask": tokenization_output.attention_mask,
"token_type_ids": tokenization_output.token_type_ids,
"prediction_mask": tokenization_output.prediction_mask,
"special_symbols_mask": tokenization_output.special_symbols_mask,
"special_symbols_mask_entities": tokenization_output.special_symbols_mask_entities,
"sample": sample,
"start_labels": start_labels,
"end_labels": end_labels,
"disambiguation_labels": disambiguation_labels,
"relation_labels": relation_labels,
"predictable_candidates": candidates_symbols,
}
if self.materialize_samples:
self.samples = data_acc
def preshuffle_elements(self, dataset_elements: List):
# This shuffling is done so that when using the sorting function,
# if it is deterministic given a collection and its order, we will
# make the whole operation not deterministic anymore.
# Basically, the aim is not to build every time the same batches.
if not self.for_inference:
dataset_elements = np.random.permutation(dataset_elements)
sorting_fn = (
lambda elem: add_noise_to_value(
sum(len(elem[k]) for k in self.sorting_fields),
noise_param=self.noise_param,
)
if not self.for_inference
else sum(len(elem[k]) for k in self.sorting_fields)
)
dataset_elements = sorted(dataset_elements, key=sorting_fn)
if self.for_inference:
return dataset_elements
ds = list(chunks(dataset_elements, 64)) # todo: modified
np.random.shuffle(ds)
return flatten(ds)
def materialize_batches(
self, dataset_elements: List[Dict[str, Any]]
) -> Generator[Dict[str, Any], None, None]:
if self.prebatch:
dataset_elements = self.preshuffle_elements(dataset_elements)
current_batch = []
# function that creates a batch from the 'current_batch' list
def output_batch() -> Dict[str, Any]:
assert (
len(
set([len(elem["predictable_candidates"]) for elem in current_batch])
)
== 1
), " ".join(
map(
str, [len(elem["predictable_candidates"]) for elem in current_batch]
)
)
batch_dict = dict()
de_values_by_field = {
fn: [de[fn] for de in current_batch if fn in de]
for fn in self.fields_batcher
}
# in case you provide fields batchers but in the batch
# there are no elements for that field
de_values_by_field = {
fn: fvs for fn, fvs in de_values_by_field.items() if len(fvs) > 0
}
assert len(set([len(v) for v in de_values_by_field.values()]))
# todo: maybe we should report the user about possible
# fields filtering due to "None" instances
de_values_by_field = {
fn: fvs
for fn, fvs in de_values_by_field.items()
if all([fv is not None for fv in fvs])
}
for field_name, field_values in de_values_by_field.items():
field_batch = (
self.fields_batcher[field_name](field_values)
if self.fields_batcher[field_name] is not None
else field_values
)
batch_dict[field_name] = field_batch
return batch_dict
max_len_discards, min_len_discards = 0, 0
should_token_batch = self.batch_size is None
curr_pred_elements = -1
for de in dataset_elements:
if (
should_token_batch
and self.max_batch_size != -1
and len(current_batch) == self.max_batch_size
) or (not should_token_batch and len(current_batch) == self.batch_size):
yield output_batch()
current_batch = []
curr_pred_elements = -1
# todo support max length (and min length) as dicts
too_long_fields = [
k
for k in de
if self.max_length != -1
and torch.is_tensor(de[k])
and len(de[k]) > self.max_length
]
if len(too_long_fields) > 0:
max_len_discards += 1
continue
too_short_fields = [
k
for k in de
if self.min_length != -1
and torch.is_tensor(de[k])
and len(de[k]) < self.min_length
]
if len(too_short_fields) > 0:
min_len_discards += 1
continue
if should_token_batch:
de_len = sum(len(de[k]) for k in self.batching_fields)
future_max_len = max(
de_len,
max(
[
sum(len(bde[k]) for k in self.batching_fields)
for bde in current_batch
],
default=0,
),
)
future_tokens_per_batch = future_max_len * (len(current_batch) + 1)
num_predictable_candidates = len(de["predictable_candidates"])
if len(current_batch) > 0 and (
future_tokens_per_batch >= self.tokens_per_batch
or (
num_predictable_candidates != curr_pred_elements
and curr_pred_elements != -1
)
):
yield output_batch()
current_batch = []
current_batch.append(de)
curr_pred_elements = len(de["predictable_candidates"])
if len(current_batch) != 0 and not self.drop_last:
yield output_batch()
if max_len_discards > 0:
if self.for_inference:
logger.warning(
f"WARNING: Inference mode is True but {max_len_discards} samples longer than max length were "
f"found. The {max_len_discards} samples will be DISCARDED. If you are doing some kind of evaluation"
f", this can INVALIDATE results. This might happen if the max length was not set to -1 or if the "
f"sample length exceeds the maximum length supported by the current model."
)
else:
logger.warning(
f"During iteration, {max_len_discards} elements were "
f"discarded since longer than max length {self.max_length}"
)
if min_len_discards > 0:
if self.for_inference:
logger.warning(
f"WARNING: Inference mode is True but {min_len_discards} samples shorter than min length were "
f"found. The {min_len_discards} samples will be DISCARDED. If you are doing some kind of evaluation"
f", this can INVALIDATE results. This might happen if the min length was not set to -1 or if the "
f"sample length is shorter than the minimum length supported by the current model."
)
else:
logger.warning(
f"During iteration, {min_len_discards} elements were "
f"discarded since shorter than min length {self.min_length}"
)
@staticmethod
def _new_output_format(sample: RelikReaderSample) -> RelikReaderSample:
# try-out for a new format
# set of span tuples (start, end, type) for each entity
predicted_spans = set()
for prediction in sample.predicted_entities:
predicted_spans.add(
(
prediction[0],
prediction[1],
prediction[2],
)
)
# sort the spans by start so that we can use the index of the span to get the entity
predicted_spans = sorted(predicted_spans, key=lambda x: x[0])
predicted_triples = []
# now search for the spans in each triplet
for prediction in sample.predicted_relations:
# get the index of the entity that has the same start and end
start_entity_index = [
i
for i, p in enumerate(predicted_spans)
if p[:2]
== (prediction["subject"]["start"], prediction["subject"]["end"])
][0]
end_entity_index = [
i
for i, p in enumerate(predicted_spans)
if p[:2] == (prediction["object"]["start"], prediction["object"]["end"])
][0]
predicted_triples.append(
(
start_entity_index,
prediction["relation"]["name"],
end_entity_index,
prediction["relation"]["probability"],
)
)
sample.predicted_spans = predicted_spans
sample.predicted_triples = predicted_triples
return sample
@staticmethod
def _convert_annotations(sample: RelikReaderSample) -> RelikReaderSample:
triplets = []
entities = []
for entity in sample.predicted_entities:
span_start = entity[0] - 1
span_end = entity[1] - 1
if str(span_start) not in sample.token2word_start:
# span_start is in the middle of a word
# retrieve the first token of the word
while str(span_start) not in sample.token2word_start:
span_start -= 1
# skip
if span_start < 0:
break
if str(span_end) not in sample.token2word_end:
# span_end is in the middle of a word
# retrieve the last token of the word
while str(span_end) not in sample.token2word_end:
span_end += 1
# skip
if span_end >= len(sample.tokens):
break
if span_start < 0 or span_end >= len(sample.tokens):
continue
entities.append(
(
sample.token2word_start[str(span_start)],
sample.token2word_end[str(span_end)] + 1,
sample.span_candidates[entity[2]]
if sample.span_candidates and len(entity) > 2
else "NME",
)
)
for predicted_triplet, predicted_triplet_probabilities in zip(
sample.predicted_relations, sample.predicted_relations_probabilities
):
subject, object_, relation = predicted_triplet
subject = entities[subject]
object_ = entities[object_]
relation = sample.triplet_candidates[relation]
triplets.append(
{
"subject": {
"start": subject[0],
"end": subject[1],
"type": subject[2],
# "name": " ".join(sample.tokens[subject[0] : subject[1]]),
},
"relation": {
"name": relation,
"probability": float(predicted_triplet_probabilities.round(2)),
},
"object": {
"start": object_[0],
"end": object_[1],
"type": object_[2],
# "name": " ".join(sample.tokens[object_[0] : object_[1]]),
},
}
)
# convert to list since we need to modify the sample down the road
sample.predicted_entities = entities
sample.predicted_relations = triplets
del sample._d["predicted_relations_probabilities"]
return sample
@staticmethod
def convert_to_word_annotations(sample: RelikReaderSample) -> RelikReaderSample:
sample = RelikREDataset._convert_annotations(sample)
return RelikREDataset._new_output_format(sample)
@staticmethod
def convert_to_char_annotations(
sample: RelikReaderSample,
remove_nmes: bool = True,
) -> RelikReaderSample:
RelikREDataset._convert_annotations(sample)
if "token2char_start" in sample._d:
entities = []
for entity in sample.predicted_entities:
entity = list(entity)
token_start = sample.word2token_start[str(entity[0])]
entity[0] = sample.token2char_start[str(token_start)]
token_end = sample.word2token_end[str(entity[1] - 1)]
entity[1] = sample.token2char_end[str(token_end)]
entities.append(entity)
sample.predicted_entities = entities
for triplet in sample.predicted_relations:
triplet["subject"]["start"] = sample.token2char_start[
str(sample.word2token_start[str(triplet["subject"]["start"])])
]
triplet["subject"]["end"] = sample.token2char_end[
str(sample.word2token_end[str(triplet["subject"]["end"] - 1)])
]
triplet["object"]["start"] = sample.token2char_start[
str(sample.word2token_start[str(triplet["object"]["start"])])
]
triplet["object"]["end"] = sample.token2char_end[
str(sample.word2token_end[str(triplet["object"]["end"] - 1)])
]
sample = RelikREDataset._new_output_format(sample)
return sample
@staticmethod
def merge_patches_predictions(sample) -> None:
pass
def main():
special_symbols = [NME_SYMBOL] + [f"R-{i}" for i in range(50)]
relik_dataset = RelikREDataset(
"/home/huguetcabot/alby-re/alby/data/nyt-alby+/valid.jsonl",
materialize_samples=False,
transformer_model="microsoft/deberta-v3-base",
special_symbols=special_symbols,
shuffle_candidates=False,
flip_candidates=False,
for_inference=True,
)
for batch in relik_dataset:
print(batch)
exit(0)
if __name__ == "__main__":
main()