Spaces:
Build error
Build error
import argparse | |
import logging | |
import os | |
import random | |
import timeit | |
from datetime import datetime | |
import torch | |
#import wandb | |
from pytorch_lightning import Trainer | |
from pytorch_lightning.callbacks import LearningRateMonitor, EarlyStopping, ModelCheckpoint | |
from pytorch_lightning.loggers import WandbLogger, TensorBoardLogger | |
from pytorch_lightning.utilities.seed import seed_everything | |
from src.genie.data_module4 import RAMSDataModule | |
from src.genie.ACE_data_module import ACEDataModule | |
from src.genie.KAIROS_data_module import KAIROSDataModule | |
from src.genie.model import GenIEModel | |
logger = logging.getLogger(__name__) | |
def main(): | |
parser = argparse.ArgumentParser() | |
# Required parameters | |
parser.add_argument( | |
"--model", | |
type=str, | |
required=True, | |
choices=['gen','constrained-gen'] | |
) | |
parser.add_argument( | |
"--dataset", | |
type=str, | |
required=True, | |
choices=['RAMS', 'ACE', 'KAIROS'] | |
) | |
parser.add_argument('--tmp_dir', type=str) | |
parser.add_argument( | |
"--ckpt_name", | |
default=None, | |
type=str, | |
help="The output directory where the model checkpoints and predictions will be written.", | |
) | |
parser.add_argument( | |
"--load_ckpt", | |
default=None, | |
type=str, | |
) | |
parser.add_argument( | |
"--train_file", | |
default=None, | |
type=str, | |
help="The input training file. If a data dir is specified, will look for the file there" | |
+ "If no data dir or train/predict files are specified, will run with tensorflow_datasets.", | |
) | |
parser.add_argument( | |
"--val_file", | |
default=None, | |
type=str, | |
help="The input evaluation file. If a data dir is specified, will look for the file there" | |
+ "If no data dir or train/predict files are specified, will run with tensorflow_datasets.", | |
) | |
parser.add_argument( | |
'--test_file', | |
type=str, | |
default=None, | |
) | |
parser.add_argument('--input_dir', type=str, default=None) | |
parser.add_argument('--coref_dir', type=str, default='data/kairos/coref_outputs') | |
parser.add_argument('--use_info', action='store_true', default=False, help='use informative mentions instead of the nearest mention.') | |
parser.add_argument('--mark_trigger', action='store_true') | |
parser.add_argument('--sample-gen', action='store_true', help='Do sampling when generation.') | |
parser.add_argument("--train_batch_size", default=8, type=int, help="Batch size per GPU/CPU for training.") | |
parser.add_argument( | |
"--eval_batch_size", default=8, type=int, help="Batch size per GPU/CPU for evaluation." | |
) | |
parser.add_argument( | |
"--eval_only", action="store_true", | |
) | |
parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.") | |
parser.add_argument( | |
"--accumulate_grad_batches", | |
type=int, | |
default=1, | |
help="Number of updates steps to accumulate before performing a backward/update pass.", | |
) | |
parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.") | |
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.") | |
parser.add_argument("--gradient_clip_val", default=1.0, type=float, help="Max gradient norm.") | |
parser.add_argument( | |
"--num_train_epochs", default=3, type=int, help="Total number of training epochs to perform." | |
) | |
parser.add_argument( | |
"--max_steps", | |
default=-1, | |
type=int, | |
help="If > 0: set total number of training steps to perform. Override num_train_epochs.", | |
) | |
parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.") | |
parser.add_argument("--gpus", default=-1, help='-1 means train on all gpus') | |
parser.add_argument("--seed", type=int, default=42, help="random seed for initialization") | |
parser.add_argument( | |
"--fp16", | |
action="store_true", | |
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit", | |
) | |
parser.add_argument("--threads", type=int, default=1, help="multiple threads for converting example to features") | |
args = parser.parse_args() | |
# Setup logging | |
logging.basicConfig( | |
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", | |
datefmt="%m/%d/%Y %H:%M:%S", | |
level=logging.INFO, | |
) | |
# Set seed | |
seed_everything(args.seed) | |
logger.info("Training/evaluation parameters %s", args) | |
if not args.ckpt_name: | |
d = datetime.now() | |
time_str = d.strftime('%m-%dT%H%M') | |
args.ckpt_name = '{}_{}lr{}_{}'.format(args.model, args.train_batch_size * args.accumulate_grad_batches, | |
args.learning_rate, time_str) | |
args.ckpt_dir = os.path.join(f'./checkpoints/{args.ckpt_name}') | |
#os.makedirs(args.ckpt_dir) | |
checkpoint_callback = ModelCheckpoint( | |
dirpath=args.ckpt_dir, | |
save_top_k=2, | |
monitor='val/loss', | |
mode='min', | |
save_weights_only=True, | |
filename='{epoch}', # this cannot contain slashes | |
) | |
lr_logger = LearningRateMonitor() | |
tb_logger = TensorBoardLogger('logs/') | |
model = GenIEModel(args) | |
if args.dataset == 'RAMS': | |
dm = RAMSDataModule(args) | |
elif args.dataset == 'ACE': | |
dm = ACEDataModule(args) | |
elif args.dataset == 'KAIROS': | |
dm = KAIROSDataModule(args) | |
if args.max_steps < 0 : | |
args.max_epochs = args.min_epochs = args.num_train_epochs | |
trainer = Trainer( | |
logger=tb_logger, | |
min_epochs=args.num_train_epochs, | |
max_epochs=args.num_train_epochs, | |
gpus=args.gpus, | |
checkpoint_callback=checkpoint_callback, | |
accumulate_grad_batches=args.accumulate_grad_batches, | |
gradient_clip_val=args.gradient_clip_val, | |
num_sanity_val_steps=0, | |
val_check_interval=0.5, # use float to check every n epochs | |
precision=16 if args.fp16 else 32, | |
callbacks = [lr_logger, ], | |
) | |
if args.load_ckpt: | |
model.load_state_dict(torch.load(args.load_ckpt,map_location=model.device)['state_dict']) | |
if args.eval_only: | |
print(args.eval_only) | |
dm.setup('test') | |
trainer.test(model, datamodule=dm) #also loads training dataloader | |
else: | |
print(args.eval_only) | |
dm.setup('fit') | |
trainer.fit(model, dm) | |
if __name__ == "__main__": | |
main() | |