heheyas
init
cfb7702
import sys
import argparse
import os
import time
import logging
from datetime import datetime
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--config", required=True, help="path to config file")
parser.add_argument("--gpu", default="0", help="GPU(s) to be used")
parser.add_argument(
"--resume", default=None, help="path to the weights to be resumed"
)
parser.add_argument(
"--resume_weights_only",
action="store_true",
help="specify this argument to restore only the weights (w/o training states), e.g. --resume path/to/resume --resume_weights_only",
)
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument("--train", action="store_true")
group.add_argument("--validate", action="store_true")
group.add_argument("--test", action="store_true")
group.add_argument("--predict", action="store_true")
# group.add_argument('--export', action='store_true') # TODO: a separate export action
parser.add_argument("--exp_dir", default="./exp")
parser.add_argument("--runs_dir", default="./runs")
parser.add_argument(
"--verbose", action="store_true", help="if true, set logging level to DEBUG"
)
args, extras = parser.parse_known_args()
# set CUDA_VISIBLE_DEVICES then import pytorch-lightning
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
n_gpus = len(args.gpu.split(","))
import datasets
import systems
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
from pytorch_lightning.loggers import TensorBoardLogger, CSVLogger
from utils.callbacks import (
CodeSnapshotCallback,
ConfigSnapshotCallback,
CustomProgressBar,
)
from utils.misc import load_config
# parse YAML config to OmegaConf
config = load_config(args.config, cli_args=extras)
config.cmd_args = vars(args)
config.trial_name = config.get("trial_name") or (
config.tag + datetime.now().strftime("@%Y%m%d-%H%M%S")
)
config.exp_dir = config.get("exp_dir") or os.path.join(args.exp_dir, config.name)
config.save_dir = config.get("save_dir") or os.path.join(
config.exp_dir, config.trial_name, "save"
)
config.ckpt_dir = config.get("ckpt_dir") or os.path.join(
config.exp_dir, config.trial_name, "ckpt"
)
config.code_dir = config.get("code_dir") or os.path.join(
config.exp_dir, config.trial_name, "code"
)
config.config_dir = config.get("config_dir") or os.path.join(
config.exp_dir, config.trial_name, "config"
)
logger = logging.getLogger("pytorch_lightning")
if args.verbose:
logger.setLevel(logging.DEBUG)
if "seed" not in config:
config.seed = int(time.time() * 1000) % 1000
pl.seed_everything(config.seed)
dm = datasets.make(config.dataset.name, config.dataset)
system = systems.make(
config.system.name,
config,
load_from_checkpoint=None if not args.resume_weights_only else args.resume,
)
callbacks = []
if args.train:
callbacks += [
ModelCheckpoint(dirpath=config.ckpt_dir, **config.checkpoint),
LearningRateMonitor(logging_interval="step"),
# CodeSnapshotCallback(
# config.code_dir, use_version=False
# ),
ConfigSnapshotCallback(config, config.config_dir, use_version=False),
CustomProgressBar(refresh_rate=1),
]
loggers = []
if args.train:
loggers += [
TensorBoardLogger(
args.runs_dir, name=config.name, version=config.trial_name
),
CSVLogger(config.exp_dir, name=config.trial_name, version="csv_logs"),
]
if sys.platform == "win32":
# does not support multi-gpu on windows
strategy = "dp"
assert n_gpus == 1
else:
strategy = "ddp_find_unused_parameters_false"
trainer = Trainer(
devices=n_gpus,
accelerator="gpu",
callbacks=callbacks,
logger=loggers,
strategy=strategy,
**config.trainer
)
if args.train:
if args.resume and not args.resume_weights_only:
# FIXME: different behavior in pytorch-lighting>1.9 ?
trainer.fit(system, datamodule=dm, ckpt_path=args.resume)
else:
trainer.fit(system, datamodule=dm)
trainer.test(system, datamodule=dm)
elif args.validate:
trainer.validate(system, datamodule=dm, ckpt_path=args.resume)
elif args.test:
trainer.test(system, datamodule=dm, ckpt_path=args.resume)
elif args.predict:
trainer.predict(system, datamodule=dm, ckpt_path=args.resume)
if __name__ == "__main__":
main()