import os from datetime import datetime import pytorch_lightning as pl # from torch.utils.tensorboard import SummaryWriter import torch from src.decoder import DECODER_REGISTRY from src.utils.loss import TVLoss class StyleRF(pl.LightningModule): def __init__(self, cfg): super().__init__() self.cfg = cfg self.init_model() def init_model(self): logfolder = f'{self.cfg["global"]["base_dir"]}/{self.cfg["global"]["expname"]}{datetime.now().strftime("-%Y%m%d-%H%M%S")}' device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # init log file os.makedirs(logfolder, exist_ok=True) # summary_writer = SummaryWriter(logfolder) # init parameters assert self.cfg["model"]["tensorf"]["ckpt"] is not None, 'Have to be pre-trained to get density fielded!' type = self.cfg["model"]["type"] ckpt = torch.load(self.cfg["model"]["tensorf"]["ckpt"], map_location=device) kwargs = ckpt['kwargs'] kwargs.update({'device': device}) self.ndc_ray = self.cfg["model"]["tensorf"]["ndc_ray"] if type == "feature": self.tensorf = DECODER_REGISTRY.get(self.cfg["model"]["tensorf"]["model_name"])(**kwargs) self.tensorf.load(ckpt) self.tensorf.change_to_feature_mod(self.cfg["model"]["tensorf"]["lamb_sh"], device) elif type == "style": self.tensorf = DECODER_REGISTRY.get(self.cfg["model"]["tensorf"]["model_name"])(**kwargs) self.tensorf.change_to_feature_mod(self.cfg["model"]["tensorf"]["lamb_sh"], device) self.tensorf.load(ckpt) self.tensorf.change_to_style_mod(device) self.tensorf.rayMarch_weight_thres = self.cfg["model"]["tensorf"]["rm_weight_mask_thre"] TV_weight_feature = self.cfg["model"]["tensorf"]["TV_weight_feature"] self.tvreg = TVLoss() print(f"initial TV_weight_feature: {TV_weight_feature}") def forward(self, batch): pass def training_step(self, batch, batch_idx): pass # # 2. Calculate loss # loss = self.compute_loss(forwarded_batch=forwarded_batch, input_batch=batch) # # 3. Update monitor # self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True) # return {"loss": loss} # def validation_step(self, batch, batch_idx): # # 1. Get embeddings from model # forwarded_batch = self.forward(batch) # # 2. Calculate loss # loss = self.compute_loss(forwarded_batch=forwarded_batch, input_batch=batch) # # 3. Update metric for each batch # self.log("val_loss", loss, on_step=True, on_epoch=True, prog_bar=True) # self.metric_evaluator.append( # g_emb=forwarded_batch["pc_embedding_feats"].float().clone().detach(), # q_emb=forwarded_batch["query_embedding_feats"].float().clone().detach(), # query_ids=batch["query_ids"], # gallery_ids=batch["point_cloud_ids"], # target_ids=batch["point_cloud_ids"], # ) # # return {"loss": loss} # # def validation_epoch_end(self, outputs) -> None: # """ # Callback at validation epoch end to do additional works # with output of validation step, note that this is called # before `training_epoch_end()` # Args: # outputs: output of validation step # """ # self.log_dict( # self.metric_evaluator.evaluate(), # prog_bar=True, # on_step=False, # on_epoch=True, # ) # self.metric_evaluator.reset()