|
import json |
|
import argparse |
|
from tqdm import tqdm |
|
import os |
|
|
|
from datasets import load_dataset |
|
from tokenizers import SentencePieceBPETokenizer |
|
from transformers import LlamaTokenizerFast, TrainingArguments, AutoTokenizer |
|
|
|
def main(args): |
|
|
|
|
|
if args.dataset_name is not None: |
|
data_files = os.listdir(args.dataset_name) |
|
data_files = [args.dataset_name+f for f in data_files] |
|
print(len(data_files)) |
|
dataset = load_dataset("json", |
|
data_files=data_files, |
|
split=args.dataset_split, |
|
token=args.hub_token if args.hub_token else None |
|
) |
|
print(dataset) |
|
|
|
else: |
|
raise ValueError("No dataset name provided or dataset is already tokenized") |
|
|
|
|
|
dataset = dataset.remove_columns([col for col in dataset.column_names if col != "text"]) |
|
|
|
|
|
dataset = dataset.shuffle(seed=args.seed).select(range(args.num_samples)) |
|
|
|
|
|
tokenizer = SentencePieceBPETokenizer() |
|
|
|
|
|
tokenizer.train_from_iterator( |
|
iterator=dataset['text'], |
|
vocab_size=args.vocab_size, |
|
show_progress=True, |
|
special_tokens=["<unk>", "<s>", "</s>", "<pad>"], |
|
) |
|
|
|
|
|
tokenizer.save("new-sentencepiece-tokenizer.json", pretty=True) |
|
|
|
|
|
if args.reference_tokenizer is not None and args.hub_token is not None: |
|
reference_tokenizer = AutoTokenizer.from_pretrained(args.reference_tokenizer, token=args.hub_token if args.hub_token else None) |
|
reference_tokenizer.save_pretrained("reference-tokenizer") |
|
else: |
|
raise ValueError("No tokenizer name provided or no hub token provided. Try using `--reference_tokenizer 'meta-llama/Llama-2-7b-hf'") |
|
|
|
|
|
with open("new-sentencepiece-tokenizer.json") as f: |
|
new_llama_tokenizer_json = json.load(f) |
|
|
|
with open("reference-tokenizer/tokenizer.json") as f: |
|
reference_tokenizer_json = json.load(f) |
|
|
|
|
|
new_llama_tokenizer_json["normalizer"] = reference_tokenizer_json["normalizer"] |
|
new_llama_tokenizer_json["pre_tokenizer"] = reference_tokenizer_json["pre_tokenizer"] |
|
new_llama_tokenizer_json["post_processor"] = reference_tokenizer_json["post_processor"] |
|
new_llama_tokenizer_json["decoder"] = reference_tokenizer_json["decoder"] |
|
new_llama_tokenizer_json["model"]['fuse_unk'] = reference_tokenizer_json["model"]['fuse_unk'] |
|
new_llama_tokenizer_json["model"]['byte_fallback'] = reference_tokenizer_json["model"]['byte_fallback'] |
|
|
|
|
|
with open("new-sentencepiece-tokenizer.json", "w") as f: |
|
json.dump(new_llama_tokenizer_json, f, indent=2, ensure_ascii=False) |
|
|
|
|
|
new_llama_tokenizer = LlamaTokenizerFast( |
|
tokenizer_file="new-sentencepiece-tokenizer.json", |
|
name_or_path=args.reference_tokenizer + "-tokenizer", |
|
unk_token="<unk>", |
|
unk_token_id=0, |
|
bos_token="<s>", |
|
bos_token_id=1, |
|
eos_token="</s>", |
|
eos_token_id=2, |
|
pad_token="<pad>", |
|
pad_token_id=3, |
|
padding_side="right", |
|
) |
|
|
|
|
|
new_llama_tokenizer.save_pretrained("new-llama-tokenizer") |
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser(description="Train a new Llama tokenizer") |
|
parser.add_argument( |
|
"--dataset_name", |
|
type=str, |
|
default=None, |
|
help="The name of the dataset to be tokenized", |
|
) |
|
parser.add_argument( |
|
"--dataset_split", |
|
type=str, |
|
default=None, |
|
help="The split of the dataset to be tokenized", |
|
) |
|
parser.add_argument( |
|
"--hub_token", |
|
type=str, |
|
default=None, |
|
help="The token to access the dataset on the hub", |
|
) |
|
parser.add_argument( |
|
"--reference_tokenizer", |
|
type=str, |
|
default=None, |
|
help="The name of the reference tokenizer to use", |
|
) |
|
parser.add_argument( |
|
"--seed", |
|
type=int, |
|
default=123, |
|
help="set random seed", |
|
) |
|
parser.add_argument( |
|
"--num_samples", |
|
type=int, |
|
default=None, |
|
help="Number of samples to use from the dataset", |
|
) |
|
parser.add_argument( |
|
"--vocab_size", |
|
type=int, |
|
default=None, |
|
help="Vocabulary size to use for the tokenizer", |
|
) |
|
args = parser.parse_args() |
|
main(args) |
|
|
|
|
|
|
|
|
|
|