|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import math |
|
import time |
|
import collections |
|
|
|
import numpy as np |
|
import re |
|
|
|
from fengshen.data.megatron_dataloader.utils import ( |
|
print_rank_0 |
|
) |
|
from fengshen.data.megatron_dataloader.blendable_dataset import BlendableDataset |
|
from fengshen.data.megatron_dataloader.indexed_dataset import make_dataset as make_indexed_dataset |
|
|
|
DSET_TYPE_BERT = 'standard_bert' |
|
DSET_TYPE_ICT = 'ict' |
|
DSET_TYPE_T5 = 't5' |
|
DSET_TYPE_BERT_CN_WWM = 'bert_cn_wwm' |
|
DSET_TYPE_BART = 'bart' |
|
DSET_TYPE_COCOLM = 'coco_lm' |
|
|
|
DSET_TYPES = [DSET_TYPE_BERT, DSET_TYPE_ICT, |
|
DSET_TYPE_T5, DSET_TYPE_BERT_CN_WWM, |
|
DSET_TYPE_BART, DSET_TYPE_COCOLM] |
|
|
|
|
|
def get_datasets_weights_and_num_samples(data_prefix, |
|
train_valid_test_num_samples): |
|
|
|
|
|
|
|
assert len(data_prefix) % 2 == 0 |
|
num_datasets = len(data_prefix) // 2 |
|
weights = [0] * num_datasets |
|
prefixes = [0] * num_datasets |
|
for i in range(num_datasets): |
|
weights[i] = float(data_prefix[2 * i]) |
|
prefixes[i] = (data_prefix[2 * i + 1]).strip() |
|
|
|
weight_sum = 0.0 |
|
for weight in weights: |
|
weight_sum += weight |
|
assert weight_sum > 0.0 |
|
weights = [weight / weight_sum for weight in weights] |
|
|
|
|
|
|
|
|
|
datasets_train_valid_test_num_samples = [] |
|
for weight in weights: |
|
datasets_train_valid_test_num_samples.append( |
|
[int(math.ceil(val * weight * 1.005)) |
|
for val in train_valid_test_num_samples]) |
|
|
|
return prefixes, weights, datasets_train_valid_test_num_samples |
|
|
|
|
|
def compile_helper(): |
|
"""Compile helper function ar runtime. Make sure this |
|
is invoked on a single process.""" |
|
import os |
|
import subprocess |
|
path = os.path.abspath(os.path.dirname(__file__)) |
|
ret = subprocess.run(['make', '-C', path]) |
|
if ret.returncode != 0: |
|
print("Making C++ dataset helpers module failed, exiting.") |
|
import sys |
|
sys.exit(1) |
|
|
|
|
|
def get_a_and_b_segments(sample, np_rng): |
|
"""Divide sample into a and b segments.""" |
|
|
|
|
|
n_sentences = len(sample) |
|
|
|
assert n_sentences > 1, 'make sure each sample has at least two sentences.' |
|
|
|
|
|
|
|
a_end = 1 |
|
if n_sentences >= 3: |
|
|
|
a_end = np_rng.randint(1, n_sentences) |
|
tokens_a = [] |
|
for j in range(a_end): |
|
tokens_a.extend(sample[j]) |
|
|
|
|
|
tokens_b = [] |
|
for j in range(a_end, n_sentences): |
|
tokens_b.extend(sample[j]) |
|
|
|
|
|
is_next_random = False |
|
if np_rng.random() < 0.5: |
|
is_next_random = True |
|
tokens_a, tokens_b = tokens_b, tokens_a |
|
|
|
return tokens_a, tokens_b, is_next_random |
|
|
|
|
|
def truncate_segments(tokens_a, tokens_b, len_a, len_b, max_num_tokens, np_rng): |
|
"""Truncates a pair of sequences to a maximum sequence length.""" |
|
|
|
assert len_a > 0 |
|
if len_a + len_b <= max_num_tokens: |
|
return False |
|
while len_a + len_b > max_num_tokens: |
|
if len_a > len_b: |
|
len_a -= 1 |
|
tokens = tokens_a |
|
else: |
|
len_b -= 1 |
|
tokens = tokens_b |
|
if np_rng.random() < 0.5: |
|
del tokens[0] |
|
else: |
|
tokens.pop() |
|
return True |
|
|
|
|
|
def create_tokens_and_tokentypes(tokens_a, tokens_b, cls_id, sep_id): |
|
"""Merge segments A and B, add [CLS] and [SEP] and build tokentypes.""" |
|
|
|
tokens = [] |
|
tokentypes = [] |
|
|
|
tokens.append(cls_id) |
|
tokentypes.append(0) |
|
|
|
for token in tokens_a: |
|
tokens.append(token) |
|
tokentypes.append(0) |
|
|
|
tokens.append(sep_id) |
|
tokentypes.append(0) |
|
|
|
for token in tokens_b: |
|
tokens.append(token) |
|
tokentypes.append(1) |
|
if tokens_b: |
|
|
|
tokens.append(sep_id) |
|
tokentypes.append(1) |
|
|
|
return tokens, tokentypes |
|
|
|
|
|
MaskedLmInstance = collections.namedtuple("MaskedLmInstance", |
|
["index", "label"]) |
|
|
|
|
|
def is_start_piece(piece): |
|
"""Check if the current word piece is the starting piece (BERT).""" |
|
|
|
|
|
|
|
|
|
return not piece.startswith("##") |
|
|
|
|
|
def create_masked_lm_predictions(tokens, |
|
vocab_id_list, vocab_id_to_token_dict, |
|
masked_lm_prob, |
|
cls_id, sep_id, mask_id, |
|
max_predictions_per_seq, |
|
np_rng, |
|
tokenizer, |
|
max_ngrams=3, |
|
do_whole_word_mask=True, |
|
favor_longer_ngram=False, |
|
do_permutation=False, |
|
geometric_dist=False, |
|
masking_style="bert", |
|
zh_tokenizer=None): |
|
"""Creates the predictions for the masked LM objective. |
|
Note: Tokens here are vocab ids and not text tokens.""" |
|
|
|
cand_indexes = [] |
|
|
|
|
|
|
|
token_boundary = [0] * len(tokens) |
|
|
|
|
|
if zh_tokenizer is None: |
|
for (i, token) in enumerate(tokens): |
|
if token == cls_id or token == sep_id: |
|
token_boundary[i] = 1 |
|
continue |
|
|
|
|
|
|
|
|
|
|
|
|
|
if (do_whole_word_mask and len(cand_indexes) >= 1 and |
|
not is_start_piece(vocab_id_to_token_dict[token])): |
|
cand_indexes[-1].append(i) |
|
else: |
|
cand_indexes.append([i]) |
|
if is_start_piece(vocab_id_to_token_dict[token]): |
|
token_boundary[i] = 1 |
|
else: |
|
|
|
|
|
raw_tokens = [] |
|
for t in tokens: |
|
if t != cls_id and t != sep_id: |
|
raw_tokens.append(t) |
|
raw_tokens = [vocab_id_to_token_dict[i] for i in raw_tokens] |
|
|
|
word_list = set(zh_tokenizer(''.join(raw_tokens), HMM=True)) |
|
word_length_dict = {} |
|
for w in word_list: |
|
if len(w) < 1: |
|
continue |
|
if w[0] not in word_length_dict: |
|
word_length_dict[w[0]] = len(w) |
|
elif word_length_dict[w[0]] < len(w): |
|
word_length_dict[w[0]] = len(w) |
|
i = 0 |
|
|
|
while i < len(tokens): |
|
token_id = tokens[i] |
|
token = vocab_id_to_token_dict[token_id] |
|
if len(token) == 0 or token_id == cls_id or token_id == sep_id: |
|
token_boundary[i] = 1 |
|
i += 1 |
|
continue |
|
word_max_length = 1 |
|
if token[0] in word_length_dict: |
|
word_max_length = word_length_dict[token[0]] |
|
j = 0 |
|
word = '' |
|
word_end = i+1 |
|
|
|
old_style = False |
|
while word_end < len(tokens) and vocab_id_to_token_dict[tokens[word_end]].startswith('##'): |
|
old_style = True |
|
word_end += 1 |
|
if not old_style: |
|
while j < word_max_length and i+j < len(tokens): |
|
cur_token = tokens[i+j] |
|
word += vocab_id_to_token_dict[cur_token] |
|
j += 1 |
|
if word in word_list: |
|
word_end = i+j |
|
cand_indexes.append([p for p in range(i, word_end)]) |
|
token_boundary[i] = 1 |
|
i = word_end |
|
|
|
output_tokens = list(tokens) |
|
|
|
if masking_style == 'bert-cn-wwm': |
|
|
|
|
|
new_token_ids = [] |
|
for token_id in output_tokens: |
|
token = tokenizer.convert_ids_to_tokens([token_id])[0] |
|
if len(re.findall('##[\u4E00-\u9FA5]', token)) > 0: |
|
token = token[2:] |
|
new_token_id = tokenizer.convert_tokens_to_ids([token])[ |
|
0] |
|
new_token_ids.append(new_token_id) |
|
output_tokens = new_token_ids |
|
|
|
masked_lm_positions = [] |
|
masked_lm_labels = [] |
|
|
|
if masked_lm_prob == 0: |
|
return (output_tokens, masked_lm_positions, |
|
masked_lm_labels, token_boundary) |
|
|
|
num_to_predict = min(max_predictions_per_seq, |
|
max(1, int(round(len(tokens) * masked_lm_prob)))) |
|
|
|
ngrams = np.arange(1, max_ngrams + 1, dtype=np.int64) |
|
if not geometric_dist: |
|
|
|
|
|
pvals = 1. / np.arange(1, max_ngrams + 1) |
|
pvals /= pvals.sum(keepdims=True) |
|
if favor_longer_ngram: |
|
pvals = pvals[::-1] |
|
|
|
ngram_indexes = [] |
|
for idx in range(len(cand_indexes)): |
|
ngram_index = [] |
|
for n in ngrams: |
|
ngram_index.append(cand_indexes[idx:idx + n]) |
|
ngram_indexes.append(ngram_index) |
|
|
|
np_rng.shuffle(ngram_indexes) |
|
|
|
(masked_lms, masked_spans) = ([], []) |
|
covered_indexes = set() |
|
for cand_index_set in ngram_indexes: |
|
if len(masked_lms) >= num_to_predict: |
|
break |
|
if not cand_index_set: |
|
continue |
|
|
|
|
|
for index_set in cand_index_set[0]: |
|
for index in index_set: |
|
if index in covered_indexes: |
|
continue |
|
|
|
if not geometric_dist: |
|
n = np_rng.choice(ngrams[:len(cand_index_set)], |
|
p=pvals[:len(cand_index_set)] / |
|
pvals[:len(cand_index_set)].sum(keepdims=True)) |
|
else: |
|
|
|
|
|
|
|
n = min(np_rng.geometric(0.2), max_ngrams) |
|
|
|
index_set = sum(cand_index_set[n - 1], []) |
|
n -= 1 |
|
|
|
|
|
|
|
while len(masked_lms) + len(index_set) > num_to_predict: |
|
if n == 0: |
|
break |
|
index_set = sum(cand_index_set[n - 1], []) |
|
n -= 1 |
|
|
|
|
|
if len(masked_lms) + len(index_set) > num_to_predict: |
|
continue |
|
is_any_index_covered = False |
|
for index in index_set: |
|
if index in covered_indexes: |
|
is_any_index_covered = True |
|
break |
|
if is_any_index_covered: |
|
continue |
|
for index in index_set: |
|
covered_indexes.add(index) |
|
masked_token = None |
|
if masking_style == "bert": |
|
|
|
if np_rng.random() < 0.8: |
|
masked_token = mask_id |
|
else: |
|
|
|
if np_rng.random() < 0.5: |
|
masked_token = tokens[index] |
|
|
|
else: |
|
masked_token = vocab_id_list[np_rng.randint(0, len(vocab_id_list))] |
|
elif masking_style == 'bert-cn-wwm': |
|
|
|
if np_rng.random() < 0.8: |
|
masked_token = mask_id |
|
else: |
|
|
|
if np_rng.random() < 0.5: |
|
|
|
token_id = tokens[index] |
|
token = tokenizer.convert_ids_to_tokens([token_id])[ |
|
0] |
|
if len(re.findall('##[\u4E00-\u9FA5]', token)) > 0: |
|
token = token[2:] |
|
new_token_id = tokenizer.convert_tokens_to_ids([token])[ |
|
0] |
|
masked_token = new_token_id |
|
|
|
else: |
|
masked_token = vocab_id_list[np_rng.randint( |
|
0, len(vocab_id_list))] |
|
elif masking_style == "t5": |
|
masked_token = mask_id |
|
else: |
|
raise ValueError("invalid value of masking style") |
|
|
|
output_tokens[index] = masked_token |
|
masked_lms.append(MaskedLmInstance( |
|
index=index, label=tokens[index])) |
|
|
|
masked_spans.append(MaskedLmInstance( |
|
index=index_set, |
|
label=[tokens[index] for index in index_set])) |
|
|
|
assert len(masked_lms) <= num_to_predict |
|
np_rng.shuffle(ngram_indexes) |
|
|
|
select_indexes = set() |
|
if do_permutation: |
|
for cand_index_set in ngram_indexes: |
|
if len(select_indexes) >= num_to_predict: |
|
break |
|
if not cand_index_set: |
|
continue |
|
|
|
|
|
for index_set in cand_index_set[0]: |
|
for index in index_set: |
|
if index in covered_indexes or index in select_indexes: |
|
continue |
|
|
|
n = np.random.choice(ngrams[:len(cand_index_set)], |
|
p=pvals[:len(cand_index_set)] / |
|
pvals[:len(cand_index_set)].sum(keepdims=True)) |
|
index_set = sum(cand_index_set[n - 1], []) |
|
n -= 1 |
|
|
|
while len(select_indexes) + len(index_set) > num_to_predict: |
|
if n == 0: |
|
break |
|
index_set = sum(cand_index_set[n - 1], []) |
|
n -= 1 |
|
|
|
|
|
if len(select_indexes) + len(index_set) > num_to_predict: |
|
continue |
|
is_any_index_covered = False |
|
for index in index_set: |
|
if index in covered_indexes or index in select_indexes: |
|
is_any_index_covered = True |
|
break |
|
if is_any_index_covered: |
|
continue |
|
for index in index_set: |
|
select_indexes.add(index) |
|
assert len(select_indexes) <= num_to_predict |
|
|
|
select_indexes = sorted(select_indexes) |
|
permute_indexes = list(select_indexes) |
|
np_rng.shuffle(permute_indexes) |
|
orig_token = list(output_tokens) |
|
|
|
for src_i, tgt_i in zip(select_indexes, permute_indexes): |
|
output_tokens[src_i] = orig_token[tgt_i] |
|
masked_lms.append(MaskedLmInstance( |
|
index=src_i, label=orig_token[src_i])) |
|
|
|
masked_lms = sorted(masked_lms, key=lambda x: x.index) |
|
|
|
masked_spans = sorted(masked_spans, key=lambda x: x.index[0]) |
|
|
|
for p in masked_lms: |
|
masked_lm_positions.append(p.index) |
|
masked_lm_labels.append(p.label) |
|
return (output_tokens, masked_lm_positions, masked_lm_labels, token_boundary, masked_spans) |
|
|
|
|
|
def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions, |
|
masked_labels, pad_id, max_seq_length): |
|
"""Pad sequences and convert them to numpy.""" |
|
|
|
|
|
num_tokens = len(tokens) |
|
padding_length = max_seq_length - num_tokens |
|
assert padding_length >= 0 |
|
assert len(tokentypes) == num_tokens |
|
assert len(masked_positions) == len(masked_labels) |
|
|
|
|
|
filler = [pad_id] * padding_length |
|
tokens_np = np.array(tokens + filler, dtype=np.int64) |
|
tokentypes_np = np.array(tokentypes + filler, dtype=np.int64) |
|
|
|
|
|
padding_mask_np = np.array([1] * num_tokens + [0] * padding_length, |
|
dtype=np.int64) |
|
|
|
|
|
labels = [-1] * max_seq_length |
|
loss_mask = [0] * max_seq_length |
|
for i in range(len(masked_positions)): |
|
assert masked_positions[i] < num_tokens |
|
labels[masked_positions[i]] = masked_labels[i] |
|
loss_mask[masked_positions[i]] = 1 |
|
labels_np = np.array(labels, dtype=np.int64) |
|
loss_mask_np = np.array(loss_mask, dtype=np.int64) |
|
|
|
return tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np |
|
|
|
|
|
def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, |
|
train_valid_test_num_samples, |
|
max_seq_length, |
|
masked_lm_prob, short_seq_prob, seed, |
|
tokenizer, |
|
skip_warmup, binary_head=False, |
|
max_seq_length_dec=None, |
|
dataset_type='standard_bert', |
|
zh_tokenizer=None, |
|
span=None): |
|
|
|
if len(data_prefix) == 1: |
|
return _build_train_valid_test_datasets(data_prefix[0], |
|
data_impl, splits_string, |
|
train_valid_test_num_samples, |
|
max_seq_length, masked_lm_prob, |
|
short_seq_prob, seed, |
|
skip_warmup, |
|
binary_head, |
|
max_seq_length_dec, |
|
tokenizer, |
|
dataset_type=dataset_type, |
|
zh_tokenizer=zh_tokenizer, |
|
span=span) |
|
|
|
|
|
output = get_datasets_weights_and_num_samples(data_prefix, |
|
train_valid_test_num_samples) |
|
prefixes, weights, datasets_train_valid_test_num_samples = output |
|
|
|
|
|
train_datasets = [] |
|
valid_datasets = [] |
|
test_datasets = [] |
|
for i in range(len(prefixes)): |
|
train_ds, valid_ds, test_ds = _build_train_valid_test_datasets( |
|
prefixes[i], data_impl, splits_string, |
|
datasets_train_valid_test_num_samples[i], |
|
max_seq_length, masked_lm_prob, short_seq_prob, |
|
seed, skip_warmup, binary_head, max_seq_length_dec, |
|
tokenizer, dataset_type=dataset_type, zh_tokenizer=zh_tokenizer) |
|
if train_ds: |
|
train_datasets.append(train_ds) |
|
if valid_ds: |
|
valid_datasets.append(valid_ds) |
|
if test_ds: |
|
test_datasets.append(test_ds) |
|
|
|
|
|
blending_train_dataset = None |
|
if train_datasets: |
|
blending_train_dataset = BlendableDataset(train_datasets, weights) |
|
blending_valid_dataset = None |
|
if valid_datasets: |
|
blending_valid_dataset = BlendableDataset(valid_datasets, weights) |
|
blending_test_dataset = None |
|
if test_datasets: |
|
blending_test_dataset = BlendableDataset(test_datasets, weights) |
|
|
|
return (blending_train_dataset, blending_valid_dataset, |
|
blending_test_dataset) |
|
|
|
|
|
def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string, |
|
train_valid_test_num_samples, |
|
max_seq_length, |
|
masked_lm_prob, short_seq_prob, seed, |
|
skip_warmup, binary_head, |
|
max_seq_length_dec, |
|
tokenizer, |
|
dataset_type='standard_bert', |
|
zh_tokenizer=None, |
|
span=None): |
|
|
|
if dataset_type not in DSET_TYPES: |
|
raise ValueError("Invalid dataset_type: ", dataset_type) |
|
|
|
|
|
indexed_dataset = get_indexed_dataset_(data_prefix, |
|
data_impl, |
|
skip_warmup) |
|
|
|
|
|
|
|
|
|
total_num_of_documents = indexed_dataset.doc_idx.shape[0] - 1 |
|
splits = get_train_valid_test_split_(splits_string, total_num_of_documents) |
|
|
|
|
|
print_rank_0(' > dataset split:') |
|
|
|
def print_split_stats(name, index): |
|
print_rank_0(' {}:'.format(name)) |
|
print_rank_0(' document indices in [{}, {}) total of {} ' |
|
'documents'.format(splits[index], splits[index + 1], |
|
splits[index + 1] - splits[index])) |
|
start_index = indexed_dataset.doc_idx[splits[index]] |
|
end_index = indexed_dataset.doc_idx[splits[index + 1]] |
|
print_rank_0(' sentence indices in [{}, {}) total of {} ' |
|
'sentences'.format(start_index, end_index, |
|
end_index - start_index)) |
|
print_split_stats('train', 0) |
|
print_split_stats('validation', 1) |
|
print_split_stats('test', 2) |
|
|
|
def build_dataset(index, name): |
|
from fengshen.data.megatron_dataloader.bert_dataset import BertDataset |
|
from fengshen.data.megatron_dataloader.bart_dataset import BartDataset |
|
from fengshen.data.megatron_dataloader.cocolm_dataset import COCOLMDataset |
|
dataset = None |
|
if splits[index + 1] > splits[index]: |
|
|
|
doc_idx_ptr = indexed_dataset.get_doc_idx() |
|
|
|
start_index = splits[index] |
|
|
|
end_index = splits[index + 1] + 1 |
|
|
|
indexed_dataset.set_doc_idx(doc_idx_ptr[start_index:end_index]) |
|
|
|
kwargs = dict( |
|
name=name, |
|
data_prefix=data_prefix, |
|
num_epochs=None, |
|
max_num_samples=train_valid_test_num_samples[index], |
|
max_seq_length=max_seq_length, |
|
seed=seed, |
|
) |
|
|
|
if dataset_type == DSET_TYPE_BERT or dataset_type == DSET_TYPE_BERT_CN_WWM: |
|
dataset = BertDataset( |
|
indexed_dataset=indexed_dataset, |
|
masked_lm_prob=masked_lm_prob, |
|
short_seq_prob=short_seq_prob, |
|
binary_head=binary_head, |
|
|
|
tokenizer=tokenizer, |
|
masking_style='bert' if dataset_type == DSET_TYPE_BERT else 'bert-cn-wwm', |
|
**kwargs |
|
) |
|
elif dataset_type == DSET_TYPE_BART: |
|
dataset = BartDataset( |
|
indexed_dataset=indexed_dataset, |
|
masked_lm_prob=masked_lm_prob, |
|
short_seq_prob=short_seq_prob, |
|
tokenizer=tokenizer, |
|
zh_tokenizer=zh_tokenizer, |
|
**kwargs |
|
) |
|
elif dataset_type == DSET_TYPE_COCOLM: |
|
dataset = COCOLMDataset( |
|
indexed_dataset=indexed_dataset, |
|
masked_lm_prob=masked_lm_prob, |
|
short_seq_prob=short_seq_prob, |
|
tokenizer=tokenizer, |
|
masking_style='bert', |
|
span=span, |
|
**kwargs |
|
) |
|
else: |
|
raise NotImplementedError( |
|
"Dataset type not fully implemented.") |
|
|
|
|
|
indexed_dataset.set_doc_idx(doc_idx_ptr) |
|
|
|
assert indexed_dataset.doc_idx[0] == 0 |
|
assert indexed_dataset.doc_idx.shape[0] == \ |
|
(total_num_of_documents + 1) |
|
return dataset |
|
|
|
train_dataset = build_dataset(0, 'train') |
|
valid_dataset = build_dataset(1, 'valid') |
|
test_dataset = build_dataset(2, 'test') |
|
|
|
return (train_dataset, valid_dataset, test_dataset) |
|
|
|
|
|
def get_indexed_dataset_(data_prefix, data_impl, skip_warmup): |
|
|
|
print_rank_0(' > building dataset index ...') |
|
|
|
start_time = time.time() |
|
indexed_dataset = make_indexed_dataset(data_prefix, |
|
data_impl, |
|
skip_warmup) |
|
assert indexed_dataset.sizes.shape[0] == indexed_dataset.doc_idx[-1] |
|
print_rank_0(' > finished creating indexed dataset in {:4f} ' |
|
'seconds'.format(time.time() - start_time)) |
|
|
|
print_rank_0(' > indexed dataset stats:') |
|
print_rank_0(' number of documents: {}'.format( |
|
indexed_dataset.doc_idx.shape[0] - 1)) |
|
print_rank_0(' number of sentences: {}'.format( |
|
indexed_dataset.sizes.shape[0])) |
|
|
|
return indexed_dataset |
|
|
|
|
|
def get_train_valid_test_split_(splits_string, size): |
|
""" Get dataset splits from comma or '/' separated string list.""" |
|
|
|
splits = [] |
|
if splits_string.find(',') != -1: |
|
splits = [float(s) for s in splits_string.split(',')] |
|
elif splits_string.find('/') != -1: |
|
splits = [float(s) for s in splits_string.split('/')] |
|
else: |
|
splits = [float(splits_string)] |
|
while len(splits) < 3: |
|
splits.append(0.) |
|
splits = splits[:3] |
|
splits_sum = sum(splits) |
|
assert splits_sum > 0.0 |
|
splits = [split / splits_sum for split in splits] |
|
splits_index = [0] |
|
for index, split in enumerate(splits): |
|
splits_index.append(splits_index[index] + |
|
int(round(split * float(size)))) |
|
diff = splits_index[-1] - size |
|
for index in range(1, len(splits_index)): |
|
splits_index[index] -= diff |
|
assert len(splits_index) == 4 |
|
assert splits_index[-1] == size |
|
return splits_index |
|
|
|
|
|
def get_samples_mapping(indexed_dataset, |
|
data_prefix, |
|
num_epochs, |
|
max_num_samples, |
|
max_seq_length, |
|
short_seq_prob, |
|
seed, |
|
name, |
|
binary_head): |
|
"""Get a list that maps a sample index to a starting |
|
sentence index, end sentence index, and length""" |
|
|
|
if not num_epochs: |
|
if not max_num_samples: |
|
raise ValueError("Need to specify either max_num_samples " |
|
"or num_epochs") |
|
num_epochs = np.iinfo(np.int32).max - 1 |
|
if not max_num_samples: |
|
max_num_samples = np.iinfo(np.int64).max - 1 |
|
|
|
|
|
indexmap_filename = data_prefix |
|
indexmap_filename += '_{}_indexmap'.format(name) |
|
if num_epochs != (np.iinfo(np.int32).max - 1): |
|
indexmap_filename += '_{}ep'.format(num_epochs) |
|
if max_num_samples != (np.iinfo(np.int64).max - 1): |
|
indexmap_filename += '_{}mns'.format(max_num_samples) |
|
indexmap_filename += '_{}msl'.format(max_seq_length) |
|
indexmap_filename += '_{:0.2f}ssp'.format(short_seq_prob) |
|
indexmap_filename += '_{}s'.format(seed) |
|
indexmap_filename += '.npy' |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print_rank_0(' > loading indexed mapping from {}'.format( |
|
indexmap_filename)) |
|
start_time = time.time() |
|
samples_mapping = np.load( |
|
indexmap_filename, allow_pickle=True, mmap_mode='r') |
|
print_rank_0(' loaded indexed file in {:3.3f} seconds'.format( |
|
time.time() - start_time)) |
|
print_rank_0(' total number of samples: {}'.format( |
|
samples_mapping.shape[0])) |
|
|
|
return samples_mapping |
|
|