|
|
|
|
|
|
|
|
|
|
|
|
|
import datetime
|
|
|
|
|
|
import os
|
|
|
|
os.environ["NCCL_DEBUG"] = "INFO"
|
|
os.environ["OMPI_MCA_opal_cuda_support"] = "true"
|
|
os.environ["CONDA_OVERRIDE_GLIBC"] = "2.56"
|
|
|
|
import pickle
|
|
import random
|
|
import subprocess
|
|
|
|
import numpy as np
|
|
import pytz
|
|
import torch
|
|
from datasets import load_from_disk
|
|
from transformers import BertConfig, BertForMaskedLM, TrainingArguments
|
|
|
|
from geneformer import GeneformerPretrainer
|
|
|
|
seed_num = 0
|
|
random.seed(seed_num)
|
|
np.random.seed(seed_num)
|
|
seed_val = 42
|
|
torch.manual_seed(seed_val)
|
|
torch.cuda.manual_seed_all(seed_val)
|
|
|
|
|
|
timezone = pytz.timezone("US/Eastern")
|
|
rootdir = "/parent_ouput_directory"
|
|
|
|
|
|
|
|
model_type = "bert"
|
|
|
|
max_input_size = 2**11
|
|
|
|
num_layers = 6
|
|
|
|
num_attn_heads = 4
|
|
|
|
num_embed_dim = 256
|
|
|
|
intermed_size = num_embed_dim * 2
|
|
|
|
activ_fn = "relu"
|
|
|
|
initializer_range = 0.02
|
|
layer_norm_eps = 1e-12
|
|
attention_probs_dropout_prob = 0.02
|
|
hidden_dropout_prob = 0.02
|
|
|
|
|
|
|
|
|
|
num_examples = 27_406_208
|
|
|
|
num_gpus = 12
|
|
|
|
geneformer_batch_size = 12
|
|
|
|
max_lr = 1e-3
|
|
|
|
lr_schedule_fn = "linear"
|
|
|
|
warmup_steps = 10_000
|
|
|
|
epochs = 3
|
|
|
|
optimizer = "adamw"
|
|
|
|
weight_decay = 0.001
|
|
|
|
|
|
|
|
current_date = datetime.datetime.now(tz=timezone)
|
|
datestamp = f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}_{current_date.strftime('%X').replace(':','')}"
|
|
run_name = f"{datestamp}_geneformer_30M_L{num_layers}_emb{num_embed_dim}_SL{max_input_size}_E{epochs}_B{geneformer_batch_size}_LR{max_lr}_LS{lr_schedule_fn}_WU{warmup_steps}_O{optimizer}_DS{num_gpus}"
|
|
training_output_dir = f"{rootdir}/models/{run_name}/"
|
|
logging_dir = f"{rootdir}/runs/{run_name}/"
|
|
model_output_dir = os.path.join(training_output_dir, "models/")
|
|
|
|
|
|
|
|
model_output_file = os.path.join(model_output_dir, "pytorch_model.bin")
|
|
if os.path.isfile(model_output_file) is True:
|
|
raise Exception("Model already saved to this directory.")
|
|
|
|
|
|
|
|
subprocess.call(f"mkdir {training_output_dir}", shell=True)
|
|
subprocess.call(f"mkdir {model_output_dir}", shell=True)
|
|
|
|
|
|
|
|
with open("token_dictionary.pkl", "rb") as fp:
|
|
token_dictionary = pickle.load(fp)
|
|
|
|
|
|
config = {
|
|
"hidden_size": num_embed_dim,
|
|
"num_hidden_layers": num_layers,
|
|
"initializer_range": initializer_range,
|
|
"layer_norm_eps": layer_norm_eps,
|
|
"attention_probs_dropout_prob": attention_probs_dropout_prob,
|
|
"hidden_dropout_prob": hidden_dropout_prob,
|
|
"intermediate_size": intermed_size,
|
|
"hidden_act": activ_fn,
|
|
"max_position_embeddings": max_input_size,
|
|
"model_type": model_type,
|
|
"num_attention_heads": num_attn_heads,
|
|
"pad_token_id": token_dictionary.get("<pad>"),
|
|
"vocab_size": len(token_dictionary),
|
|
}
|
|
|
|
config = BertConfig(**config)
|
|
model = BertForMaskedLM(config)
|
|
model = model.train()
|
|
|
|
|
|
training_args = {
|
|
"learning_rate": max_lr,
|
|
"do_train": True,
|
|
"do_eval": False,
|
|
"group_by_length": True,
|
|
"length_column_name": "length",
|
|
"disable_tqdm": False,
|
|
"lr_scheduler_type": lr_schedule_fn,
|
|
"warmup_steps": warmup_steps,
|
|
"weight_decay": weight_decay,
|
|
"per_device_train_batch_size": geneformer_batch_size,
|
|
"num_train_epochs": epochs,
|
|
"save_strategy": "steps",
|
|
"save_steps": np.floor(num_examples / geneformer_batch_size / 8),
|
|
"logging_steps": 1000,
|
|
"output_dir": training_output_dir,
|
|
"logging_dir": logging_dir,
|
|
}
|
|
training_args = TrainingArguments(**training_args)
|
|
|
|
print("Starting training.")
|
|
|
|
|
|
trainer = GeneformerPretrainer(
|
|
model=model,
|
|
args=training_args,
|
|
|
|
train_dataset=load_from_disk("genecorpus_30M_2048.dataset"),
|
|
|
|
example_lengths_file="genecorpus_30M_2048_lengths.pkl",
|
|
token_dictionary=token_dictionary,
|
|
)
|
|
|
|
|
|
trainer.train()
|
|
|
|
|
|
trainer.save_model(model_output_dir)
|
|
|