import argparse import os import pathlib import yaml from dataset import DataModule from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.loggers.csv_logs import CSVLogger from pytorch_lightning.loggers import TensorBoardLogger from pytorch_lightning.callbacks.early_stopping import EarlyStopping from lightning_module import ( PretrainLightningModule, SSLStepLightningModule, SSLDualLightningModule, ) from utils import configure_args def get_arg(): parser = argparse.ArgumentParser() parser.add_argument("--config_path", required=True, type=pathlib.Path) parser.add_argument( "--stage", required=True, type=str, choices=["pretrain", "ssl-step", "ssl-dual"] ) parser.add_argument("--run_name", required=True, type=str) parser.add_argument("--corpus_type", default=None, type=str) parser.add_argument("--source_path", default=None, type=pathlib.Path) parser.add_argument("--aux_path", default=None, type=pathlib.Path) parser.add_argument("--preprocessed_path", default=None, type=pathlib.Path) parser.add_argument("--n_train", default=None, type=int) parser.add_argument("--n_val", default=None, type=int) parser.add_argument("--n_test", default=None, type=int) parser.add_argument("--epoch", default=None, type=int) parser.add_argument("--load_pretrained", action="store_true") parser.add_argument("--pretrained_path", default=None, type=pathlib.Path) parser.add_argument("--early_stopping", action="store_true") parser.add_argument("--alpha", default=None, type=float) parser.add_argument("--beta", default=None, type=float) parser.add_argument("--learning_rate", default=None, type=float) parser.add_argument( "--feature_loss_type", default=None, type=str, choices=["mae", "mse"] ) parser.add_argument("--debug", action="store_true") return parser.parse_args() def train(args, config, output_path): debug = args.debug csvlogger = CSVLogger(save_dir=str(output_path), name="train_log") tblogger = TensorBoardLogger(save_dir=str(output_path), name="tf_log") checkpoint_callback = ModelCheckpoint( dirpath=str(output_path), save_weights_only=True, save_top_k=-1, every_n_epochs=1, monitor="val_loss", ) callbacks = [checkpoint_callback] if config["train"]["early_stopping"]: earlystop_callback = EarlyStopping( monitor="val_loss", min_delta=0.0, patience=15, mode="min" ) callbacks.append(earlystop_callback) trainer = Trainer( max_epochs=1 if debug else config["train"]["epoch"], gpus=-1, deterministic=False, auto_select_gpus=True, benchmark=True, default_root_dir=os.getcwd(), limit_train_batches=0.01 if debug else 1.0, limit_val_batches=0.5 if debug else 1.0, callbacks=callbacks, logger=[csvlogger, tblogger], gradient_clip_val=config["train"]["grad_clip_thresh"], flush_logs_every_n_steps=config["train"]["logger_step"], val_check_interval=0.5, ) if config["general"]["stage"] == "pretrain": model = PretrainLightningModule(config) elif config["general"]["stage"] == "ssl-step": model = SSLStepLightningModule(config) elif config["general"]["stage"] == "ssl-dual": model = SSLDualLightningModule(config) else: raise NotImplementedError() datamodule = DataModule(config) trainer.fit(model, datamodule=datamodule) if __name__ == "__main__": args = get_arg() config = yaml.load(open(args.config_path, "r"), Loader=yaml.FullLoader) output_path = pathlib.Path(config["general"]["output_path"]) / args.run_name os.makedirs(output_path, exist_ok=True) config, args = configure_args(config, args) train(args=args, config=config, output_path=output_path)