Spaces:
Running
on
Zero
Running
on
Zero
import os | |
from dataclasses import dataclass, field | |
import pytorch_lightning as pl | |
import torch.nn.functional as F | |
import craftsman | |
from craftsman.utils.base import ( | |
Updateable, | |
update_end_if_possible, | |
update_if_possible, | |
) | |
from craftsman.utils.scheduler import parse_optimizer, parse_scheduler | |
from craftsman.utils.config import parse_structured | |
from craftsman.utils.misc import C, cleanup, get_device, load_module_weights | |
from craftsman.utils.saving import SaverMixin | |
from craftsman.utils.typing import * | |
class BaseSystem(pl.LightningModule, Updateable, SaverMixin): | |
class Config: | |
loggers: dict = field(default_factory=dict) | |
loss: dict = field(default_factory=dict) | |
optimizer: dict = field(default_factory=dict) | |
scheduler: Optional[dict] = None | |
weights: Optional[str] = None | |
weights_ignore_modules: Optional[List[str]] = None | |
cleanup_after_validation_step: bool = False | |
cleanup_after_test_step: bool = False | |
pretrained_model_path: Optional[str] = None | |
strict_load: bool = True | |
cfg: Config | |
def __init__(self, cfg, resumed=False) -> None: | |
super().__init__() | |
self.cfg = parse_structured(self.Config, cfg) | |
self._save_dir: Optional[str] = None | |
self._resumed: bool = resumed | |
self._resumed_eval: bool = False | |
self._resumed_eval_status: dict = {"global_step": 0, "current_epoch": 0} | |
if "loggers" in cfg: | |
self.create_loggers(cfg.loggers) | |
self.configure() | |
if self.cfg.weights is not None: | |
self.load_weights(self.cfg.weights, self.cfg.weights_ignore_modules) | |
self.post_configure() | |
def load_weights(self, weights: str, ignore_modules: Optional[List[str]] = None): | |
state_dict, epoch, global_step = load_module_weights( | |
weights, ignore_modules=ignore_modules, map_location="cpu" | |
) | |
self.load_state_dict(state_dict, strict=False) | |
# restore step-dependent states | |
self.do_update_step(epoch, global_step, on_load_weights=True) | |
def set_resume_status(self, current_epoch: int, global_step: int): | |
# restore correct epoch and global step in eval | |
self._resumed_eval = True | |
self._resumed_eval_status["current_epoch"] = current_epoch | |
self._resumed_eval_status["global_step"] = global_step | |
def resumed(self): | |
# whether from resumed checkpoint | |
return self._resumed | |
def true_global_step(self): | |
if self._resumed_eval: | |
return self._resumed_eval_status["global_step"] | |
else: | |
return self.global_step | |
def true_current_epoch(self): | |
if self._resumed_eval: | |
return self._resumed_eval_status["current_epoch"] | |
else: | |
return self.current_epoch | |
def configure(self) -> None: | |
pass | |
def post_configure(self) -> None: | |
""" | |
executed after weights are loaded | |
""" | |
pass | |
def C(self, value: Any) -> float: | |
return C(value, self.true_current_epoch, self.true_global_step) | |
def configure_optimizers(self): | |
optim = parse_optimizer(self.cfg.optimizer, self) | |
ret = { | |
"optimizer": optim, | |
} | |
if self.cfg.scheduler is not None: | |
ret.update( | |
{ | |
"lr_scheduler": parse_scheduler(self.cfg.scheduler, optim), | |
} | |
) | |
return ret | |
def training_step(self, batch, batch_idx): | |
raise NotImplementedError | |
def validation_step(self, batch, batch_idx): | |
raise NotImplementedError | |
def on_train_batch_end(self, outputs, batch, batch_idx): | |
self.dataset = self.trainer.train_dataloader.dataset | |
update_end_if_possible( | |
self.dataset, self.true_current_epoch, self.true_global_step | |
) | |
self.do_update_step_end(self.true_current_epoch, self.true_global_step) | |
def on_validation_batch_end(self, outputs, batch, batch_idx): | |
self.dataset = self.trainer.val_dataloaders.dataset | |
update_end_if_possible( | |
self.dataset, self.true_current_epoch, self.true_global_step | |
) | |
self.do_update_step_end(self.true_current_epoch, self.true_global_step) | |
if self.cfg.cleanup_after_validation_step: | |
# cleanup to save vram | |
cleanup() | |
def on_validation_epoch_end(self): | |
raise NotImplementedError | |
def test_step(self, batch, batch_idx): | |
raise NotImplementedError | |
def on_test_batch_end(self, outputs, batch, batch_idx): | |
self.dataset = self.trainer.test_dataloaders.dataset | |
update_end_if_possible( | |
self.dataset, self.true_current_epoch, self.true_global_step | |
) | |
self.do_update_step_end(self.true_current_epoch, self.true_global_step) | |
if self.cfg.cleanup_after_test_step: | |
# cleanup to save vram | |
cleanup() | |
def on_test_epoch_end(self): | |
pass | |
def predict_step(self, batch, batch_idx): | |
raise NotImplementedError | |
def on_predict_batch_end(self, outputs, batch, batch_idx): | |
self.dataset = self.trainer.predict_dataloaders.dataset | |
update_end_if_possible( | |
self.dataset, self.true_current_epoch, self.true_global_step | |
) | |
self.do_update_step_end(self.true_current_epoch, self.true_global_step) | |
if self.cfg.cleanup_after_test_step: | |
# cleanup to save vram | |
cleanup() | |
def on_predict_epoch_end(self): | |
pass | |
def preprocess_data(self, batch, stage): | |
pass | |
""" | |
Implementing on_after_batch_transfer of DataModule does the same. | |
But on_after_batch_transfer does not support DP. | |
""" | |
def on_train_batch_start(self, batch, batch_idx, unused=0): | |
self.preprocess_data(batch, "train") | |
self.dataset = self.trainer.train_dataloader.dataset | |
update_if_possible(self.dataset, self.true_current_epoch, self.true_global_step) | |
self.do_update_step(self.true_current_epoch, self.true_global_step) | |
def on_validation_batch_start(self, batch, batch_idx, dataloader_idx=0): | |
self.preprocess_data(batch, "validation") | |
self.dataset = self.trainer.val_dataloaders.dataset | |
update_if_possible(self.dataset, self.true_current_epoch, self.true_global_step) | |
self.do_update_step(self.true_current_epoch, self.true_global_step) | |
def on_test_batch_start(self, batch, batch_idx, dataloader_idx=0): | |
self.preprocess_data(batch, "test") | |
self.dataset = self.trainer.test_dataloaders.dataset | |
update_if_possible(self.dataset, self.true_current_epoch, self.true_global_step) | |
self.do_update_step(self.true_current_epoch, self.true_global_step) | |
def on_predict_batch_start(self, batch, batch_idx, dataloader_idx=0): | |
self.preprocess_data(batch, "predict") | |
self.dataset = self.trainer.predict_dataloaders.dataset | |
update_if_possible(self.dataset, self.true_current_epoch, self.true_global_step) | |
self.do_update_step(self.true_current_epoch, self.true_global_step) | |
def update_step(self, epoch: int, global_step: int, on_load_weights: bool = False): | |
pass | |
def on_before_optimizer_step(self, optimizer): | |
""" | |
# some gradient-related debugging goes here, example: | |
from lightning.pytorch.utilities import grad_norm | |
norms = grad_norm(self.geometry, norm_type=2) | |
print(norms) | |
""" | |
pass | |