import argparse import contextlib import importlib import logging import os import sys import time import traceback import pytorch_lightning as pl import torch from pytorch_lightning import Trainer from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint from pytorch_lightning.loggers import CSVLogger, TensorBoardLogger from pytorch_lightning.utilities.rank_zero import rank_zero_only import craftsman from craftsman.systems.base import BaseSystem from craftsman.utils.callbacks import ( CodeSnapshotCallback, ConfigSnapshotCallback, CustomProgressBar, ProgressCallback, ) from craftsman.utils.config import ExperimentConfig, load_config from craftsman.utils.misc import get_rank from craftsman.utils.typing import Optional class ColoredFilter(logging.Filter): """ A logging filter to add color to certain log levels. """ RESET = "\033[0m" RED = "\033[31m" GREEN = "\033[32m" YELLOW = "\033[33m" BLUE = "\033[34m" MAGENTA = "\033[35m" CYAN = "\033[36m" COLORS = { "WARNING": YELLOW, "INFO": GREEN, "DEBUG": BLUE, "CRITICAL": MAGENTA, "ERROR": RED, } RESET = "\x1b[0m" def __init__(self): super().__init__() def filter(self, record): if record.levelname in self.COLORS: color_start = self.COLORS[record.levelname] record.levelname = f"{color_start}[{record.levelname}]" record.msg = f"{record.msg}{self.RESET}" return True def load_custom_module(module_path): module_name = os.path.basename(module_path) if os.path.isfile(module_path): sp = os.path.splitext(module_path) module_name = sp[0] try: if os.path.isfile(module_path): module_spec = importlib.util.spec_from_file_location( module_name, module_path ) else: module_spec = importlib.util.spec_from_file_location( module_name, os.path.join(module_path, "__init__.py") ) module = importlib.util.module_from_spec(module_spec) sys.modules[module_name] = module module_spec.loader.exec_module(module) return True except Exception as e: print(traceback.format_exc()) print(f"Cannot import {module_path} module for custom nodes:", e) return False def load_custom_modules(): node_paths = ["custom"] node_import_times = [] if not os.path.exists("node_paths"): return for custom_node_path in node_paths: possible_modules = os.listdir(custom_node_path) if "__pycache__" in possible_modules: possible_modules.remove("__pycache__") for possible_module in possible_modules: module_path = os.path.join(custom_node_path, possible_module) if ( os.path.isfile(module_path) and os.path.splitext(module_path)[1] != ".py" ): continue if module_path.endswith(".disabled"): continue time_before = time.perf_counter() success = load_custom_module(module_path) node_import_times.append( (time.perf_counter() - time_before, module_path, success) ) if len(node_import_times) > 0: print("\nImport times for custom modules:") for n in sorted(node_import_times): if n[2]: import_message = "" else: import_message = " (IMPORT FAILED)" print("{:6.1f} seconds{}:".format(n[0], import_message), n[1]) print() def main(args, extras) -> None: # set CUDA_VISIBLE_DEVICES if needed, then import pytorch-lightning os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" env_gpus_str = os.environ.get("CUDA_VISIBLE_DEVICES", None) env_gpus = list(env_gpus_str.split(",")) if env_gpus_str else [] selected_gpus = [0] torch.set_float32_matmul_precision("high") # Always rely on CUDA_VISIBLE_DEVICES if specific GPU ID(s) are specified. # As far as Pytorch Lightning is concerned, we always use all available GPUs # (possibly filtered by CUDA_VISIBLE_DEVICES). devices = -1 if len(env_gpus) > 0: n_gpus = len(env_gpus) else: selected_gpus = list(args.gpu.split(",")) n_gpus = len(selected_gpus) print(f"Using {n_gpus} GPUs: {selected_gpus}") os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu if args.typecheck: from jaxtyping import install_import_hook install_import_hook("craftsman", "typeguard.typechecked") logger = logging.getLogger("pytorch_lightning") if args.verbose: logger.setLevel(logging.DEBUG) for handler in logger.handlers: if handler.stream == sys.stderr: # type: ignore if not args.gradio: handler.setFormatter(logging.Formatter("%(levelname)s %(message)s")) handler.addFilter(ColoredFilter()) else: handler.setFormatter(logging.Formatter("[%(levelname)s] %(message)s")) load_custom_modules() # parse YAML config to OmegaConf cfg: ExperimentConfig cfg = load_config(args.config, cli_args=extras, n_gpus=n_gpus) # set a different seed for each device pl.seed_everything(cfg.seed + get_rank(), workers=True) dm = craftsman.find(cfg.data_type)(cfg.data) system: BaseSystem = craftsman.find(cfg.system_type)( cfg.system, resumed=cfg.resume is not None ) system.set_save_dir(os.path.join(cfg.trial_dir, "save")) if args.gradio: fh = logging.FileHandler(os.path.join(cfg.trial_dir, "logs")) fh.setLevel(logging.INFO) if args.verbose: fh.setLevel(logging.DEBUG) fh.setFormatter(logging.Formatter("[%(levelname)s] %(message)s")) logger.addHandler(fh) callbacks = [] if args.train: callbacks += [ ModelCheckpoint( dirpath=os.path.join(cfg.trial_dir, "ckpts"), **cfg.checkpoint ), LearningRateMonitor(logging_interval="step"), CodeSnapshotCallback( os.path.join(cfg.trial_dir, "code"), use_version=False ), ConfigSnapshotCallback( args.config, cfg, os.path.join(cfg.trial_dir, "configs"), use_version=False, ), ] if args.gradio: callbacks += [ ProgressCallback(save_path=os.path.join(cfg.trial_dir, "progress")) ] else: callbacks += [CustomProgressBar(refresh_rate=1)] def write_to_text(file, lines): with open(file, "w") as f: for line in lines: f.write(line + "\n") loggers = [] if args.train: # make tensorboard logging dir to suppress warning rank_zero_only( lambda: os.makedirs(os.path.join(cfg.trial_dir, "tb_logs"), exist_ok=True) )() loggers += [ TensorBoardLogger(cfg.trial_dir, name="tb_logs"), CSVLogger(cfg.trial_dir, name="csv_logs"), ] + system.get_loggers() rank_zero_only( lambda: write_to_text( os.path.join(cfg.trial_dir, "cmd.txt"), ["python " + " ".join(sys.argv), str(args)], ) )() trainer = Trainer( callbacks=callbacks, logger=loggers, inference_mode=False, accelerator="gpu", devices=devices, # profiler="advanced", **cfg.trainer, ) def set_system_status(system: BaseSystem, ckpt_path: Optional[str]): if ckpt_path is None: return ckpt = torch.load(ckpt_path, map_location="cpu") system.set_resume_status(ckpt["epoch"], ckpt["global_step"]) if args.train: trainer.fit(system, datamodule=dm, ckpt_path=cfg.resume) trainer.test(system, datamodule=dm) if args.gradio: # also export assets if in gradio mode trainer.predict(system, datamodule=dm) elif args.validate: # manually set epoch and global_step as they cannot be automatically resumed set_system_status(system, cfg.resume) trainer.validate(system, datamodule=dm, ckpt_path=cfg.resume) elif args.test: # manually set epoch and global_step as they cannot be automatically resumed set_system_status(system, cfg.resume) trainer.test(system, datamodule=dm, ckpt_path=cfg.resume) elif args.export: set_system_status(system, cfg.resume) trainer.predict(system, datamodule=dm, ckpt_path=cfg.resume) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--config", required=True, help="path to config file") parser.add_argument( "--gpu", default="0", help="GPU(s) to be used. 0 means use the 1st available GPU. " "1,2 means use the 2nd and 3rd available GPU. " "If CUDA_VISIBLE_DEVICES is set before calling `launch.py`, " "this argument is ignored and all available GPUs are always used.", ) group = parser.add_mutually_exclusive_group(required=True) group.add_argument("--train", action="store_true") group.add_argument("--validate", action="store_true") group.add_argument("--test", action="store_true") group.add_argument("--export", action="store_true") parser.add_argument( "--gradio", action="store_true", help="if true, run in gradio mode" ) parser.add_argument( "--verbose", action="store_true", help="if true, set logging level to DEBUG" ) parser.add_argument( "--typecheck", action="store_true", help="whether to enable dynamic type checking", ) args, extras = parser.parse_known_args() if args.gradio: with contextlib.redirect_stdout(sys.stderr): main(args, extras) else: main(args, extras)