|
from transformers import PreTrainedTokenizer |
|
import json |
|
import os |
|
|
|
class CognitivessTokenizer(PreTrainedTokenizer): |
|
def __init__(self, vocab_file, merges_file, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
self.vocab_file = vocab_file |
|
self.merges_file = merges_file |
|
self.encoder = self.load_vocab(vocab_file) |
|
self.decoder = {v: k for k, v in self.encoder.items()} |
|
|
|
with open(merges_file, encoding="utf-8") as merges_handle: |
|
bpe_data = merges_handle.read() |
|
bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split("\n")[1:-1]] |
|
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) |
|
|
|
@property |
|
def vocab_size(self): |
|
return len(self.encoder) |
|
|
|
def get_vocab(self): |
|
return dict(self.encoder) |
|
|
|
def _tokenize(self, text): |
|
return text.split() |
|
|
|
def _convert_token_to_id(self, token): |
|
return self.encoder.get(token, self.encoder.get(self.unk_token)) |
|
|
|
def _convert_id_to_token(self, index): |
|
return self.decoder.get(index, self.unk_token) |
|
|
|
def convert_tokens_to_string(self, tokens): |
|
return " ".join(tokens) |
|
|
|
def save_vocabulary(self, save_directory): |
|
if not os.path.isdir(save_directory): |
|
os.makedirs(save_directory) |
|
vocab_file = os.path.join(save_directory, "vocab.json") |
|
merges_file = os.path.join(save_directory, "merges.txt") |
|
|
|
with open(vocab_file, "w", encoding="utf-8") as vocab_handle: |
|
json.dump(self.encoder, vocab_handle, ensure_ascii=False) |
|
with open(merges_file, "w", encoding="utf-8") as merges_handle: |
|
merges_handle.write("\n".join(" ".join(pair) for pair in self.bpe_ranks.keys())) |
|
|
|
return (vocab_file, merges_file) |
|
|
|
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): |
|
bos_token_id = [self.bos_token_id] |
|
eos_token_id = [self.eos_token_id] |
|
return bos_token_id + token_ids_0 + eos_token_id |
|
|
|
def load_vocab(self, vocab_file): |
|
with open(vocab_file, "r", encoding="utf-8") as f: |
|
return json.load(f) |
|
|