|
from transformers import PreTrainedTokenizerFast |
|
import numpy |
|
import torch |
|
|
|
class ModernDecoderBERTTokenizer(PreTrainedTokenizerFast): |
|
|
|
def _batch_encode_plus(self, *args, **kwargs): |
|
outputs = super()._batch_encode_plus(*args, **kwargs) |
|
del outputs["token_type_ids"] |
|
|
|
|
|
input_ids = outputs['input_ids'] |
|
|
|
|
|
def ends_with_eos(sequence): |
|
if len(sequence) == 0: |
|
return False |
|
return sequence[-1] == self.eos_token_id |
|
|
|
|
|
if isinstance(input_ids, torch.Tensor): |
|
last_token_is_eos = torch.tensor([ |
|
ends_with_eos(seq) for seq in input_ids |
|
], dtype=torch.bool) |
|
|
|
if last_token_is_eos.all(): |
|
|
|
for key in ['input_ids', 'attention_mask']: |
|
outputs[key] = outputs[key][..., :-1] |
|
elif last_token_is_eos.any(): |
|
|
|
batch_size = input_ids.shape[0] |
|
for i in range(batch_size): |
|
if last_token_is_eos[i]: |
|
for key in ['input_ids', 'attention_mask']: |
|
|
|
truncated = outputs[key][i, :-1] |
|
outputs[key][i] = torch.cat([ |
|
torch.zeros_like(truncated[:1]), |
|
truncated |
|
]) |
|
|
|
elif isinstance(input_ids, numpy.ndarray): |
|
last_token_is_eos = numpy.array([ |
|
ends_with_eos(seq) for seq in input_ids |
|
], dtype=bool) |
|
|
|
if last_token_is_eos.all(): |
|
|
|
for key in ['input_ids', 'attention_mask']: |
|
outputs[key] = outputs[key][..., :-1] |
|
elif last_token_is_eos.any(): |
|
batch_size = input_ids.shape[0] |
|
for i in range(batch_size): |
|
if last_token_is_eos[i]: |
|
for key in ['input_ids', 'attention_mask']: |
|
|
|
truncated = outputs[key][i, :-1] |
|
outputs[key][i] = numpy.concatenate([ |
|
numpy.zeros_like(truncated[:1]), |
|
truncated |
|
]) |
|
|
|
elif isinstance(input_ids, list): |
|
last_token_is_eos = [ends_with_eos(seq) for seq in input_ids] |
|
|
|
if all(last_token_is_eos): |
|
|
|
for key in ['input_ids', 'attention_mask']: |
|
outputs[key] = [sequence[:-1] for sequence in outputs[key]] |
|
elif any(last_token_is_eos): |
|
for key in ['input_ids', 'attention_mask']: |
|
outputs[key] = [ |
|
[0] + sequence[:-1] if is_eos else sequence |
|
for sequence, is_eos in zip(outputs[key], last_token_is_eos) |
|
] |
|
|
|
return outputs |
|
|
|
|
|
|
|
from transformers import AutoTokenizer |
|
AutoTokenizer.register(ModernDecoderBERTTokenizer, fast_tokenizer_class=ModernDecoderBERTTokenizer) |