# coding=utf-8 # Copyright 2018 The Google AI Language Team Authors, and NVIDIA. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Most of the code here has been copied from: # https://github.com/google-research/albert/blob/master/create_pretraining_data.py # with some modifications. import math import time import collections import numpy as np import re from fengshen.data.megatron_dataloader.utils import ( print_rank_0 ) from fengshen.data.megatron_dataloader.blendable_dataset import BlendableDataset from fengshen.data.megatron_dataloader.indexed_dataset import make_dataset as make_indexed_dataset DSET_TYPE_BERT = 'standard_bert' DSET_TYPE_ICT = 'ict' DSET_TYPE_T5 = 't5' DSET_TYPE_BERT_CN_WWM = 'bert_cn_wwm' DSET_TYPE_BART = 'bart' DSET_TYPE_COCOLM = 'coco_lm' DSET_TYPES = [DSET_TYPE_BERT, DSET_TYPE_ICT, DSET_TYPE_T5, DSET_TYPE_BERT_CN_WWM, DSET_TYPE_BART, DSET_TYPE_COCOLM] def get_datasets_weights_and_num_samples(data_prefix, train_valid_test_num_samples): # The data prefix should be in the format of: # weight-1, data-prefix-1, weight-2, data-prefix-2, .. assert len(data_prefix) % 2 == 0 num_datasets = len(data_prefix) // 2 weights = [0] * num_datasets prefixes = [0] * num_datasets for i in range(num_datasets): weights[i] = float(data_prefix[2 * i]) prefixes[i] = (data_prefix[2 * i + 1]).strip() # Normalize weights weight_sum = 0.0 for weight in weights: weight_sum += weight assert weight_sum > 0.0 weights = [weight / weight_sum for weight in weights] # Add 0.5% (the 1.005 factor) so in case the bleding dataset does # not uniformly distribute the number of samples, we still have # samples left to feed to the network. datasets_train_valid_test_num_samples = [] for weight in weights: datasets_train_valid_test_num_samples.append( [int(math.ceil(val * weight * 1.005)) for val in train_valid_test_num_samples]) return prefixes, weights, datasets_train_valid_test_num_samples def compile_helper(): """Compile helper function ar runtime. Make sure this is invoked on a single process.""" import os import subprocess path = os.path.abspath(os.path.dirname(__file__)) ret = subprocess.run(['make', '-C', path]) if ret.returncode != 0: print("Making C++ dataset helpers module failed, exiting.") import sys sys.exit(1) def get_a_and_b_segments(sample, np_rng): """Divide sample into a and b segments.""" # Number of sentences in the sample. n_sentences = len(sample) # Make sure we always have two sentences. assert n_sentences > 1, 'make sure each sample has at least two sentences.' # First part: # `a_end` is how many sentences go into the `A`. a_end = 1 if n_sentences >= 3: # Note that randin in numpy is exclusive. a_end = np_rng.randint(1, n_sentences) tokens_a = [] for j in range(a_end): tokens_a.extend(sample[j]) # Second part: tokens_b = [] for j in range(a_end, n_sentences): tokens_b.extend(sample[j]) # Random next: is_next_random = False if np_rng.random() < 0.5: is_next_random = True tokens_a, tokens_b = tokens_b, tokens_a return tokens_a, tokens_b, is_next_random def truncate_segments(tokens_a, tokens_b, len_a, len_b, max_num_tokens, np_rng): """Truncates a pair of sequences to a maximum sequence length.""" # print(len_a, len_b, max_num_tokens) assert len_a > 0 if len_a + len_b <= max_num_tokens: return False while len_a + len_b > max_num_tokens: if len_a > len_b: len_a -= 1 tokens = tokens_a else: len_b -= 1 tokens = tokens_b if np_rng.random() < 0.5: del tokens[0] else: tokens.pop() return True def create_tokens_and_tokentypes(tokens_a, tokens_b, cls_id, sep_id): """Merge segments A and B, add [CLS] and [SEP] and build tokentypes.""" tokens = [] tokentypes = [] # [CLS]. tokens.append(cls_id) tokentypes.append(0) # Segment A. for token in tokens_a: tokens.append(token) tokentypes.append(0) # [SEP]. tokens.append(sep_id) tokentypes.append(0) # Segment B. for token in tokens_b: tokens.append(token) tokentypes.append(1) if tokens_b: # [SEP]. tokens.append(sep_id) tokentypes.append(1) return tokens, tokentypes MaskedLmInstance = collections.namedtuple("MaskedLmInstance", ["index", "label"]) def is_start_piece(piece): """Check if the current word piece is the starting piece (BERT).""" # When a word has been split into # WordPieces, the first token does not have any marker and any subsequence # tokens are prefixed with ##. So whenever we see the ## token, we # append it to the previous set of word indexes. return not piece.startswith("##") def create_masked_lm_predictions(tokens, vocab_id_list, vocab_id_to_token_dict, masked_lm_prob, cls_id, sep_id, mask_id, max_predictions_per_seq, np_rng, tokenizer, max_ngrams=3, do_whole_word_mask=True, favor_longer_ngram=False, do_permutation=False, geometric_dist=False, masking_style="bert", zh_tokenizer=None): """Creates the predictions for the masked LM objective. Note: Tokens here are vocab ids and not text tokens.""" cand_indexes = [] # Note(mingdachen): We create a list for recording if the piece is # the starting piece of current token, where 1 means true, so that # on-the-fly whole word masking is possible. token_boundary = [0] * len(tokens) # 如果没有指定中文分词器,那就直接按##算 if zh_tokenizer is None: for (i, token) in enumerate(tokens): if token == cls_id or token == sep_id: token_boundary[i] = 1 continue # Whole Word Masking means that if we mask all of the wordpieces # corresponding to an original word. # # Note that Whole Word Masking does *not* change the training code # at all -- we still predict each WordPiece independently, softmaxed # over the entire vocabulary. if (do_whole_word_mask and len(cand_indexes) >= 1 and not is_start_piece(vocab_id_to_token_dict[token])): cand_indexes[-1].append(i) else: cand_indexes.append([i]) if is_start_piece(vocab_id_to_token_dict[token]): token_boundary[i] = 1 else: # 如果指定了中文分词器,那就先用分词器分词,然后再进行判断 # 获取去掉CLS SEP的原始文本 raw_tokens = [] for t in tokens: if t != cls_id and t != sep_id: raw_tokens.append(t) raw_tokens = [vocab_id_to_token_dict[i] for i in raw_tokens] # 分词然后获取每次字开头的最长词的长度 word_list = set(zh_tokenizer(''.join(raw_tokens), HMM=True)) word_length_dict = {} for w in word_list: if len(w) < 1: continue if w[0] not in word_length_dict: word_length_dict[w[0]] = len(w) elif word_length_dict[w[0]] < len(w): word_length_dict[w[0]] = len(w) i = 0 # 从词表里面检索 while i < len(tokens): token_id = tokens[i] token = vocab_id_to_token_dict[token_id] if len(token) == 0 or token_id == cls_id or token_id == sep_id: token_boundary[i] = 1 i += 1 continue word_max_length = 1 if token[0] in word_length_dict: word_max_length = word_length_dict[token[0]] j = 0 word = '' word_end = i+1 # 兼容以前##的形式,如果后面的词是##开头的,那么直接把后面的拼到前面当作一个词 old_style = False while word_end < len(tokens) and vocab_id_to_token_dict[tokens[word_end]].startswith('##'): old_style = True word_end += 1 if not old_style: while j < word_max_length and i+j < len(tokens): cur_token = tokens[i+j] word += vocab_id_to_token_dict[cur_token] j += 1 if word in word_list: word_end = i+j cand_indexes.append([p for p in range(i, word_end)]) token_boundary[i] = 1 i = word_end output_tokens = list(tokens) # add by ganruyi if masking_style == 'bert-cn-wwm': # if non chinese is False, that means it is chinese # then try to remove "##" which is added previously new_token_ids = [] for token_id in output_tokens: token = tokenizer.convert_ids_to_tokens([token_id])[0] if len(re.findall('##[\u4E00-\u9FA5]', token)) > 0: token = token[2:] new_token_id = tokenizer.convert_tokens_to_ids([token])[ 0] new_token_ids.append(new_token_id) output_tokens = new_token_ids masked_lm_positions = [] masked_lm_labels = [] if masked_lm_prob == 0: return (output_tokens, masked_lm_positions, masked_lm_labels, token_boundary) num_to_predict = min(max_predictions_per_seq, max(1, int(round(len(tokens) * masked_lm_prob)))) ngrams = np.arange(1, max_ngrams + 1, dtype=np.int64) if not geometric_dist: # Note(mingdachen): # By default, we set the probilities to favor shorter ngram sequences. pvals = 1. / np.arange(1, max_ngrams + 1) pvals /= pvals.sum(keepdims=True) if favor_longer_ngram: pvals = pvals[::-1] # 获取一个ngram的idx,对于每个word,记录他的ngram的word ngram_indexes = [] for idx in range(len(cand_indexes)): ngram_index = [] for n in ngrams: ngram_index.append(cand_indexes[idx:idx + n]) ngram_indexes.append(ngram_index) np_rng.shuffle(ngram_indexes) (masked_lms, masked_spans) = ([], []) covered_indexes = set() for cand_index_set in ngram_indexes: if len(masked_lms) >= num_to_predict: break if not cand_index_set: continue # Note(mingdachen): # Skip current piece if they are covered in lm masking or previous ngrams. for index_set in cand_index_set[0]: for index in index_set: if index in covered_indexes: continue if not geometric_dist: n = np_rng.choice(ngrams[:len(cand_index_set)], p=pvals[:len(cand_index_set)] / pvals[:len(cand_index_set)].sum(keepdims=True)) else: # Sampling "n" from the geometric distribution and clipping it to # the max_ngrams. Using p=0.2 default from the SpanBERT paper # https://arxiv.org/pdf/1907.10529.pdf (Sec 3.1) n = min(np_rng.geometric(0.2), max_ngrams) index_set = sum(cand_index_set[n - 1], []) n -= 1 # Note(mingdachen): # Repeatedly looking for a candidate that does not exceed the # maximum number of predictions by trying shorter ngrams. while len(masked_lms) + len(index_set) > num_to_predict: if n == 0: break index_set = sum(cand_index_set[n - 1], []) n -= 1 # If adding a whole-word mask would exceed the maximum number of # predictions, then just skip this candidate. if len(masked_lms) + len(index_set) > num_to_predict: continue is_any_index_covered = False for index in index_set: if index in covered_indexes: is_any_index_covered = True break if is_any_index_covered: continue for index in index_set: covered_indexes.add(index) masked_token = None if masking_style == "bert": # 80% of the time, replace with [MASK] if np_rng.random() < 0.8: masked_token = mask_id else: # 10% of the time, keep original if np_rng.random() < 0.5: masked_token = tokens[index] # 10% of the time, replace with random word else: masked_token = vocab_id_list[np_rng.randint(0, len(vocab_id_list))] elif masking_style == 'bert-cn-wwm': # 80% of the time, replace with [MASK] if np_rng.random() < 0.8: masked_token = mask_id else: # 10% of the time, keep original if np_rng.random() < 0.5: # 如果是中文全词mask,去掉tokens里的## token_id = tokens[index] token = tokenizer.convert_ids_to_tokens([token_id])[ 0] if len(re.findall('##[\u4E00-\u9FA5]', token)) > 0: token = token[2:] new_token_id = tokenizer.convert_tokens_to_ids([token])[ 0] masked_token = new_token_id # 10% of the time, replace with random word else: masked_token = vocab_id_list[np_rng.randint( 0, len(vocab_id_list))] elif masking_style == "t5": masked_token = mask_id else: raise ValueError("invalid value of masking style") output_tokens[index] = masked_token masked_lms.append(MaskedLmInstance( index=index, label=tokens[index])) masked_spans.append(MaskedLmInstance( index=index_set, label=[tokens[index] for index in index_set])) assert len(masked_lms) <= num_to_predict np_rng.shuffle(ngram_indexes) select_indexes = set() if do_permutation: for cand_index_set in ngram_indexes: if len(select_indexes) >= num_to_predict: break if not cand_index_set: continue # Note(mingdachen): # Skip current piece if they are covered in lm masking or previous ngrams. for index_set in cand_index_set[0]: for index in index_set: if index in covered_indexes or index in select_indexes: continue n = np.random.choice(ngrams[:len(cand_index_set)], p=pvals[:len(cand_index_set)] / pvals[:len(cand_index_set)].sum(keepdims=True)) index_set = sum(cand_index_set[n - 1], []) n -= 1 while len(select_indexes) + len(index_set) > num_to_predict: if n == 0: break index_set = sum(cand_index_set[n - 1], []) n -= 1 # If adding a whole-word mask would exceed the maximum number of # predictions, then just skip this candidate. if len(select_indexes) + len(index_set) > num_to_predict: continue is_any_index_covered = False for index in index_set: if index in covered_indexes or index in select_indexes: is_any_index_covered = True break if is_any_index_covered: continue for index in index_set: select_indexes.add(index) assert len(select_indexes) <= num_to_predict select_indexes = sorted(select_indexes) permute_indexes = list(select_indexes) np_rng.shuffle(permute_indexes) orig_token = list(output_tokens) for src_i, tgt_i in zip(select_indexes, permute_indexes): output_tokens[src_i] = orig_token[tgt_i] masked_lms.append(MaskedLmInstance( index=src_i, label=orig_token[src_i])) masked_lms = sorted(masked_lms, key=lambda x: x.index) # Sort the spans by the index of the first span masked_spans = sorted(masked_spans, key=lambda x: x.index[0]) for p in masked_lms: masked_lm_positions.append(p.index) masked_lm_labels.append(p.label) return (output_tokens, masked_lm_positions, masked_lm_labels, token_boundary, masked_spans) def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions, masked_labels, pad_id, max_seq_length): """Pad sequences and convert them to numpy.""" # Some checks. num_tokens = len(tokens) padding_length = max_seq_length - num_tokens assert padding_length >= 0 assert len(tokentypes) == num_tokens assert len(masked_positions) == len(masked_labels) # Tokens and token types. filler = [pad_id] * padding_length tokens_np = np.array(tokens + filler, dtype=np.int64) tokentypes_np = np.array(tokentypes + filler, dtype=np.int64) # Padding mask. padding_mask_np = np.array([1] * num_tokens + [0] * padding_length, dtype=np.int64) # Lables and loss mask. labels = [-1] * max_seq_length loss_mask = [0] * max_seq_length for i in range(len(masked_positions)): assert masked_positions[i] < num_tokens labels[masked_positions[i]] = masked_labels[i] loss_mask[masked_positions[i]] = 1 labels_np = np.array(labels, dtype=np.int64) loss_mask_np = np.array(loss_mask, dtype=np.int64) return tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, train_valid_test_num_samples, max_seq_length, masked_lm_prob, short_seq_prob, seed, tokenizer, skip_warmup, binary_head=False, max_seq_length_dec=None, dataset_type='standard_bert', zh_tokenizer=None, span=None): if len(data_prefix) == 1: return _build_train_valid_test_datasets(data_prefix[0], data_impl, splits_string, train_valid_test_num_samples, max_seq_length, masked_lm_prob, short_seq_prob, seed, skip_warmup, binary_head, max_seq_length_dec, tokenizer, dataset_type=dataset_type, zh_tokenizer=zh_tokenizer, span=span) # Blending dataset. # Parse the values. output = get_datasets_weights_and_num_samples(data_prefix, train_valid_test_num_samples) prefixes, weights, datasets_train_valid_test_num_samples = output # Build individual datasets. train_datasets = [] valid_datasets = [] test_datasets = [] for i in range(len(prefixes)): train_ds, valid_ds, test_ds = _build_train_valid_test_datasets( prefixes[i], data_impl, splits_string, datasets_train_valid_test_num_samples[i], max_seq_length, masked_lm_prob, short_seq_prob, seed, skip_warmup, binary_head, max_seq_length_dec, tokenizer, dataset_type=dataset_type, zh_tokenizer=zh_tokenizer) if train_ds: train_datasets.append(train_ds) if valid_ds: valid_datasets.append(valid_ds) if test_ds: test_datasets.append(test_ds) # Blend. blending_train_dataset = None if train_datasets: blending_train_dataset = BlendableDataset(train_datasets, weights) blending_valid_dataset = None if valid_datasets: blending_valid_dataset = BlendableDataset(valid_datasets, weights) blending_test_dataset = None if test_datasets: blending_test_dataset = BlendableDataset(test_datasets, weights) return (blending_train_dataset, blending_valid_dataset, blending_test_dataset) def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string, train_valid_test_num_samples, max_seq_length, masked_lm_prob, short_seq_prob, seed, skip_warmup, binary_head, max_seq_length_dec, tokenizer, dataset_type='standard_bert', zh_tokenizer=None, span=None): if dataset_type not in DSET_TYPES: raise ValueError("Invalid dataset_type: ", dataset_type) # Indexed dataset. indexed_dataset = get_indexed_dataset_(data_prefix, data_impl, skip_warmup) # Get start and end indices of train/valid/train into doc-idx # Note that doc-idx is desinged to be num-docs + 1 so we can # easily iterate over it. total_num_of_documents = indexed_dataset.doc_idx.shape[0] - 1 splits = get_train_valid_test_split_(splits_string, total_num_of_documents) # Print stats about the splits. print_rank_0(' > dataset split:') def print_split_stats(name, index): print_rank_0(' {}:'.format(name)) print_rank_0(' document indices in [{}, {}) total of {} ' 'documents'.format(splits[index], splits[index + 1], splits[index + 1] - splits[index])) start_index = indexed_dataset.doc_idx[splits[index]] end_index = indexed_dataset.doc_idx[splits[index + 1]] print_rank_0(' sentence indices in [{}, {}) total of {} ' 'sentences'.format(start_index, end_index, end_index - start_index)) print_split_stats('train', 0) print_split_stats('validation', 1) print_split_stats('test', 2) def build_dataset(index, name): from fengshen.data.megatron_dataloader.bert_dataset import BertDataset from fengshen.data.megatron_dataloader.bart_dataset import BartDataset from fengshen.data.megatron_dataloader.cocolm_dataset import COCOLMDataset dataset = None if splits[index + 1] > splits[index]: # Get the pointer to the original doc-idx so we can set it later. doc_idx_ptr = indexed_dataset.get_doc_idx() # Slice the doc-idx start_index = splits[index] # Add +1 so we can index into the dataset to get the upper bound. end_index = splits[index + 1] + 1 # New doc_idx view. indexed_dataset.set_doc_idx(doc_idx_ptr[start_index:end_index]) # Build the dataset accordingly. kwargs = dict( name=name, data_prefix=data_prefix, num_epochs=None, max_num_samples=train_valid_test_num_samples[index], max_seq_length=max_seq_length, seed=seed, ) if dataset_type == DSET_TYPE_BERT or dataset_type == DSET_TYPE_BERT_CN_WWM: dataset = BertDataset( indexed_dataset=indexed_dataset, masked_lm_prob=masked_lm_prob, short_seq_prob=short_seq_prob, binary_head=binary_head, # 增加参数区分bert和bert-cn-wwm tokenizer=tokenizer, masking_style='bert' if dataset_type == DSET_TYPE_BERT else 'bert-cn-wwm', **kwargs ) elif dataset_type == DSET_TYPE_BART: dataset = BartDataset( indexed_dataset=indexed_dataset, masked_lm_prob=masked_lm_prob, short_seq_prob=short_seq_prob, tokenizer=tokenizer, zh_tokenizer=zh_tokenizer, **kwargs ) elif dataset_type == DSET_TYPE_COCOLM: dataset = COCOLMDataset( indexed_dataset=indexed_dataset, masked_lm_prob=masked_lm_prob, short_seq_prob=short_seq_prob, tokenizer=tokenizer, masking_style='bert', span=span, **kwargs ) else: raise NotImplementedError( "Dataset type not fully implemented.") # Set the original pointer so dataset remains the main dataset. indexed_dataset.set_doc_idx(doc_idx_ptr) # Checks. assert indexed_dataset.doc_idx[0] == 0 assert indexed_dataset.doc_idx.shape[0] == \ (total_num_of_documents + 1) return dataset train_dataset = build_dataset(0, 'train') valid_dataset = build_dataset(1, 'valid') test_dataset = build_dataset(2, 'test') return (train_dataset, valid_dataset, test_dataset) def get_indexed_dataset_(data_prefix, data_impl, skip_warmup): print_rank_0(' > building dataset index ...') start_time = time.time() indexed_dataset = make_indexed_dataset(data_prefix, data_impl, skip_warmup) assert indexed_dataset.sizes.shape[0] == indexed_dataset.doc_idx[-1] print_rank_0(' > finished creating indexed dataset in {:4f} ' 'seconds'.format(time.time() - start_time)) print_rank_0(' > indexed dataset stats:') print_rank_0(' number of documents: {}'.format( indexed_dataset.doc_idx.shape[0] - 1)) print_rank_0(' number of sentences: {}'.format( indexed_dataset.sizes.shape[0])) return indexed_dataset def get_train_valid_test_split_(splits_string, size): """ Get dataset splits from comma or '/' separated string list.""" splits = [] if splits_string.find(',') != -1: splits = [float(s) for s in splits_string.split(',')] elif splits_string.find('/') != -1: splits = [float(s) for s in splits_string.split('/')] else: splits = [float(splits_string)] while len(splits) < 3: splits.append(0.) splits = splits[:3] splits_sum = sum(splits) assert splits_sum > 0.0 splits = [split / splits_sum for split in splits] splits_index = [0] for index, split in enumerate(splits): splits_index.append(splits_index[index] + int(round(split * float(size)))) diff = splits_index[-1] - size for index in range(1, len(splits_index)): splits_index[index] -= diff assert len(splits_index) == 4 assert splits_index[-1] == size return splits_index def get_samples_mapping(indexed_dataset, data_prefix, num_epochs, max_num_samples, max_seq_length, short_seq_prob, seed, name, binary_head): """Get a list that maps a sample index to a starting sentence index, end sentence index, and length""" if not num_epochs: if not max_num_samples: raise ValueError("Need to specify either max_num_samples " "or num_epochs") num_epochs = np.iinfo(np.int32).max - 1 if not max_num_samples: max_num_samples = np.iinfo(np.int64).max - 1 # Filename of the index mapping indexmap_filename = data_prefix indexmap_filename += '_{}_indexmap'.format(name) if num_epochs != (np.iinfo(np.int32).max - 1): indexmap_filename += '_{}ep'.format(num_epochs) if max_num_samples != (np.iinfo(np.int64).max - 1): indexmap_filename += '_{}mns'.format(max_num_samples) indexmap_filename += '_{}msl'.format(max_seq_length) indexmap_filename += '_{:0.2f}ssp'.format(short_seq_prob) indexmap_filename += '_{}s'.format(seed) indexmap_filename += '.npy' # This should be a barrier but nccl barrier assumes # device_index=rank which is not the case for model # parallel case # ganruyi comment # counts = torch.cuda.LongTensor([1]) # torch.distributed.all_reduce( # counts, group=mpu.get_data_parallel_group()) # torch.distributed.all_reduce( # counts, group=mpu.get_pipeline_model_parallel_group()) # assert counts[0].item() == ( # torch.distributed.get_world_size() // # torch.distributed.get_world_size( # group=mpu.get_tensor_model_parallel_group())) # Load indexed dataset. print_rank_0(' > loading indexed mapping from {}'.format( indexmap_filename)) start_time = time.time() samples_mapping = np.load( indexmap_filename, allow_pickle=True, mmap_mode='r') print_rank_0(' loaded indexed file in {:3.3f} seconds'.format( time.time() - start_time)) print_rank_0(' total number of samples: {}'.format( samples_mapping.shape[0])) return samples_mapping