File size: 3,666 Bytes
e32c848
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
da29c37
 
 
 
e32c848
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import os
from pathlib import Path

import torch
from lightning_fabric import seed_everything
from PIL import Image, ImageFile

from src.dataset import DATASET_REGISTRY
from src.decoder import DECODER_REGISTRY
from src.utils.opt import Opts
import torchvision.transforms as T

from src.utils.renderer import evaluation_feature, evaluation_feature_path, OctreeRender_trilinear_fast


def inference(cfg, render_mode: str, image=None):
    device = "cuda" if torch.cuda.is_available() else "cpu"

    ckpt = torch.load(cfg["model"]["tensorf"]["ckpt"], map_location=device)
    kwargs = ckpt['kwargs']
    kwargs.update({'device': device})
    print(device)
    tensorf = DECODER_REGISTRY.get(cfg["model"]["tensorf"]["model_name"])(**kwargs)
    tensorf.change_to_feature_mod(cfg["model"]["tensorf"]["lamb_sh"], device)
    tensorf.change_to_style_mod(device)
    tensorf.load(ckpt)
    tensorf.eval()
    tensorf.rayMarch_weight_thres = cfg["model"]["tensorf"]["rm_weight_mask_thre"]

    logfolder = os.path.dirname("./checkpoints")
    renderer= OctreeRender_trilinear_fast

    trans = T.Compose([T.Resize(size=(256, 256)), T.ToTensor()])
    if image:
        if torch.cuda.is_available():
            style_img = trans(image).cuda()[None, ...]
        else:
            style_img = trans(image)[None, ...]
    else:
        style_img = trans(Image.open(cfg["global"]["style_img"])).cuda()[None, ...]
    style_name = Path(cfg["global"]["style_img"]).stem

    if render_mode == "render_train":
        dataset = DATASET_REGISTRY.get(cfg["dataset"]["name"])(
            **cfg["dataset"]["train"]["params"],
        )
        os.makedirs(f'{logfolder}/{cfg["global"]["expname"]}/imgs_train_all/{style_name}', exist_ok=True)
        result = evaluation_feature(dataset, tensorf, renderer, cfg["sampler"]["params"]["chunk_size"],
                           f'{logfolder}/{cfg["global"]["expname"]}/imgs_train_all/{style_name}',
                           N_vis=-1, N_samples=-1, white_bg=dataset.white_bg, ndc_ray=cfg["model"]["tensorf"]["ndc_ray"],
                           style_img=style_img, device=device)

    if render_mode == "render_test":
        dataset = DATASET_REGISTRY.get(cfg["dataset"]["name"])(
            **cfg["dataset"]["val"]["params"],
        )
        os.makedirs(f'{logfolder}/{cfg["global"]["expname"]}/imgs_train_all/{style_name}', exist_ok=True)
        result = evaluation_feature(dataset, tensorf, renderer, cfg["sampler"]["params"]["chunk_size"],
                           f'{logfolder}/{cfg["global"]["expname"]}/imgs_train_all/{style_name}',
                           N_vis=-1, N_samples=-1, white_bg=dataset.white_bg, ndc_ray=cfg["model"]["tensorf"]["ndc_ray"],
                           style_img=style_img, device=device)

    if render_mode == "render_path":
        dataset = DATASET_REGISTRY.get(cfg["dataset"]["name"])(
            **cfg["dataset"]["val"]["params"],
        )
        c2ws = dataset.render_path
        os.makedirs(f'{logfolder}/{cfg["global"]["expname"]}/imgs_path_all/{style_name}', exist_ok=True)
        result = evaluation_feature_path(dataset, tensorf, c2ws, renderer, cfg["sampler"]["params"]["chunk_size"],
                                f'{logfolder}/{cfg["global"]["expname"]}/imgs_path_all/{style_name}',
                                N_vis=-1, N_samples=-1, white_bg=dataset.white_bg, ndc_ray=cfg["model"]["tensorf"]["ndc_ray"],
                                style_img=style_img, device=device)
    return result

if __name__ == "__main__":
    cfg = Opts(cfg="configs/style_inference.yml").parse_args()
    seed_everything(seed=cfg["global"]["SEED"])
    inference(cfg, "render_test")