deepkyu's picture
initial commit
1ba3df3
raw
history blame
No virus
2.63 kB
import argparse
import glob
from pathlib import Path
from omegaconf import OmegaConf
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from lightning import FontLightningModule
from utils import save_files
def load_configuration(path_config):
setting = OmegaConf.load(path_config)
# load hyperparameter
hp = OmegaConf.load(setting.config.dataset)
hp = OmegaConf.merge(hp, OmegaConf.load(setting.config.model))
hp = OmegaConf.merge(hp, OmegaConf.load(setting.config.logging))
# with lightning setting
if hasattr(setting.config, 'lightning'):
pl_config = OmegaConf.load(setting.config.lightning)
if hasattr(pl_config, 'pl_config'):
return hp, pl_config.pl_config
return hp, pl_config
# without lightning setting
return hp
def parse_args():
parser = argparse.ArgumentParser(description='Code to train font style transfer')
parser.add_argument("--config", type=str, default="./config/setting.yaml",
help="Config file for training")
parser.add_argument('-g', '--gpus', type=str, default='0,1',
help="Number of gpus to use (e.g. '0,1,2,3'). Will use all if not given.")
parser.add_argument('-p', '--resume_checkpoint_path', type=str, default=None,
help="path of checkpoint for resuming")
args = parser.parse_args()
return args
def main():
args = parse_args()
hp, pl_config = load_configuration(args.config)
logging_dir = Path(hp.logging.log_dir)
# call lightning module
font_pl = FontLightningModule(hp)
# set logging
hp.logging['log_dir'] = logging_dir / 'tensorboard'
savefiles = []
for reg in hp.logging.savefiles:
savefiles += glob.glob(reg)
hp.logging['log_dir'].mkdir(exist_ok=True)
save_files(str(logging_dir), savefiles)
# set tensorboard logger
logger = TensorBoardLogger(str(logging_dir), name=str(hp.logging.seed))
# set checkpoing callback
weights_save_path = logging_dir / 'checkpoint' / str(hp.logging.seed)
weights_save_path.mkdir(exist_ok=True)
checkpoint_callback = ModelCheckpoint(
dirpath=str(weights_save_path),
**pl_config.checkpoint.callback
)
# set lightning trainer
trainer = pl.Trainer(
logger=logger,
gpus=-1 if args.gpus is None else args.gpus,
callbacks=[checkpoint_callback],
**pl_config.trainer
)
# let's train
trainer.fit(font_pl)
if __name__ == "__main__":
main()