titu_lm_55gb_72k / tokenization_bn.py
muntasir2000's picture
Update tokenization_bn.py
10f8197
"""
Forked from the file src/transformers/models/bert_generation/tokenization_bert_generation.py from the HuggingFace Transformers library.
Permalink: https://github.com/huggingface/transformers/blob/04ab5605fbb4ef207b10bf2772d88c53fc242e83/src/transformers/models/bert_generation/tokenization_bert_generation.py
"""
import os
import sentencepiece as spm
from shutil import copyfile
from transformers import PreTrainedTokenizer
from typing import Any, Dict, List, Optional, Tuple
VOCAB_FILES_NAMES = {'vocab_file': 'tokenizer.model'}
class BNTokenizer(PreTrainedTokenizer):
"""
Construct a ReplitLMTokenizer tokenizer. Based on [SentencePiece](https://github.com/google/sentencepiece).
This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods.
Args:
vocab_file (`str`):
[SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that
contains the vocabulary necessary to instantiate a tokenizer.
eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
The end of sequence token.
bos_token (`str`, *optional*, defaults to `None`):
The begin of sequence token.
unk_token (`str`, *optional*, defaults to `"<|unk|>"`):
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
token instead.
pad_token (`str`, *optional*, defaults to `"<|pad|>"`):
The token used for padding, for example when batching sequences of different lengths.
sp_model_kwargs (`dict`, *optional*):
Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for
SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things,
to set:
- `enable_sampling`: Enable subword regularization.
- `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout.
- `nbest_size = {0,1}`: No sampling is performed.
- `nbest_size > 1`: samples from the nbest_size results.
- `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice)
using forward-filtering-and-backward-sampling algorithm.
- `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for
BPE-dropout.
"""
vocab_files_names = VOCAB_FILES_NAMES
prefix_tokens: List[int] = []
model_input_names = ['input_ids', 'attention_mask']
def __init__(self, vocab_file, bos_token=None, eos_token='</s>', unk_token='<unk>', pad_token='<|reserved001|>', sep_token=None, sp_model_kwargs: Optional[Dict[str, Any]]=None, **kwargs) -> None:
self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
self.vocab_file = vocab_file
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
self.sp_model.Load(vocab_file)
super().__init__(bos_token=bos_token, eos_token=eos_token, unk_token=unk_token, pad_token=pad_token, sep_token=sep_token, sp_model_kwargs=self.sp_model_kwargs, **kwargs)
@property
def vocab_size(self):
return self.sp_model.get_piece_size()
def get_vocab(self):
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
vocab.update(self.added_tokens_encoder)
return vocab
def __getstate__(self):
state = self.__dict__.copy()
state['sp_model'] = None
return state
def __setstate__(self, d):
self.__dict__ = d
if not hasattr(self, 'sp_model_kwargs'):
self.sp_model_kwargs = {}
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
self.sp_model.load(self.vocab_file)
def _tokenize(self, text: str) -> List[str]:
"""Take as input a string and return a list of strings (tokens) for words/sub-words"""
return self.sp_model.encode(text, out_type=str)
def _convert_token_to_id(self, token):
"""Converts a token (str) in an id using the vocab."""
return self.sp_model.piece_to_id(token)
def _convert_id_to_token(self, index):
"""Converts an index (integer) in a token (str) using the vocab."""
token = self.sp_model.id_to_piece(index)
return token
def convert_tokens_to_string(self, tokens):
"""Converts a sequence of tokens (string) in a single string."""
return self.sp_model.decode(tokens)
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str]=None) -> Tuple[str]:
if not os.path.isdir(save_directory):
raise ValueError(f'Vocabulary path ({save_directory}) should be a directory')
out_vocab_file = os.path.join(save_directory, (filename_prefix + '-' if filename_prefix else '') + VOCAB_FILES_NAMES['vocab_file'])
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
copyfile(self.vocab_file, out_vocab_file)
elif not os.path.isfile(self.vocab_file):
with open(out_vocab_file, 'wb') as fi:
content_spiece_model = self.sp_model.serialized_model_proto()
fi.write(content_spiece_model)
return (out_vocab_file,)