import pandas as pd import os import torch import torch.nn as nn from transformers import GPT2TokenizerFast, GPT2LMHeadModel, AutoModelForCausalLM from transformers import DataCollatorWithPadding, GPT2Config, DataCollatorForLanguageModeling from transformers import Trainer, TrainingArguments, RobertaTokenizerFast import datasets from datasets import disable_caching disable_caching() from datasets import IterableDataset from conditional_gpt2_model import ConditionalGPT2LMHeadModel ENCODER_MODEL_NAME = "entropy/roberta_zinc_480m" # encoder model name TOKENIZER_MAX_LEN = 256 # max_length param on tokenizer DATA_SUBSHARDS = 10 # number of shards to break each data chunk into DATA_DIR = None # directory with saved data shards TRAINER_SAVE_DIR = None # directory to save model checkpoints assert DATA_DIR is not None, "data directory must be specified" assert TRAINER_SAVE_DIR is not None, "trainer save directory must be specified" def gen_dataset(): data_filenames = sorted([i for i in os.listdir(DATA_DIR) if '.hf' in i]) for filename in data_filenames: dataset = datasets.Dataset.load_from_disk(f'{DATA_DIR}/{filename}') keep_cols = ['input_ids', 'encoder_hidden_states'] dataset = dataset.remove_columns([i for i in dataset.column_names if not i in keep_cols]).with_format("torch") # contiguous shards for faster loading shards = [dataset.shard(num_shards=DATA_SUBSHARDS, index=index, contiguous=True) for index in range(DATA_SUBSHARDS)] for i, shard in enumerate(shards): for example in shard: # need to add unit axis to hidden states example['encoder_hidden_states'] = example['encoder_hidden_states'][None,:] yield example dataset = IterableDataset.from_generator(gen_dataset) dataset = dataset.with_format("torch") tokenizer = RobertaTokenizerFast.from_pretrained(ENCODER_MODEL_NAME, max_len=TOKENIZER_MAX_LEN) collator = DataCollatorForLanguageModeling(tokenizer, mlm=False) # train from scratch config = GPT2Config( vocab_size=len(tokenizer), n_positions=TOKENIZER_MAX_LEN, bos_token_id=tokenizer.bos_token_id, eos_token_id=tokenizer.eos_token_id, n_layer=6, n_head=8, add_cross_attention=True, ) model = ConditionalGPT2LMHeadModel(config) # alternatively, load a pre-trained model # commit_hash = '0ba58478f467056fe33003d7d91644ecede695a7' # model = AutoModelForCausalLM.from_pretrained("entropy/roberta_zinc_decoder", # trust_remote_code=True, revision=commit_hash) # change trainer args as needed args = TrainingArguments( output_dir=TRAINER_SAVE_DIR, per_device_train_batch_size=192, logging_steps=25, gradient_accumulation_steps=8, num_train_epochs=1, weight_decay=0.1, warmup_steps=1000, lr_scheduler_type="cosine", learning_rate=1e-5, save_steps=200, save_total_limit=30, fp16=True, push_to_hub=False, max_steps=50000, ) trainer = Trainer( model=model, tokenizer=tokenizer, args=args, data_collator=collator, train_dataset=dataset, ) trainer.train()