Bart-gen-arg / train.py
adherent's picture
Add application file
4bb803b
raw
history blame
6.6 kB
import argparse
import logging
import os
import random
import timeit
from datetime import datetime
import torch
#import wandb
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import LearningRateMonitor, EarlyStopping, ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger, TensorBoardLogger
from pytorch_lightning.utilities.seed import seed_everything
from src.genie.data_module4 import RAMSDataModule
from src.genie.ACE_data_module import ACEDataModule
from src.genie.KAIROS_data_module import KAIROSDataModule
from src.genie.model import GenIEModel
logger = logging.getLogger(__name__)
def main():
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument(
"--model",
type=str,
required=True,
choices=['gen','constrained-gen']
)
parser.add_argument(
"--dataset",
type=str,
required=True,
choices=['RAMS', 'ACE', 'KAIROS']
)
parser.add_argument('--tmp_dir', type=str)
parser.add_argument(
"--ckpt_name",
default=None,
type=str,
help="The output directory where the model checkpoints and predictions will be written.",
)
parser.add_argument(
"--load_ckpt",
default=None,
type=str,
)
parser.add_argument(
"--train_file",
default=None,
type=str,
help="The input training file. If a data dir is specified, will look for the file there"
+ "If no data dir or train/predict files are specified, will run with tensorflow_datasets.",
)
parser.add_argument(
"--val_file",
default=None,
type=str,
help="The input evaluation file. If a data dir is specified, will look for the file there"
+ "If no data dir or train/predict files are specified, will run with tensorflow_datasets.",
)
parser.add_argument(
'--test_file',
type=str,
default=None,
)
parser.add_argument('--input_dir', type=str, default=None)
parser.add_argument('--coref_dir', type=str, default='data/kairos/coref_outputs')
parser.add_argument('--use_info', action='store_true', default=False, help='use informative mentions instead of the nearest mention.')
parser.add_argument('--mark_trigger', action='store_true')
parser.add_argument('--sample-gen', action='store_true', help='Do sampling when generation.')
parser.add_argument("--train_batch_size", default=8, type=int, help="Batch size per GPU/CPU for training.")
parser.add_argument(
"--eval_batch_size", default=8, type=int, help="Batch size per GPU/CPU for evaluation."
)
parser.add_argument(
"--eval_only", action="store_true",
)
parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
parser.add_argument(
"--accumulate_grad_batches",
type=int,
default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.",
)
parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
parser.add_argument("--gradient_clip_val", default=1.0, type=float, help="Max gradient norm.")
parser.add_argument(
"--num_train_epochs", default=3, type=int, help="Total number of training epochs to perform."
)
parser.add_argument(
"--max_steps",
default=-1,
type=int,
help="If > 0: set total number of training steps to perform. Override num_train_epochs.",
)
parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")
parser.add_argument("--gpus", default=-1, help='-1 means train on all gpus')
parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
parser.add_argument(
"--fp16",
action="store_true",
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
)
parser.add_argument("--threads", type=int, default=1, help="multiple threads for converting example to features")
args = parser.parse_args()
# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
# Set seed
seed_everything(args.seed)
logger.info("Training/evaluation parameters %s", args)
if not args.ckpt_name:
d = datetime.now()
time_str = d.strftime('%m-%dT%H%M')
args.ckpt_name = '{}_{}lr{}_{}'.format(args.model, args.train_batch_size * args.accumulate_grad_batches,
args.learning_rate, time_str)
args.ckpt_dir = os.path.join(f'./checkpoints/{args.ckpt_name}')
#os.makedirs(args.ckpt_dir)
checkpoint_callback = ModelCheckpoint(
dirpath=args.ckpt_dir,
save_top_k=2,
monitor='val/loss',
mode='min',
save_weights_only=True,
filename='{epoch}', # this cannot contain slashes
)
lr_logger = LearningRateMonitor()
tb_logger = TensorBoardLogger('logs/')
model = GenIEModel(args)
if args.dataset == 'RAMS':
dm = RAMSDataModule(args)
elif args.dataset == 'ACE':
dm = ACEDataModule(args)
elif args.dataset == 'KAIROS':
dm = KAIROSDataModule(args)
if args.max_steps < 0 :
args.max_epochs = args.min_epochs = args.num_train_epochs
trainer = Trainer(
logger=tb_logger,
min_epochs=args.num_train_epochs,
max_epochs=args.num_train_epochs,
gpus=args.gpus,
checkpoint_callback=checkpoint_callback,
accumulate_grad_batches=args.accumulate_grad_batches,
gradient_clip_val=args.gradient_clip_val,
num_sanity_val_steps=0,
val_check_interval=0.5, # use float to check every n epochs
precision=16 if args.fp16 else 32,
callbacks = [lr_logger, ],
)
if args.load_ckpt:
model.load_state_dict(torch.load(args.load_ckpt,map_location=model.device)['state_dict'])
if args.eval_only:
print(args.eval_only)
dm.setup('test')
trainer.test(model, datamodule=dm) #also loads training dataloader
else:
print(args.eval_only)
dm.setup('fit')
trainer.fit(model, dm)
if __name__ == "__main__":
main()