import argparse import json import os import tempfile from pathlib import Path from tqdm import tqdm from datasets import load_dataset from tokenizers import SentencePieceBPETokenizer from transformers import LlamaTokenizerFast, TrainingArguments, AutoTokenizer def main(args): # Load the dataset from the huggingface Hub and prepare it for training if args.dataset_name is not None: if args.dataset_type: if os.path.isfile(args.dataset_name): data_files = [args.dataset_name] else: data_files = os.listdir(args.dataset_name) data_files = [Path(args.dataset_name) / f for f in data_files] print(f"Training on {len(data_files)} files") dataset = load_dataset(args.dataset_type, data_files=data_files, split=args.dataset_split, token=args.hub_token if args.hub_token else None ) else: dataset = load_dataset(args.dataset_name, split=args.dataset_split, streaming=True, token=args.hub_token if args.hub_token else None ) print(dataset) else: raise ValueError("No dataset name provided or dataset is already tokenized") # Remove non text columns dataset = dataset.remove_columns([col for col in dataset.column_names if col != "text"]) # Randomize docs dataset = dataset.shuffle(seed=args.seed) # Select `num_samples` from the dataset if args.num_samples: dataset = dataset.select(range(args.num_samples)) # Create a SentencePieceBPETokenizer tokenizer = SentencePieceBPETokenizer() # Train the SentencePieceBPETokenizer on the dataset tokenizer.train_from_iterator( iterator=dataset['text'], vocab_size=args.vocab_size, show_progress=True, special_tokens=["", "", "", ""], ) # Save the tokenizer new_tokenizer_file = tempfile.NamedTemporaryFile(prefix='tokenizer_', suffix='.json').name tokenizer.save(new_tokenizer_file, pretty=True) # Load reference tokenizer if args.reference_tokenizer is not None and args.hub_token is not None: reference_tokenizer = AutoTokenizer.from_pretrained(args.reference_tokenizer, token=args.hub_token if args.hub_token else None) reference_tokenizer_path = tempfile.TemporaryDirectory().name reference_tokenizer.save_pretrained(reference_tokenizer_path) else: raise ValueError("No tokenizer name provided or no hub token provided. Try using `--reference_tokenizer 'mistralai/Mistral-7B-Instruct-v0.2'") # Read and dump the json file for the new tokenizer and the reference tokenizer with open(new_tokenizer_file) as f: new_tokenizer_json = json.load(f) with open(Path(reference_tokenizer_path) / "tokenizer.json") as f: reference_tokenizer_json = json.load(f) # Add the reference tokenizer's config to the new tokenizer's config new_tokenizer_json["normalizer"] = reference_tokenizer_json["normalizer"] new_tokenizer_json["pre_tokenizer"] = reference_tokenizer_json["pre_tokenizer"] new_tokenizer_json["post_processor"] = reference_tokenizer_json["post_processor"] new_tokenizer_json["decoder"] = reference_tokenizer_json["decoder"] new_tokenizer_json["model"]['fuse_unk'] = reference_tokenizer_json["model"]['fuse_unk'] new_tokenizer_json["model"]['byte_fallback'] = reference_tokenizer_json["model"]['byte_fallback'] # Dump the new tokenizer's config with open(new_tokenizer_file, "w") as f: json.dump(new_tokenizer_json, f, indent=2, ensure_ascii=False) # Load the new tokenizer as a LlamaTokenizerFast new_llama_tokenizer = LlamaTokenizerFast( tokenizer_file=new_tokenizer_file, name_or_path=args.reference_tokenizer + "-tokenizer", unk_token="", unk_token_id=0, bos_token="", bos_token_id=1, eos_token="", eos_token_id=2, pad_token="", pad_token_id=3, padding_side="right", ) # Save the new tokenizer new_llama_tokenizer.save_pretrained(args.output) if __name__ == "__main__": parser = argparse.ArgumentParser(description="Train a new Llama tokenizer") parser.add_argument( "--dataset_name", type=str, default=None, help="The name of the dataset to be tokenized", ) parser.add_argument( "--dataset_type", type=str, default=None, help="The type, 'text', 'json', or 'csv'. Leave blank for regular HF datasets", ) parser.add_argument( "--dataset_split", type=str, default=None, help="The split of the dataset to be tokenized", ) parser.add_argument( "--hub_token", type=str, default=None, help="The token to access the dataset on the hub", ) parser.add_argument( "--reference_tokenizer", type=str, default=None, help="The name of the reference tokenizer to use", ) parser.add_argument( "--seed", type=int, default=123, help="set random seed", ) parser.add_argument( "--num_samples", type=int, default=None, help="Number of samples to use from the dataset", ) parser.add_argument( "--vocab_size", type=int, default=None, help="Vocabulary size to use for the tokenizer", ) parser.add_argument( "--output", type=str, default="./", help="Output path for the new tokenizer", ) args = parser.parse_args() main(args) # How to run: # python train_tokenizer.py --dataset_name texts/all.txt --dataset_type text --dataset_split train --reference_tokenizer mistralai/Mistral-7B-Instruct-v0.2 --vocab_size 32768 --hub_token True