File size: 4,522 Bytes
08838e9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
import json
import argparse
from tqdm import tqdm
import os
from datasets import load_dataset
from tokenizers import SentencePieceBPETokenizer
from transformers import LlamaTokenizerFast, TrainingArguments, AutoTokenizer
def main(args):
# Load the dataset from the huggingface Hub and prepare it for training
if args.dataset_name is not None:
data_files = os.listdir(args.dataset_name)
data_files = [args.dataset_name+f for f in data_files]
print(len(data_files))
dataset = load_dataset("json",
data_files=data_files,
split=args.dataset_split,
token=args.hub_token if args.hub_token else None
)
print(dataset)
else:
raise ValueError("No dataset name provided or dataset is already tokenized")
# Remove non text columns
dataset = dataset.remove_columns([col for col in dataset.column_names if col != "text"])
# select `num_samples` from the dataset
dataset = dataset.shuffle(seed=args.seed).select(range(args.num_samples))
# Create a SentencePieceBPETokenizer
tokenizer = SentencePieceBPETokenizer()
# Train the SentencePieceBPETokenizer on the dataset
tokenizer.train_from_iterator(
iterator=dataset['text'],
vocab_size=args.vocab_size,
show_progress=True,
special_tokens=["<unk>", "<s>", "</s>", "<pad>"],
)
# Save the tokenizer
tokenizer.save("new-sentencepiece-tokenizer.json", pretty=True)
# Load reference tokenizer
if args.reference_tokenizer is not None and args.hub_token is not None:
reference_tokenizer = AutoTokenizer.from_pretrained(args.reference_tokenizer, token=args.hub_token if args.hub_token else None)
reference_tokenizer.save_pretrained("reference-tokenizer")
else:
raise ValueError("No tokenizer name provided or no hub token provided. Try using `--reference_tokenizer 'meta-llama/Llama-2-7b-hf'")
# Read and dump the json file for the new tokenizer and the reference tokenizer
with open("new-sentencepiece-tokenizer.json") as f:
new_llama_tokenizer_json = json.load(f)
with open("reference-tokenizer/tokenizer.json") as f:
reference_tokenizer_json = json.load(f)
# Add the reference tokenizer's config to the new tokenizer's config
new_llama_tokenizer_json["normalizer"] = reference_tokenizer_json["normalizer"]
new_llama_tokenizer_json["pre_tokenizer"] = reference_tokenizer_json["pre_tokenizer"]
new_llama_tokenizer_json["post_processor"] = reference_tokenizer_json["post_processor"]
new_llama_tokenizer_json["decoder"] = reference_tokenizer_json["decoder"]
new_llama_tokenizer_json["model"]['fuse_unk'] = reference_tokenizer_json["model"]['fuse_unk']
new_llama_tokenizer_json["model"]['byte_fallback'] = reference_tokenizer_json["model"]['byte_fallback']
# Dump the new tokenizer's config
with open("new-sentencepiece-tokenizer.json", "w") as f:
json.dump(new_llama_tokenizer_json, f, indent=2, ensure_ascii=False)
# Load the new tokenizer as a LlamaTokenizerFast
new_llama_tokenizer = LlamaTokenizerFast(
tokenizer_file="new-sentencepiece-tokenizer.json",
name_or_path=args.reference_tokenizer + "-tokenizer",
unk_token="<unk>",
unk_token_id=0,
bos_token="<s>",
bos_token_id=1,
eos_token="</s>",
eos_token_id=2,
pad_token="<pad>",
pad_token_id=3,
padding_side="right",
)
# Save the new tokenizer
new_llama_tokenizer.save_pretrained("new-llama-tokenizer")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Train a new Llama tokenizer")
parser.add_argument(
"--dataset_name",
type=str,
default=None,
help="The name of the dataset to be tokenized",
)
parser.add_argument(
"--dataset_split",
type=str,
default=None,
help="The split of the dataset to be tokenized",
)
parser.add_argument(
"--hub_token",
type=str,
default=None,
help="The token to access the dataset on the hub",
)
parser.add_argument(
"--reference_tokenizer",
type=str,
default=None,
help="The name of the reference tokenizer to use",
)
parser.add_argument(
"--seed",
type=int,
default=123,
help="set random seed",
)
parser.add_argument(
"--num_samples",
type=int,
default=None,
help="Number of samples to use from the dataset",
)
parser.add_argument(
"--vocab_size",
type=int,
default=None,
help="Vocabulary size to use for the tokenizer",
)
args = parser.parse_args()
main(args)
# How to run:
# python tokenizer_train.py --dataset_name /mimir/dataset/delivery/mimir_base/data/ --dataset_split train --reference_tokenizer meta-llama/Llama-2-7b-hf --vocab_size 32768 --hub_token hf_IIbKlx.... --num_samples 6000000
|