AnTo2209's picture
refactor
e32c848
import os
import torch
from lightning_fabric import seed_everything
import pytorch_lightning as pl
from pytorch_lightning.loggers.wandb import WandbLogger
import datetime
import wandb
from src.callback import CALLBACK_REGISTRY
from src.loop.feature_training_loop import FeatureTrainingLoop
from src.loop.style_training_loop import StyleTrainingLoop
from src.model import MODEL_REGISTRY
from src.utils.opt import Opts
from src.utils.renderer import OctreeRender_trilinear_fast
def train(config):
model = MODEL_REGISTRY.get(config["model"]["name"])(config)
epoch = config["trainer"]["n_iters"]
time_str = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
run_name = f"{config['global']['name']}-{time_str}"
wandb_logger = WandbLogger(
project=config["global"]["project_name"],
name=run_name,
save_dir=config["global"]["save_dir"],
entity=config["global"]["username"],
)
wandb_logger.watch((model))
wandb_logger.experiment.config.update(config)
callbacks = [
CALLBACK_REGISTRY.get(mcfg["name"])(**mcfg["params"])
for mcfg in config["callbacks"]
]
trainer = pl.Trainer(
default_root_dir="src",
check_val_every_n_epoch=config["trainer"]["evaluate_interval"],
log_every_n_steps=config["trainer"]["log_interval"],
enable_checkpointing=True,
accelerator="gpu" if torch.cuda.is_available() else "auto",
devices=-1,
sync_batchnorm=True if torch.cuda.is_available() else False,
precision=16 if config["trainer"]["use_fp16"] else 32,
fast_dev_run=config["trainer"]["debug"],
logger=wandb_logger,
callbacks=callbacks,
num_sanity_val_steps=-1, # Sanity full validation required for visualization callbacks
deterministic=False,
auto_lr_find=True,
)
print("Trainer: ", trainer)
if cfg["model"]["type"] == "feature":
trainer.fit_loop = FeatureTrainingLoop(epoch=epoch, cfg=config, renderer=OctreeRender_trilinear_fast)
elif cfg["model"]["type"] == "style":
trainer.fit_loop = StyleTrainingLoop(epoch=epoch, cfg=config, renderer=OctreeRender_trilinear_fast)
else:
raise NotImplementedError
trainer.fit(model, ckpt_path=config["global"]["resume"])
return os.path.join(os.path.join(os.path.join(config["global"]["save_dir"],
config["global"]["project_name"]), wandb.run.id), "checkpoints")
if __name__ == "__main__":
cfg = Opts(cfg="configs/style_baseline.yml").parse_args()
seed_everything(seed=cfg["global"]["SEED"])
train(cfg)