roberta_zinc_decoder / train_script.py
entropy's picture
Update train_script.py
20d0936
raw
history blame contribute delete
No virus
3.39 kB
import pandas as pd
import os
import torch
import torch.nn as nn
from transformers import GPT2TokenizerFast, GPT2LMHeadModel, AutoModelForCausalLM
from transformers import DataCollatorWithPadding, GPT2Config, DataCollatorForLanguageModeling
from transformers import Trainer, TrainingArguments, RobertaTokenizerFast
import datasets
from datasets import disable_caching
disable_caching()
from datasets import IterableDataset
from conditional_gpt2_model import ConditionalGPT2LMHeadModel
ENCODER_MODEL_NAME = "entropy/roberta_zinc_480m" # encoder model name
TOKENIZER_MAX_LEN = 256 # max_length param on tokenizer
DATA_SUBSHARDS = 10 # number of shards to break each data chunk into
DATA_DIR = None # directory with saved data shards
TRAINER_SAVE_DIR = None # directory to save model checkpoints
assert DATA_DIR is not None, "data directory must be specified"
assert TRAINER_SAVE_DIR is not None, "trainer save directory must be specified"
def gen_dataset():
data_filenames = sorted([i for i in os.listdir(DATA_DIR) if '.hf' in i])
for filename in data_filenames:
dataset = datasets.Dataset.load_from_disk(f'{DATA_DIR}/{filename}')
keep_cols = ['input_ids', 'encoder_hidden_states']
dataset = dataset.remove_columns([i for i in dataset.column_names
if not i in keep_cols]).with_format("torch")
# contiguous shards for faster loading
shards = [dataset.shard(num_shards=DATA_SUBSHARDS, index=index, contiguous=True)
for index in range(DATA_SUBSHARDS)]
for i, shard in enumerate(shards):
for example in shard:
# need to add unit axis to hidden states
example['encoder_hidden_states'] = example['encoder_hidden_states'][None,:]
yield example
dataset = IterableDataset.from_generator(gen_dataset)
dataset = dataset.with_format("torch")
tokenizer = RobertaTokenizerFast.from_pretrained(ENCODER_MODEL_NAME, max_len=TOKENIZER_MAX_LEN)
collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)
# train from scratch
config = GPT2Config(
vocab_size=len(tokenizer),
n_positions=TOKENIZER_MAX_LEN,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
n_layer=6,
n_head=8,
add_cross_attention=True,
)
model = ConditionalGPT2LMHeadModel(config)
# alternatively, load a pre-trained model
# commit_hash = '0ba58478f467056fe33003d7d91644ecede695a7'
# model = AutoModelForCausalLM.from_pretrained("entropy/roberta_zinc_decoder",
# trust_remote_code=True, revision=commit_hash)
# change trainer args as needed
args = TrainingArguments(
output_dir=TRAINER_SAVE_DIR,
per_device_train_batch_size=192,
logging_steps=25,
gradient_accumulation_steps=8,
num_train_epochs=1,
weight_decay=0.1,
warmup_steps=1000,
lr_scheduler_type="cosine",
learning_rate=1e-5,
save_steps=200,
save_total_limit=30,
fp16=True,
push_to_hub=False,
max_steps=50000,
)
trainer = Trainer(
model=model,
tokenizer=tokenizer,
args=args,
data_collator=collator,
train_dataset=dataset,
)
trainer.train()