LongReward-llama3.1-8b-SFT / tiktoken_tokenizer.py
NeoZ123's picture
Upload 13 files
dd80640 verified
import regex as re
import base64
import tiktoken
import os
import json
from transformers import PreTrainedTokenizer
class BaseTokenizer(PreTrainedTokenizer):
"""Abstract class for tokenizer."""
def __init__(self, **kwargs):
super().__init__()
@property
def add_prefix_space(self):
return False
@property
def vocab_size(self):
raise NotImplemented
def tokenize(self, text):
raise NotImplemented
def detokenize(self, token_ids, ignore_special_tokens=True):
raise NotImplemented
def build_single_message(self, role, metadata, message):
assert role in ["system", "user", "assistant", "observation"], role
role_tokens = [self.get_command(f"<|{role}|>")] + self.tokenizer.encode(f"{metadata}\n")
message_tokens = self.tokenizer.encode(message, disallowed_special=())
tokens = role_tokens + message_tokens
return tokens
def build_chat_input(self, query, history=None, role="user", metadata=""):
if history is None:
history = []
input_ids = []
for item in history:
content = item["content"]
if item["role"] == "system" and "tools" in item:
content = content + "\n" + json.dumps(item["tools"], indent=4, ensure_ascii=False)
input_ids.extend(self.build_single_message(item["role"], item.get("metadata", ""), content))
input_ids.extend(self.build_single_message(role, metadata, query))
input_ids.extend([self.get_command("<|assistant|>")])
return self.batch_encode_plus([input_ids], return_tensors="pt", is_split_into_words=True)
@property
def eos_id(self):
raise NotImplemented
def get_command(self, token):
return NotImplemented
class TikTokenizer(BaseTokenizer):
vocab_files_names = {"vocab_file": "tokenizer.tiktoken"}
def __init__(self, vocab_file, **kwargs):
pat_str = "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"
self.pat_str = re.compile(pat_str)
self.b64_vocab = {}
mergeable_ranks = {}
with open(vocab_file) as f:
for line in f:
token, rank = line.strip().split()
rank = int(rank)
token = base64.b64decode(token)
mergeable_ranks[token] = rank
self.b64_vocab['%s' % token] = rank
self.special_tokens = ["<|endoftext|>", "[MASK]", "[gMASK]", "[sMASK]", "<sop>", "<eop>", "<|system|>",
"<|user|>", "<|assistant|>", "<|observation|>"]
self.special_tokens = {
token: idx for idx, token in enumerate(self.special_tokens, start=len(mergeable_ranks))
}
self.special_token_ids = {idx: token for token, idx in self.special_tokens.items()}
self.tokenizer = tiktoken.Encoding(
name="my_tokenizer",
pat_str=pat_str,
mergeable_ranks=mergeable_ranks,
special_tokens=self.special_tokens
)
self.decoder = {rank: token for token, rank in mergeable_ranks.items()}
self.n_words = len(self.decoder) + len(self.special_tokens)
super().__init__()
@property
def add_prefix_space(self):
return False
def tokenize(self, text, add_special_tokens=True):
ids = self.encode(text, add_special_tokens=add_special_tokens)
return [self.convert_id_to_token(_id) for _id in ids]
def detokenize(self, ids, ignore_special_tokens=True):
if ignore_special_tokens:
ids = [idx for idx in ids if idx not in self.special_token_ids]
return self.tokenizer.decode(ids)
def encode(self, text, add_special_tokens=True):
ids = self.tokenizer.encode(text, disallowed_special=(), allowed_special="all")
if add_special_tokens:
ids = [self.special_tokens["[gMASK]"], self.special_tokens["<sop>"]] + ids
return ids
def decode(self, ids, skip_special_tokens=False, clean_up_tokenization_spaces=False):
if type(ids) is int:
ids = [ids]
return self.detokenize(ids, ignore_special_tokens=skip_special_tokens)
def encode_pieces(self, text):
ids = self.tokenizer.encode(text, disallowed_special=())
return list(map(lambda x: self.decoder[x].detokenize('utf-8', errors='replace'), ids))
@property
def vocab_size(self):
return self.n_words
@property
def eos_token_id(self):
return self.special_tokens["<|endoftext|>"]
def convert_token_to_id(self, token):
""" Converts a token (str) in an id using the vocab. """
if token in self.special_tokens:
return self.special_tokens[token]
# assert type(token) == str, "type of token (%s) is %s" % (token, type(token))
# ids = self.tokenizer.encode(token, disallowed_special=())
if token in self.b64_vocab:
return self.b64_vocab[token]
# if len(ids) == 1:
# return ids[0]
else:
raise RuntimeError(f"{token} is not a single token")
def _convert_token_to_id(self, token):
return self.convert_token_to_id(token)
def convert_id_to_token(self, index):
if index in self.special_token_ids:
return self.special_token_ids[index]
return '%s' % self.decoder[index]
# try:
# return self.decoder[index].decode('utf-8')
# except Exception as e:
# print("Exception: %s for (%d)%s" % (e, index, self.decoder[index]))
# return ""
#return self.decoder[index].detokenize('utf-8', errors='replace')
def _convert_id_to_token(self, index):
return self.convert_id_to_token(index)
def get_command(self, token):
return self.special_tokens[token]
def get_vocab(self):
vocab = {self._convert_id_to_token(i): i for i in range(self.vocab_size)}
return vocab