import os from pathlib import Path import torch from lightning_fabric import seed_everything from PIL import Image, ImageFile from src.dataset import DATASET_REGISTRY from src.decoder import DECODER_REGISTRY from src.utils.opt import Opts import torchvision.transforms as T from src.utils.renderer import evaluation_feature, evaluation_feature_path, OctreeRender_trilinear_fast def inference(cfg, render_mode: str, image=None): device = "cuda" if torch.cuda.is_available() else "cpu" ckpt = torch.load(cfg["model"]["tensorf"]["ckpt"], map_location=device) kwargs = ckpt['kwargs'] kwargs.update({'device': device}) print(device) tensorf = DECODER_REGISTRY.get(cfg["model"]["tensorf"]["model_name"])(**kwargs) tensorf.change_to_feature_mod(cfg["model"]["tensorf"]["lamb_sh"], device) tensorf.change_to_style_mod(device) tensorf.load(ckpt) tensorf.eval() tensorf.rayMarch_weight_thres = cfg["model"]["tensorf"]["rm_weight_mask_thre"] logfolder = os.path.dirname("./checkpoints") renderer= OctreeRender_trilinear_fast trans = T.Compose([T.Resize(size=(256, 256)), T.ToTensor()]) if image: if torch.cuda.is_available(): style_img = trans(image).cuda()[None, ...] else: style_img = trans(image)[None, ...] else: style_img = trans(Image.open(cfg["global"]["style_img"])).cuda()[None, ...] style_name = Path(cfg["global"]["style_img"]).stem if render_mode == "render_train": dataset = DATASET_REGISTRY.get(cfg["dataset"]["name"])( **cfg["dataset"]["train"]["params"], ) os.makedirs(f'{logfolder}/{cfg["global"]["expname"]}/imgs_train_all/{style_name}', exist_ok=True) result = evaluation_feature(dataset, tensorf, renderer, cfg["sampler"]["params"]["chunk_size"], f'{logfolder}/{cfg["global"]["expname"]}/imgs_train_all/{style_name}', N_vis=-1, N_samples=-1, white_bg=dataset.white_bg, ndc_ray=cfg["model"]["tensorf"]["ndc_ray"], style_img=style_img, device=device) if render_mode == "render_test": dataset = DATASET_REGISTRY.get(cfg["dataset"]["name"])( **cfg["dataset"]["val"]["params"], ) os.makedirs(f'{logfolder}/{cfg["global"]["expname"]}/imgs_train_all/{style_name}', exist_ok=True) result = evaluation_feature(dataset, tensorf, renderer, cfg["sampler"]["params"]["chunk_size"], f'{logfolder}/{cfg["global"]["expname"]}/imgs_train_all/{style_name}', N_vis=-1, N_samples=-1, white_bg=dataset.white_bg, ndc_ray=cfg["model"]["tensorf"]["ndc_ray"], style_img=style_img, device=device) if render_mode == "render_path": dataset = DATASET_REGISTRY.get(cfg["dataset"]["name"])( **cfg["dataset"]["val"]["params"], ) c2ws = dataset.render_path os.makedirs(f'{logfolder}/{cfg["global"]["expname"]}/imgs_path_all/{style_name}', exist_ok=True) result = evaluation_feature_path(dataset, tensorf, c2ws, renderer, cfg["sampler"]["params"]["chunk_size"], f'{logfolder}/{cfg["global"]["expname"]}/imgs_path_all/{style_name}', N_vis=-1, N_samples=-1, white_bg=dataset.white_bg, ndc_ray=cfg["model"]["tensorf"]["ndc_ray"], style_img=style_img, device=device) return result if __name__ == "__main__": cfg = Opts(cfg="configs/style_inference.yml").parse_args() seed_everything(seed=cfg["global"]["SEED"]) inference(cfg, "render_test")