from datasets import Dataset, load_dataset, concatenate_datasets import datasets from transformers import GPT2TokenizerFast from tokenizers.processors import TemplateProcessing input_dir = "dataset_location" tokenizer_file="path/to/file" output_dir="output/dir" tokenizer = GPT2TokenizerFast.from_pretrained(tokenizer_file) #Add eos tokens to the tokenization pipeline as they are not added otherwise tokenizer._tokenizer.post_processor = TemplateProcessing( single="$0 "+tokenizer.eos_token, pair="$A "+tokenizer.eos_token+" $B:1 "+tokenizer.eos_token, special_tokens=[(tokenizer.eos_token, 0)], ) def tokenize_function(examples): return tokenizer(examples["text"]) def group_texts(examples): #group texts. This is based on Hugging Face CLM example block_size = 1024 concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()} total_len = len(concatenated_examples[list(examples.keys())[0]]) total_len = (total_len//block_size) * block_size result = { k: [t[i:i+block_size] for i in range(0, total_len, block_size)] for k, t in concatenated_examples.items() } result["labels"] = result["input_ids"].copy() return result def main(): num_proc=12 #set to something appropriate dataset = datasets.load_from_disk(input_dir) #This one load a saved dataset object from disk. You could create a dataset from iterable or load one like: #dataset = load_dataset("Finnish-NLP/mc4_fi_cleaned", split="train").remove_columns(["timestamp","url"]) #Example usage from Hugging Face Hub #Tokenize, filter out very short texts and group texts to blocks of attention size dataset\ .shuffle(seed=42, load_from_cache_file=False, writer_batch_size=100000)\ .map(tokenize_function, batched=True, num_proc=num_proc, remove_columns=dataset.column_names, load_from_cache_file=False, writer_batch_size=100000)\ .filter(lambda e: len(e["input_ids"]) > 20, num_proc=num_proc, load_from_cache_file=False, writer_batch_size=100000)\ .map(group_texts, batched=True, num_proc=num_proc, load_from_cache_file=False, writer_batch_size=100000)\ .train_test_split(test_size=0.05, load_from_cache_file=False, writer_batch_size=100000)\ .save_to_disk(output_dir) print(dataset) if __name__ == "__main__": main()