AnTo2209's picture
add inference
da29c37
raw
history blame
3.67 kB
import os
from pathlib import Path
import torch
from lightning_fabric import seed_everything
from PIL import Image, ImageFile
from src.dataset import DATASET_REGISTRY
from src.decoder import DECODER_REGISTRY
from src.utils.opt import Opts
import torchvision.transforms as T
from src.utils.renderer import evaluation_feature, evaluation_feature_path, OctreeRender_trilinear_fast
def inference(cfg, render_mode: str, image=None):
device = "cuda" if torch.cuda.is_available() else "cpu"
ckpt = torch.load(cfg["model"]["tensorf"]["ckpt"], map_location=device)
kwargs = ckpt['kwargs']
kwargs.update({'device': device})
print(device)
tensorf = DECODER_REGISTRY.get(cfg["model"]["tensorf"]["model_name"])(**kwargs)
tensorf.change_to_feature_mod(cfg["model"]["tensorf"]["lamb_sh"], device)
tensorf.change_to_style_mod(device)
tensorf.load(ckpt)
tensorf.eval()
tensorf.rayMarch_weight_thres = cfg["model"]["tensorf"]["rm_weight_mask_thre"]
logfolder = os.path.dirname("./checkpoints")
renderer= OctreeRender_trilinear_fast
trans = T.Compose([T.Resize(size=(256, 256)), T.ToTensor()])
if image:
if torch.cuda.is_available():
style_img = trans(image).cuda()[None, ...]
else:
style_img = trans(image)[None, ...]
else:
style_img = trans(Image.open(cfg["global"]["style_img"])).cuda()[None, ...]
style_name = Path(cfg["global"]["style_img"]).stem
if render_mode == "render_train":
dataset = DATASET_REGISTRY.get(cfg["dataset"]["name"])(
**cfg["dataset"]["train"]["params"],
)
os.makedirs(f'{logfolder}/{cfg["global"]["expname"]}/imgs_train_all/{style_name}', exist_ok=True)
result = evaluation_feature(dataset, tensorf, renderer, cfg["sampler"]["params"]["chunk_size"],
f'{logfolder}/{cfg["global"]["expname"]}/imgs_train_all/{style_name}',
N_vis=-1, N_samples=-1, white_bg=dataset.white_bg, ndc_ray=cfg["model"]["tensorf"]["ndc_ray"],
style_img=style_img, device=device)
if render_mode == "render_test":
dataset = DATASET_REGISTRY.get(cfg["dataset"]["name"])(
**cfg["dataset"]["val"]["params"],
)
os.makedirs(f'{logfolder}/{cfg["global"]["expname"]}/imgs_train_all/{style_name}', exist_ok=True)
result = evaluation_feature(dataset, tensorf, renderer, cfg["sampler"]["params"]["chunk_size"],
f'{logfolder}/{cfg["global"]["expname"]}/imgs_train_all/{style_name}',
N_vis=-1, N_samples=-1, white_bg=dataset.white_bg, ndc_ray=cfg["model"]["tensorf"]["ndc_ray"],
style_img=style_img, device=device)
if render_mode == "render_path":
dataset = DATASET_REGISTRY.get(cfg["dataset"]["name"])(
**cfg["dataset"]["val"]["params"],
)
c2ws = dataset.render_path
os.makedirs(f'{logfolder}/{cfg["global"]["expname"]}/imgs_path_all/{style_name}', exist_ok=True)
result = evaluation_feature_path(dataset, tensorf, c2ws, renderer, cfg["sampler"]["params"]["chunk_size"],
f'{logfolder}/{cfg["global"]["expname"]}/imgs_path_all/{style_name}',
N_vis=-1, N_samples=-1, white_bg=dataset.white_bg, ndc_ray=cfg["model"]["tensorf"]["ndc_ray"],
style_img=style_img, device=device)
return result
if __name__ == "__main__":
cfg = Opts(cfg="configs/style_inference.yml").parse_args()
seed_everything(seed=cfg["global"]["SEED"])
inference(cfg, "render_test")