aet_demo / train.py
saeki
fix
7b918f7
import argparse
import os
import pathlib
import yaml
from dataset import DataModule
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers.csv_logs import CSVLogger
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from lightning_module import (
PretrainLightningModule,
SSLStepLightningModule,
SSLDualLightningModule,
)
from utils import configure_args
def get_arg():
parser = argparse.ArgumentParser()
parser.add_argument("--config_path", required=True, type=pathlib.Path)
parser.add_argument(
"--stage", required=True, type=str, choices=["pretrain", "ssl-step", "ssl-dual"]
)
parser.add_argument("--run_name", required=True, type=str)
parser.add_argument("--corpus_type", default=None, type=str)
parser.add_argument("--source_path", default=None, type=pathlib.Path)
parser.add_argument("--aux_path", default=None, type=pathlib.Path)
parser.add_argument("--preprocessed_path", default=None, type=pathlib.Path)
parser.add_argument("--n_train", default=None, type=int)
parser.add_argument("--n_val", default=None, type=int)
parser.add_argument("--n_test", default=None, type=int)
parser.add_argument("--epoch", default=None, type=int)
parser.add_argument("--load_pretrained", action="store_true")
parser.add_argument("--pretrained_path", default=None, type=pathlib.Path)
parser.add_argument("--early_stopping", action="store_true")
parser.add_argument("--alpha", default=None, type=float)
parser.add_argument("--beta", default=None, type=float)
parser.add_argument("--learning_rate", default=None, type=float)
parser.add_argument(
"--feature_loss_type", default=None, type=str, choices=["mae", "mse"]
)
parser.add_argument("--debug", action="store_true")
return parser.parse_args()
def train(args, config, output_path):
debug = args.debug
csvlogger = CSVLogger(save_dir=str(output_path), name="train_log")
tblogger = TensorBoardLogger(save_dir=str(output_path), name="tf_log")
checkpoint_callback = ModelCheckpoint(
dirpath=str(output_path),
save_weights_only=True,
save_top_k=-1,
every_n_epochs=1,
monitor="val_loss",
)
callbacks = [checkpoint_callback]
if config["train"]["early_stopping"]:
earlystop_callback = EarlyStopping(
monitor="val_loss", min_delta=0.0, patience=15, mode="min"
)
callbacks.append(earlystop_callback)
trainer = Trainer(
max_epochs=1 if debug else config["train"]["epoch"],
gpus=-1,
deterministic=False,
auto_select_gpus=True,
benchmark=True,
default_root_dir=os.getcwd(),
limit_train_batches=0.01 if debug else 1.0,
limit_val_batches=0.5 if debug else 1.0,
callbacks=callbacks,
logger=[csvlogger, tblogger],
gradient_clip_val=config["train"]["grad_clip_thresh"],
flush_logs_every_n_steps=config["train"]["logger_step"],
val_check_interval=0.5,
)
if config["general"]["stage"] == "pretrain":
model = PretrainLightningModule(config)
elif config["general"]["stage"] == "ssl-step":
model = SSLStepLightningModule(config)
elif config["general"]["stage"] == "ssl-dual":
model = SSLDualLightningModule(config)
else:
raise NotImplementedError()
datamodule = DataModule(config)
trainer.fit(model, datamodule=datamodule)
if __name__ == "__main__":
args = get_arg()
config = yaml.load(open(args.config_path, "r"), Loader=yaml.FullLoader)
output_path = pathlib.Path(config["general"]["output_path"]) / args.run_name
os.makedirs(output_path, exist_ok=True)
config, args = configure_args(config, args)
train(args=args, config=config, output_path=output_path)