Spaces:
Runtime error
Runtime error
import os | |
import torch | |
from lightning_fabric import seed_everything | |
import pytorch_lightning as pl | |
from pytorch_lightning.loggers.wandb import WandbLogger | |
import datetime | |
import wandb | |
from src.callback import CALLBACK_REGISTRY | |
from src.loop.feature_training_loop import FeatureTrainingLoop | |
from src.loop.style_training_loop import StyleTrainingLoop | |
from src.model import MODEL_REGISTRY | |
from src.utils.opt import Opts | |
from src.utils.renderer import OctreeRender_trilinear_fast | |
def train(config): | |
model = MODEL_REGISTRY.get(config["model"]["name"])(config) | |
epoch = config["trainer"]["n_iters"] | |
time_str = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") | |
run_name = f"{config['global']['name']}-{time_str}" | |
wandb_logger = WandbLogger( | |
project=config["global"]["project_name"], | |
name=run_name, | |
save_dir=config["global"]["save_dir"], | |
entity=config["global"]["username"], | |
) | |
wandb_logger.watch((model)) | |
wandb_logger.experiment.config.update(config) | |
callbacks = [ | |
CALLBACK_REGISTRY.get(mcfg["name"])(**mcfg["params"]) | |
for mcfg in config["callbacks"] | |
] | |
trainer = pl.Trainer( | |
default_root_dir="src", | |
check_val_every_n_epoch=config["trainer"]["evaluate_interval"], | |
log_every_n_steps=config["trainer"]["log_interval"], | |
enable_checkpointing=True, | |
accelerator="gpu" if torch.cuda.is_available() else "auto", | |
devices=-1, | |
sync_batchnorm=True if torch.cuda.is_available() else False, | |
precision=16 if config["trainer"]["use_fp16"] else 32, | |
fast_dev_run=config["trainer"]["debug"], | |
logger=wandb_logger, | |
callbacks=callbacks, | |
num_sanity_val_steps=-1, # Sanity full validation required for visualization callbacks | |
deterministic=False, | |
auto_lr_find=True, | |
) | |
print("Trainer: ", trainer) | |
if cfg["model"]["type"] == "feature": | |
trainer.fit_loop = FeatureTrainingLoop(epoch=epoch, cfg=config, renderer=OctreeRender_trilinear_fast) | |
elif cfg["model"]["type"] == "style": | |
trainer.fit_loop = StyleTrainingLoop(epoch=epoch, cfg=config, renderer=OctreeRender_trilinear_fast) | |
else: | |
raise NotImplementedError | |
trainer.fit(model, ckpt_path=config["global"]["resume"]) | |
return os.path.join(os.path.join(os.path.join(config["global"]["save_dir"], | |
config["global"]["project_name"]), wandb.run.id), "checkpoints") | |
if __name__ == "__main__": | |
cfg = Opts(cfg="configs/style_baseline.yml").parse_args() | |
seed_everything(seed=cfg["global"]["SEED"]) | |
train(cfg) | |