import os import shutil import subprocess import pytorch_lightning from craftsman.utils.config import dump_config from craftsman.utils.misc import parse_version if parse_version(pytorch_lightning.__version__) > parse_version("1.8"): from pytorch_lightning.callbacks import Callback else: from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.callbacks.progress import TQDMProgressBar from pytorch_lightning.utilities.rank_zero import rank_zero_only, rank_zero_warn class VersionedCallback(Callback): def __init__(self, save_root, version=None, use_version=True): self.save_root = save_root self._version = version self.use_version = use_version @property def version(self) -> int: """Get the experiment version. Returns: The experiment version if specified else the next version. """ if self._version is None: self._version = self._get_next_version() return self._version def _get_next_version(self): existing_versions = [] if os.path.isdir(self.save_root): for f in os.listdir(self.save_root): bn = os.path.basename(f) if bn.startswith("version_"): dir_ver = os.path.splitext(bn)[0].split("_")[1].replace("/", "") existing_versions.append(int(dir_ver)) if len(existing_versions) == 0: return 0 return max(existing_versions) + 1 @property def savedir(self): if not self.use_version: return self.save_root return os.path.join( self.save_root, self.version if isinstance(self.version, str) else f"version_{self.version}", ) class CodeSnapshotCallback(VersionedCallback): def __init__(self, save_root, version=None, use_version=True): super().__init__(save_root, version, use_version) def get_file_list(self): return [ b.decode() for b in set( subprocess.check_output( 'git ls-files -- ":!:load/*"', shell=True ).splitlines() ) | set( # hard code, TODO: use config to exclude folders or files subprocess.check_output( "git ls-files --others --exclude-standard", shell=True ).splitlines() ) ] @rank_zero_only def save_code_snapshot(self): os.makedirs(self.savedir, exist_ok=True) for f in self.get_file_list(): if not os.path.exists(f) or os.path.isdir(f): continue os.makedirs(os.path.join(self.savedir, os.path.dirname(f)), exist_ok=True) shutil.copyfile(f, os.path.join(self.savedir, f)) def on_fit_start(self, trainer, pl_module): try: self.save_code_snapshot() except: rank_zero_warn( "Code snapshot is not saved. Please make sure you have git installed and are in a git repository." ) class ConfigSnapshotCallback(VersionedCallback): def __init__(self, config_path, config, save_root, version=None, use_version=True): super().__init__(save_root, version, use_version) self.config_path = config_path self.config = config @rank_zero_only def save_config_snapshot(self): os.makedirs(self.savedir, exist_ok=True) dump_config(os.path.join(self.savedir, "parsed.yaml"), self.config) shutil.copyfile(self.config_path, os.path.join(self.savedir, "raw.yaml")) def on_fit_start(self, trainer, pl_module): self.save_config_snapshot() class CustomProgressBar(TQDMProgressBar): def get_metrics(self, *args, **kwargs): # don't show the version number items = super().get_metrics(*args, **kwargs) items.pop("v_num", None) return items class ProgressCallback(Callback): def __init__(self, save_path): super().__init__() self.save_path = save_path self._file_handle = None @property def file_handle(self): if self._file_handle is None: self._file_handle = open(self.save_path, "w") return self._file_handle @rank_zero_only def write(self, msg: str) -> None: self.file_handle.seek(0) self.file_handle.truncate() self.file_handle.write(msg) self.file_handle.flush() @rank_zero_only def on_train_batch_end(self, trainer, pl_module, *args, **kwargs): self.write( f"Generation progress: {pl_module.true_global_step / trainer.max_steps * 100:.2f}%" ) @rank_zero_only def on_validation_start(self, trainer, pl_module): self.write(f"Rendering validation image ...") @rank_zero_only def on_test_start(self, trainer, pl_module): self.write(f"Rendering video ...") @rank_zero_only def on_predict_start(self, trainer, pl_module): self.write(f"Exporting mesh assets ...")