Spaces:
Runtime error
Runtime error
import os | |
import shutil | |
import subprocess | |
import pytorch_lightning | |
from craftsman.utils.config import dump_config | |
from craftsman.utils.misc import parse_version | |
if parse_version(pytorch_lightning.__version__) > parse_version("1.8"): | |
from pytorch_lightning.callbacks import Callback | |
else: | |
from pytorch_lightning.callbacks.base import Callback | |
from pytorch_lightning.callbacks.progress import TQDMProgressBar | |
from pytorch_lightning.utilities.rank_zero import rank_zero_only, rank_zero_warn | |
class VersionedCallback(Callback): | |
def __init__(self, save_root, version=None, use_version=True): | |
self.save_root = save_root | |
self._version = version | |
self.use_version = use_version | |
def version(self) -> int: | |
"""Get the experiment version. | |
Returns: | |
The experiment version if specified else the next version. | |
""" | |
if self._version is None: | |
self._version = self._get_next_version() | |
return self._version | |
def _get_next_version(self): | |
existing_versions = [] | |
if os.path.isdir(self.save_root): | |
for f in os.listdir(self.save_root): | |
bn = os.path.basename(f) | |
if bn.startswith("version_"): | |
dir_ver = os.path.splitext(bn)[0].split("_")[1].replace("/", "") | |
existing_versions.append(int(dir_ver)) | |
if len(existing_versions) == 0: | |
return 0 | |
return max(existing_versions) + 1 | |
def savedir(self): | |
if not self.use_version: | |
return self.save_root | |
return os.path.join( | |
self.save_root, | |
self.version | |
if isinstance(self.version, str) | |
else f"version_{self.version}", | |
) | |
class CodeSnapshotCallback(VersionedCallback): | |
def __init__(self, save_root, version=None, use_version=True): | |
super().__init__(save_root, version, use_version) | |
def get_file_list(self): | |
return [ | |
b.decode() | |
for b in set( | |
subprocess.check_output( | |
'git ls-files -- ":!:load/*"', shell=True | |
).splitlines() | |
) | |
| set( # hard code, TODO: use config to exclude folders or files | |
subprocess.check_output( | |
"git ls-files --others --exclude-standard", shell=True | |
).splitlines() | |
) | |
] | |
def save_code_snapshot(self): | |
os.makedirs(self.savedir, exist_ok=True) | |
for f in self.get_file_list(): | |
if not os.path.exists(f) or os.path.isdir(f): | |
continue | |
os.makedirs(os.path.join(self.savedir, os.path.dirname(f)), exist_ok=True) | |
shutil.copyfile(f, os.path.join(self.savedir, f)) | |
def on_fit_start(self, trainer, pl_module): | |
try: | |
self.save_code_snapshot() | |
except: | |
rank_zero_warn( | |
"Code snapshot is not saved. Please make sure you have git installed and are in a git repository." | |
) | |
class ConfigSnapshotCallback(VersionedCallback): | |
def __init__(self, config_path, config, save_root, version=None, use_version=True): | |
super().__init__(save_root, version, use_version) | |
self.config_path = config_path | |
self.config = config | |
def save_config_snapshot(self): | |
os.makedirs(self.savedir, exist_ok=True) | |
dump_config(os.path.join(self.savedir, "parsed.yaml"), self.config) | |
shutil.copyfile(self.config_path, os.path.join(self.savedir, "raw.yaml")) | |
def on_fit_start(self, trainer, pl_module): | |
self.save_config_snapshot() | |
class CustomProgressBar(TQDMProgressBar): | |
def get_metrics(self, *args, **kwargs): | |
# don't show the version number | |
items = super().get_metrics(*args, **kwargs) | |
items.pop("v_num", None) | |
return items | |
class ProgressCallback(Callback): | |
def __init__(self, save_path): | |
super().__init__() | |
self.save_path = save_path | |
self._file_handle = None | |
def file_handle(self): | |
if self._file_handle is None: | |
self._file_handle = open(self.save_path, "w") | |
return self._file_handle | |
def write(self, msg: str) -> None: | |
self.file_handle.seek(0) | |
self.file_handle.truncate() | |
self.file_handle.write(msg) | |
self.file_handle.flush() | |
def on_train_batch_end(self, trainer, pl_module, *args, **kwargs): | |
self.write( | |
f"Generation progress: {pl_module.true_global_step / trainer.max_steps * 100:.2f}%" | |
) | |
def on_validation_start(self, trainer, pl_module): | |
self.write(f"Rendering validation image ...") | |
def on_test_start(self, trainer, pl_module): | |
self.write(f"Rendering video ...") | |
def on_predict_start(self, trainer, pl_module): | |
self.write(f"Exporting mesh assets ...") | |