fclong's picture
Upload 396 files
8ebda9e
"""BART Style dataset. Modified from fairseq."""
import numpy as np
import torch
import math
import re
from fengshen.data.megatron_dataloader.dataset_utils import (
get_samples_mapping
)
class BartDataset(torch.utils.data.Dataset):
def __init__(self, name, indexed_dataset, data_prefix,
num_epochs, max_num_samples, masked_lm_prob,
max_seq_length, short_seq_prob, seed, tokenizer, zh_tokenizer):
# Params to store.
self.name = name
self.seed = seed
self.masked_lm_prob = masked_lm_prob
self.max_seq_length = max_seq_length
# Dataset.
self.indexed_dataset = indexed_dataset
# Build the samples mapping.
self.samples_mapping = get_samples_mapping(self.indexed_dataset,
data_prefix,
num_epochs,
max_num_samples,
self.max_seq_length - 3, # account for added tokens
short_seq_prob,
self.seed,
self.name,
False)
# Vocab stuff.
self.vocab_size = tokenizer.vocab_size
inv_vocab = {v: k for k, v in tokenizer.vocab.items()}
self.vocab_id_list = list(inv_vocab.keys())
self.vocab_id_to_token_dict = inv_vocab
self.cls_id = tokenizer.cls_token_id
self.sep_id = tokenizer.sep_token_id
self.mask_id = tokenizer.mask_token_id
self.pad_id = tokenizer.pad_token_id
self.tokenizer = tokenizer
seg_tokens = ['。', ';', ';', '!', '!', '?', '?']
seg_token_ids = []
for t in seg_tokens:
if t in tokenizer.vocab:
seg_token_ids.append(tokenizer.vocab[t])
else:
print('seg_token "{}" not in vocab'.format(t))
self.seg_token_ids = set(seg_token_ids)
self.zh_tokenizer = zh_tokenizer
# Denoising ratios
self.permute_sentence_ratio = 1.0
self.mask_ratio = masked_lm_prob # 0.15
self.random_ratio = 0.1
self.insert_ratio = 0.0
self.rotate_ratio = 0.0
self.mask_whole_word = 1
self.item_transform_func = None
self.mask_span_distribution = None
if False:
_lambda = 3 # Poisson lambda
lambda_to_the_k = 1
e_to_the_minus_lambda = math.exp(-_lambda)
k_factorial = 1
ps = []
for k in range(0, 128):
ps.append(e_to_the_minus_lambda * lambda_to_the_k / k_factorial)
lambda_to_the_k *= _lambda
k_factorial *= k + 1
if ps[-1] < 0.0000001:
break
ps = torch.FloatTensor(ps)
self.mask_span_distribution = torch.distributions.Categorical(ps)
def __len__(self):
return self.samples_mapping.shape[0]
def __getitem__(self, idx):
start_idx, end_idx, seq_length = self.samples_mapping[idx]
sample = [self.indexed_dataset[i] for i in range(start_idx, end_idx)]
# Note that this rng state should be numpy and not python since
# python randint is inclusive whereas the numpy one is exclusive.
# We % 2**32 since numpy requres the seed to be between 0 and 2**32 - 1
np_rng = np.random.RandomState(seed=((self.seed + idx) % 2**32))
return self.build_training_sample(sample, self.max_seq_length, np_rng)
def build_training_sample(self, sample, max_seq_length, np_rng):
"""Biuld training sample.
Arguments:
sample: A list of sentences in which each sentence is a list token ids.
max_seq_length: Desired sequence length.
np_rng: Random number genenrator. Note that this rng state should be
numpy and not python since python randint is inclusive for
the opper bound whereas the numpy one is exclusive.
"""
# permute sentences
full_stops = []
tokens = [self.cls_id]
for sent in sample:
for t in sent:
token = self.vocab_id_to_token_dict[t]
if len(re.findall('##[\u4E00-\u9FA5]', token)) > 0:
# 兼容erlangshen ##的方式做whole word mask
t = self.tokenizer.convert_tokens_to_ids(token[2:])
tokens.append(t)
if t in self.seg_token_ids:
tokens.append(self.sep_id)
if tokens[-1] != self.sep_id:
tokens.append(self.sep_id)
if len(tokens) > max_seq_length:
tokens = tokens[:max_seq_length]
tokens[-1] = self.sep_id
tokens = torch.LongTensor(tokens)
full_stops = (tokens == self.sep_id).long()
assert (max_seq_length - tokens.shape[0]) >= 0, (tokens.size(), tokens[-1], max_seq_length)
source, target = tokens, tokens[1:].clone()
use_decoder = 1
# if torch.rand(1).item() < 0.5:
# use_decoder = 0
if self.permute_sentence_ratio > 0.0 and use_decoder == 1:
source = self.permute_sentences(source, full_stops, self.permute_sentence_ratio)
if self.mask_ratio > 0.0:
replace_length = 1 if use_decoder else -1
mask_ratio = self.mask_ratio * 2 if use_decoder else self.mask_ratio
source = self.add_whole_word_mask(source, mask_ratio, replace_length)
if self.insert_ratio > 0.0:
raise NotImplementedError
source = self.add_insertion_noise(source, self.insert_ratio)
if self.rotate_ratio > 0.0 and np.random.random() < self.rotate_ratio:
raise NotImplementedError
source = self.add_rolling_noise(source)
# there can additional changes to make:
if self.item_transform_func is not None:
source, target = self.item_transform_func(source, target)
assert (source >= 0).all()
# assert (source[1:-1] >= 1).all()
assert (source <= self.vocab_size).all()
assert source[0] == self.cls_id
assert source[-1] == self.sep_id
# tokenizer = get_tokenizer()
# print(' '.join(tokenizer.tokenizer.convert_ids_to_tokens(source)))
# print(tokenizer.detokenize(target))
# print(tokenizer.detokenize(source))
# print()
prev_output_tokens = torch.zeros_like(target)
prev_output_tokens[0] = self.sep_id # match the preprocessing in fairseq
prev_output_tokens[1:] = target[:-1]
# src_padding_length = max_seq_length - source.shape[0]
# tgt_padding_length = max_seq_length - target.shape[0]
# assert src_padding_length >= 0, (source.size(), source[-1], max_seq_length)
# assert tgt_padding_length >= 0, (target.size(), target[-1], max_seq_length)
source_ = torch.full((max_seq_length,), self.pad_id, dtype=torch.long)
source_[:source.shape[0]] = source
target_ = torch.full((max_seq_length,), -100, dtype=torch.long)
# decoder not need bos in the front
target_[:target.shape[0]] = target
prev_output_tokens_ = torch.full((max_seq_length,), self.pad_id, dtype=torch.long)
prev_output_tokens_[:prev_output_tokens.shape[0]] = prev_output_tokens
return {
"input_ids": source_,
"labels": target_,
# "decoder_input_ids": prev_output_tokens_,
"attention_mask": (source_ != self.pad_id).long()
}
def permute_sentences(self, source, full_stops, p=1.0):
# Tokens that are full stops, where the previous token is not
sentence_ends = (full_stops[1:] * ~full_stops[:-1]).nonzero(as_tuple=False) + 2
result = source.clone()
num_sentences = sentence_ends.size(0)
num_to_permute = math.ceil((num_sentences * 2 * p) / 2.0)
substitutions = torch.randperm(num_sentences)[:num_to_permute]
ordering = torch.arange(0, num_sentences)
ordering[substitutions] = substitutions[torch.randperm(num_to_permute)]
# Ignore <bos> at start
index = 1
for i in ordering:
sentence = source[(sentence_ends[i - 1] if i > 0 else 1): sentence_ends[i]]
result[index: index + sentence.size(0)] = sentence
index += sentence.size(0)
return result
def word_starts_en(self, source):
if self.mask_whole_word is not None:
is_word_start = self.mask_whole_word.gather(0, source)
else:
is_word_start = torch.ones(source.size())
is_word_start[0] = 0
is_word_start[-1] = 0
return is_word_start
def word_starts(self, source):
if self.mask_whole_word is None:
is_word_start = torch.ones(source.size())
is_word_start[0] = 0
is_word_start[-1] = 0
return is_word_start
raw_tokens = [self.vocab_id_to_token_dict[i] for i in source.tolist()]
words = [raw_tokens[0]] + \
self.zh_tokenizer(''.join(raw_tokens[1:-1]), HMM=True) + [raw_tokens[-1]]
def _is_chinese_char(c):
"""Checks whether CP is the #codepoint of a CJK character."""
# This defines a "chinese character" as anything in the CJK Unicode block:
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
#
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
# despite its name. The modern Korean Hangul alphabet is a different block,
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
# space-separated words, so they are not treated specially and handled
# like the all of the other languages.
if len(c) > 1:
return all([_is_chinese_char(c_i) for c_i in c])
cp = ord(c)
if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
(cp >= 0x3400 and cp <= 0x4DBF) or #
(cp >= 0x20000 and cp <= 0x2A6DF) or #
(cp >= 0x2A700 and cp <= 0x2B73F) or #
(cp >= 0x2B740 and cp <= 0x2B81F) or #
(cp >= 0x2B820 and cp <= 0x2CEAF) or
(cp >= 0xF900 and cp <= 0xFAFF) or #
(cp >= 0x2F800 and cp <= 0x2FA1F)): #
return True
return False
def align_linear(atokens, btokens):
a2c = []
c2b = []
a2b = []
length = 0
for tok in atokens:
a2c.append([length + i for i in range(len(tok))])
length += len(tok)
for i, tok in enumerate(btokens):
c2b.extend([i for _ in range(len(tok))])
for i, amap in enumerate(a2c):
bmap = [c2b[ci] for ci in amap]
a2b.append(list(set(bmap)))
return a2b
raw_to_word_align = align_linear(raw_tokens, words)
is_word_start = torch.zeros(source.size())
word_starts = []
skip_cur_word = True
for i in range(1, len(raw_to_word_align)):
if raw_to_word_align[i-1] == raw_to_word_align[i]:
# not a word start, as they align to the same word
if not skip_cur_word and not _is_chinese_char(raw_tokens[i]):
word_starts.pop(-1)
skip_cur_word = True
continue
else:
is_word_start[i] = 1
if _is_chinese_char(raw_tokens[i]):
word_starts.append(i)
skip_cur_word = False
is_word_start[0] = 0
is_word_start[-1] = 0
word_starts = torch.tensor(word_starts).long().view(-1, 1)
return is_word_start, word_starts
def add_whole_word_mask(self, source, p, replace_length=1):
is_word_start, word_starts = self.word_starts(source)
num_to_mask_word = int(math.ceil(word_starts.size(0) * p))
num_to_mask_char = int(math.ceil(word_starts.size(0) * p * 0.1))
num_to_mask = num_to_mask_word + num_to_mask_char
if num_to_mask > word_starts.size(0):
word_starts = is_word_start.nonzero(as_tuple=False)
num_inserts = 0
if num_to_mask == 0:
return source
if self.mask_span_distribution is not None:
lengths = self.mask_span_distribution.sample(sample_shape=(num_to_mask,))
# Make sure we have enough to mask
cum_length = torch.cumsum(lengths, 0)
while cum_length[-1] < num_to_mask:
lengths = torch.cat(
[
lengths,
self.mask_span_distribution.sample(sample_shape=(num_to_mask,)),
],
dim=0,
)
cum_length = torch.cumsum(lengths, 0)
# Trim to masking budget
i = 0
while cum_length[i] < num_to_mask:
i += 1
lengths[i] = num_to_mask - (0 if i == 0 else cum_length[i - 1])
num_to_mask = i + 1
lengths = lengths[:num_to_mask]
# Handle 0-length mask (inserts) separately
lengths = lengths[lengths > 0]
num_inserts = num_to_mask - lengths.size(0)
num_to_mask -= num_inserts
if num_to_mask == 0:
return self.add_insertion_noise(source, num_inserts / source.size(0))
assert (lengths > 0).all()
else:
lengths = torch.ones((num_to_mask,)).long()
assert is_word_start[-1] == 0
indices = word_starts[
torch.randperm(word_starts.size(0))[:num_to_mask]
].squeeze(1)
mask_random = torch.FloatTensor(num_to_mask).uniform_() < self.random_ratio
source_length = source.size(0)
assert source_length - 1 not in indices
to_keep = torch.ones(source_length, dtype=torch.bool)
is_word_start[
-1
] = 255 # acts as a long length, so spans don't go over the end of doc
if replace_length == 0:
to_keep[indices] = 0
else:
# keep index, but replace it with [MASK]
# print(source.size(), word_starts.size(), indices.size(), mask_random.size())
source[indices] = self.mask_id
source[indices[mask_random]] = torch.randint(
1, self.vocab_size, size=(mask_random.sum(),)
)
# sorted_indices = torch.sort(indices)[0]
# continue_mask_pos = ((sorted_indices + 1)[:-1] == sorted_indices[1:])
# continue_mask_indices = sorted_indices[1:][continue_mask_pos]
# to_keep[continue_mask_indices] = 0
# for char indices, we already masked, the following loop handles word mask
indices = indices[:num_to_mask_word]
mask_random = mask_random[:num_to_mask_word]
if self.mask_span_distribution is not None:
assert len(lengths.size()) == 1
assert lengths.size() == indices.size()
lengths -= 1
while indices.size(0) > 0:
assert lengths.size() == indices.size()
lengths -= is_word_start[indices + 1].long()
uncompleted = lengths >= 0
indices = indices[uncompleted] + 1
mask_random = mask_random[uncompleted]
lengths = lengths[uncompleted]
if replace_length != -1:
# delete token
to_keep[indices] = 0
else:
# keep index, but replace it with [MASK]
source[indices] = self.mask_id
source[indices[mask_random]] = torch.randint(
1, self.vocab_size, size=(mask_random.sum(),)
)
else:
# A bit faster when all lengths are 1
while indices.size(0) > 0:
uncompleted = is_word_start[indices + 1] == 0
indices = indices[uncompleted] + 1
mask_random = mask_random[uncompleted]
if replace_length != -1:
# delete token
to_keep[indices] = 0
else:
# keep index, but replace it with [MASK]
source[indices] = self.mask_id
source[indices[mask_random]] = torch.randint(
1, self.vocab_size, size=(mask_random.sum(),)
)
assert source_length - 1 not in indices
source = source[to_keep]
if num_inserts > 0:
source = self.add_insertion_noise(source, num_inserts / source.size(0))
return source
def add_permuted_noise(self, tokens, p):
num_words = len(tokens)
num_to_permute = math.ceil(((num_words * 2) * p) / 2.0)
substitutions = torch.randperm(num_words - 2)[:num_to_permute] + 1
tokens[substitutions] = tokens[substitutions[torch.randperm(num_to_permute)]]
return tokens
def add_rolling_noise(self, tokens):
offset = np.random.randint(1, max(1, tokens.size(-1) - 1) + 1)
tokens = torch.cat(
(tokens[0:1], tokens[offset:-1], tokens[1:offset], tokens[-1:]),
dim=0,
)
return tokens
def add_insertion_noise(self, tokens, p):
if p == 0.0:
return tokens
num_tokens = len(tokens)
n = int(math.ceil(num_tokens * p))
noise_indices = torch.randperm(num_tokens + n - 2)[:n] + 1
noise_mask = torch.zeros(size=(num_tokens + n,), dtype=torch.bool)
noise_mask[noise_indices] = 1
result = torch.LongTensor(n + len(tokens)).fill_(-1)
num_random = int(math.ceil(n * self.random_ratio))
result[noise_indices[num_random:]] = self.mask_id
result[noise_indices[:num_random]] = torch.randint(
low=1, high=self.vocab_size, size=(num_random,)
)
result[~noise_mask] = tokens
assert (result >= 0).all()
return result