Spaces:
Runtime error
Runtime error
import os | |
from pathlib import Path | |
import torch | |
from lightning_fabric import seed_everything | |
from PIL import Image, ImageFile | |
from src.dataset import DATASET_REGISTRY | |
from src.decoder import DECODER_REGISTRY | |
from src.utils.opt import Opts | |
import torchvision.transforms as T | |
from src.utils.renderer import evaluation_feature, evaluation_feature_path, OctreeRender_trilinear_fast | |
def inference(cfg, render_mode: str, image=None): | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
ckpt = torch.load(cfg["model"]["tensorf"]["ckpt"], map_location=device) | |
kwargs = ckpt['kwargs'] | |
kwargs.update({'device': device}) | |
print(device) | |
tensorf = DECODER_REGISTRY.get(cfg["model"]["tensorf"]["model_name"])(**kwargs) | |
tensorf.change_to_feature_mod(cfg["model"]["tensorf"]["lamb_sh"], device) | |
tensorf.change_to_style_mod(device) | |
tensorf.load(ckpt) | |
tensorf.eval() | |
tensorf.rayMarch_weight_thres = cfg["model"]["tensorf"]["rm_weight_mask_thre"] | |
logfolder = os.path.dirname("./checkpoints") | |
renderer= OctreeRender_trilinear_fast | |
trans = T.Compose([T.Resize(size=(256, 256)), T.ToTensor()]) | |
if image: | |
if torch.cuda.is_available(): | |
style_img = trans(image).cuda()[None, ...] | |
else: | |
style_img = trans(image)[None, ...] | |
else: | |
style_img = trans(Image.open(cfg["global"]["style_img"])).cuda()[None, ...] | |
style_name = Path(cfg["global"]["style_img"]).stem | |
if render_mode == "render_train": | |
dataset = DATASET_REGISTRY.get(cfg["dataset"]["name"])( | |
**cfg["dataset"]["train"]["params"], | |
) | |
os.makedirs(f'{logfolder}/{cfg["global"]["expname"]}/imgs_train_all/{style_name}', exist_ok=True) | |
result = evaluation_feature(dataset, tensorf, renderer, cfg["sampler"]["params"]["chunk_size"], | |
f'{logfolder}/{cfg["global"]["expname"]}/imgs_train_all/{style_name}', | |
N_vis=-1, N_samples=-1, white_bg=dataset.white_bg, ndc_ray=cfg["model"]["tensorf"]["ndc_ray"], | |
style_img=style_img, device=device) | |
if render_mode == "render_test": | |
dataset = DATASET_REGISTRY.get(cfg["dataset"]["name"])( | |
**cfg["dataset"]["val"]["params"], | |
) | |
os.makedirs(f'{logfolder}/{cfg["global"]["expname"]}/imgs_train_all/{style_name}', exist_ok=True) | |
result = evaluation_feature(dataset, tensorf, renderer, cfg["sampler"]["params"]["chunk_size"], | |
f'{logfolder}/{cfg["global"]["expname"]}/imgs_train_all/{style_name}', | |
N_vis=-1, N_samples=-1, white_bg=dataset.white_bg, ndc_ray=cfg["model"]["tensorf"]["ndc_ray"], | |
style_img=style_img, device=device) | |
if render_mode == "render_path": | |
dataset = DATASET_REGISTRY.get(cfg["dataset"]["name"])( | |
**cfg["dataset"]["val"]["params"], | |
) | |
c2ws = dataset.render_path | |
os.makedirs(f'{logfolder}/{cfg["global"]["expname"]}/imgs_path_all/{style_name}', exist_ok=True) | |
result = evaluation_feature_path(dataset, tensorf, c2ws, renderer, cfg["sampler"]["params"]["chunk_size"], | |
f'{logfolder}/{cfg["global"]["expname"]}/imgs_path_all/{style_name}', | |
N_vis=-1, N_samples=-1, white_bg=dataset.white_bg, ndc_ray=cfg["model"]["tensorf"]["ndc_ray"], | |
style_img=style_img, device=device) | |
return result | |
if __name__ == "__main__": | |
cfg = Opts(cfg="configs/style_inference.yml").parse_args() | |
seed_everything(seed=cfg["global"]["SEED"]) | |
inference(cfg, "render_test") | |