Spaces:
Runtime error
Runtime error
import os | |
import subprocess | |
import shutil | |
from utils.misc import dump_config, parse_version | |
import pytorch_lightning | |
if parse_version(pytorch_lightning.__version__) > parse_version('1.8'): | |
from pytorch_lightning.callbacks import Callback | |
else: | |
from pytorch_lightning.callbacks.base import Callback | |
from pytorch_lightning.utilities.rank_zero import rank_zero_only, rank_zero_warn | |
from pytorch_lightning.callbacks.progress import TQDMProgressBar | |
class VersionedCallback(Callback): | |
def __init__(self, save_root, version=None, use_version=True): | |
self.save_root = save_root | |
self._version = version | |
self.use_version = use_version | |
def version(self) -> int: | |
"""Get the experiment version. | |
Returns: | |
The experiment version if specified else the next version. | |
""" | |
if self._version is None: | |
self._version = self._get_next_version() | |
return self._version | |
def _get_next_version(self): | |
existing_versions = [] | |
if os.path.isdir(self.save_root): | |
for f in os.listdir(self.save_root): | |
bn = os.path.basename(f) | |
if bn.startswith("version_"): | |
dir_ver = os.path.splitext(bn)[0].split("_")[1].replace("/", "") | |
existing_versions.append(int(dir_ver)) | |
if len(existing_versions) == 0: | |
return 0 | |
return max(existing_versions) + 1 | |
def savedir(self): | |
if not self.use_version: | |
return self.save_root | |
return os.path.join(self.save_root, self.version if isinstance(self.version, str) else f"version_{self.version}") | |
class CodeSnapshotCallback(VersionedCallback): | |
def __init__(self, save_root, version=None, use_version=True): | |
super().__init__(save_root, version, use_version) | |
def get_file_list(self): | |
return [ | |
b.decode() for b in | |
set(subprocess.check_output('git ls-files', shell=True).splitlines()) | | |
set(subprocess.check_output('git ls-files --others --exclude-standard', shell=True).splitlines()) | |
] | |
def save_code_snapshot(self): | |
os.makedirs(self.savedir, exist_ok=True) | |
for f in self.get_file_list(): | |
if not os.path.exists(f) or os.path.isdir(f): | |
continue | |
os.makedirs(os.path.join(self.savedir, os.path.dirname(f)), exist_ok=True) | |
shutil.copyfile(f, os.path.join(self.savedir, f)) | |
def on_fit_start(self, trainer, pl_module): | |
try: | |
self.save_code_snapshot() | |
except: | |
rank_zero_warn("Code snapshot is not saved. Please make sure you have git installed and are in a git repository.") | |
class ConfigSnapshotCallback(VersionedCallback): | |
def __init__(self, config, save_root, version=None, use_version=True): | |
super().__init__(save_root, version, use_version) | |
self.config = config | |
def save_config_snapshot(self): | |
os.makedirs(self.savedir, exist_ok=True) | |
dump_config(os.path.join(self.savedir, 'parsed.yaml'), self.config) | |
shutil.copyfile(self.config.cmd_args['config'], os.path.join(self.savedir, 'raw.yaml')) | |
def on_fit_start(self, trainer, pl_module): | |
self.save_config_snapshot() | |
class CustomProgressBar(TQDMProgressBar): | |
def get_metrics(self, *args, **kwargs): | |
# don't show the version number | |
items = super().get_metrics(*args, **kwargs) | |
items.pop("v_num", None) | |
return items | |