Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,069 Bytes
0f079b2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
import os
import shutil
import subprocess
import pytorch_lightning
from craftsman.utils.config import dump_config
from craftsman.utils.misc import parse_version
if parse_version(pytorch_lightning.__version__) > parse_version("1.8"):
from pytorch_lightning.callbacks import Callback
else:
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.callbacks.progress import TQDMProgressBar
from pytorch_lightning.utilities.rank_zero import rank_zero_only, rank_zero_warn
class VersionedCallback(Callback):
def __init__(self, save_root, version=None, use_version=True):
self.save_root = save_root
self._version = version
self.use_version = use_version
@property
def version(self) -> int:
"""Get the experiment version.
Returns:
The experiment version if specified else the next version.
"""
if self._version is None:
self._version = self._get_next_version()
return self._version
def _get_next_version(self):
existing_versions = []
if os.path.isdir(self.save_root):
for f in os.listdir(self.save_root):
bn = os.path.basename(f)
if bn.startswith("version_"):
dir_ver = os.path.splitext(bn)[0].split("_")[1].replace("/", "")
existing_versions.append(int(dir_ver))
if len(existing_versions) == 0:
return 0
return max(existing_versions) + 1
@property
def savedir(self):
if not self.use_version:
return self.save_root
return os.path.join(
self.save_root,
self.version
if isinstance(self.version, str)
else f"version_{self.version}",
)
class CodeSnapshotCallback(VersionedCallback):
def __init__(self, save_root, version=None, use_version=True):
super().__init__(save_root, version, use_version)
def get_file_list(self):
return [
b.decode()
for b in set(
subprocess.check_output(
'git ls-files -- ":!:load/*"', shell=True
).splitlines()
)
| set( # hard code, TODO: use config to exclude folders or files
subprocess.check_output(
"git ls-files --others --exclude-standard", shell=True
).splitlines()
)
]
@rank_zero_only
def save_code_snapshot(self):
os.makedirs(self.savedir, exist_ok=True)
for f in self.get_file_list():
if not os.path.exists(f) or os.path.isdir(f):
continue
os.makedirs(os.path.join(self.savedir, os.path.dirname(f)), exist_ok=True)
shutil.copyfile(f, os.path.join(self.savedir, f))
def on_fit_start(self, trainer, pl_module):
try:
self.save_code_snapshot()
except:
rank_zero_warn(
"Code snapshot is not saved. Please make sure you have git installed and are in a git repository."
)
class ConfigSnapshotCallback(VersionedCallback):
def __init__(self, config_path, config, save_root, version=None, use_version=True):
super().__init__(save_root, version, use_version)
self.config_path = config_path
self.config = config
@rank_zero_only
def save_config_snapshot(self):
os.makedirs(self.savedir, exist_ok=True)
dump_config(os.path.join(self.savedir, "parsed.yaml"), self.config)
shutil.copyfile(self.config_path, os.path.join(self.savedir, "raw.yaml"))
def on_fit_start(self, trainer, pl_module):
self.save_config_snapshot()
class CustomProgressBar(TQDMProgressBar):
def get_metrics(self, *args, **kwargs):
# don't show the version number
items = super().get_metrics(*args, **kwargs)
items.pop("v_num", None)
return items
class ProgressCallback(Callback):
def __init__(self, save_path):
super().__init__()
self.save_path = save_path
self._file_handle = None
@property
def file_handle(self):
if self._file_handle is None:
self._file_handle = open(self.save_path, "w")
return self._file_handle
@rank_zero_only
def write(self, msg: str) -> None:
self.file_handle.seek(0)
self.file_handle.truncate()
self.file_handle.write(msg)
self.file_handle.flush()
@rank_zero_only
def on_train_batch_end(self, trainer, pl_module, *args, **kwargs):
self.write(
f"Generation progress: {pl_module.true_global_step / trainer.max_steps * 100:.2f}%"
)
@rank_zero_only
def on_validation_start(self, trainer, pl_module):
self.write(f"Rendering validation image ...")
@rank_zero_only
def on_test_start(self, trainer, pl_module):
self.write(f"Rendering video ...")
@rank_zero_only
def on_predict_start(self, trainer, pl_module):
self.write(f"Exporting mesh assets ...")
|