cognitivess / cognitivess_model /tokenization_cognitivess.py
cognitivess's picture
Update cognitivess_model/tokenization_cognitivess.py
cab70b7 verified
from transformers import PreTrainedTokenizer
import json
import os
class CognitivessTokenizer(PreTrainedTokenizer):
def __init__(self, vocab_file, merges_file, **kwargs):
super().__init__(**kwargs)
self.vocab_file = vocab_file
self.merges_file = merges_file
self.encoder = self.load_vocab(vocab_file)
self.decoder = {v: k for k, v in self.encoder.items()}
with open(merges_file, encoding="utf-8") as merges_handle:
bpe_data = merges_handle.read()
bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split("\n")[1:-1]]
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *init_inputs, **kwargs):
vocab_file = os.path.join(pretrained_model_name_or_path, "vocab.json")
merges_file = os.path.join(pretrained_model_name_or_path, "merges.txt")
return cls(vocab_file, merges_file, **kwargs)
@property
def vocab_size(self):
return len(self.encoder)
def get_vocab(self):
return dict(self.encoder)
def _tokenize(self, text):
return text.split()
def _convert_token_to_id(self, token):
return self.encoder.get(token, self.encoder.get(self.unk_token))
def _convert_id_to_token(self, index):
return self.decoder.get(index, self.unk_token)
def convert_tokens_to_string(self, tokens):
return " ".join(tokens)
def save_vocabulary(self, save_directory):
if not os.path.isdir(save_directory):
os.makedirs(save_directory)
vocab_file = os.path.join(save_directory, "vocab.json")
merges_file = os.path.join(save_directory, "merges.txt")
with open(vocab_file, "w", encoding="utf-8") as vocab_handle:
json.dump(self.encoder, vocab_handle, ensure_ascii=False)
with open(merges_file, "w", encoding="utf-8") as merges_handle:
merges_handle.write("\n".join(" ".join(pair) for pair in self.bpe_ranks.keys()))
return (vocab_file, merges_file)
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
bos_token_id = [self.bos_token_id]
eos_token_id = [self.eos_token_id]
return bos_token_id + token_ids_0 + eos_token_id
def load_vocab(self, vocab_file):
with open(vocab_file, "r", encoding="utf-8") as f:
return json.load(f)