File size: 2,728 Bytes
cfb7702
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import torch
from scene import Scene
from pathlib import Path
from PIL import Image
import numpy as np
import sys
import os
from tqdm import tqdm
from os import makedirs
from gaussian_renderer import render
import torchvision
from utils.general_utils import safe_state
from argparse import ArgumentParser
from arguments import ModelParams, PipelineParams, get_combined_args, OptimizationParams
from gaussian_renderer import GaussianModel
from mediapy import write_video
from tqdm import tqdm
from einops import rearrange
from utils.colormaps import apply_depth_colormap


@torch.no_grad()
def render_spiral(dataset, opt, pipe, model_path):
    gaussians = GaussianModel(dataset.sh_degree)
    scene = Scene(dataset, gaussians, load_iteration=-1, shuffle=False)
    bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0]
    background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
    viewpoint_stack = scene.getTrainCameras().copy()
    views = []
    for view_cam in tqdm(viewpoint_stack):
        bg = torch.rand((3), device="cuda") if opt.random_background else background
        render_pkg = render(view_cam, gaussians, pipe, bg)
        image, viewspace_point_tensor, visibility_filter, radii = (
            render_pkg["depth"],
            render_pkg["viewspace_points"],
            render_pkg["visibility_filter"],
            render_pkg["radii"],
        )
        views.append(
            rearrange(
                apply_depth_colormap(
                    image[0][..., None],
                    accumulation=render_pkg["alpha"][0][..., None],
                ),
                "h w c -> c h w",
            )
        )
    views = torch.stack(views)

    write_video(
        f"./depth_spirals/{Path(dataset.model_path).stem}.mp4",
        rearrange(views.cpu().numpy(), "t c h w -> t h w c"),
        fps=3,
    )


if __name__ == "__main__":
    # Set up command line argument parser
    parser = ArgumentParser(description="Training script parameters")
    lp = ModelParams(parser)
    op = OptimizationParams(parser)
    pp = PipelineParams(parser)
    parser.add_argument("--iteration", default=-1, type=int)
    parser.add_argument("--skip_train", action="store_true")
    parser.add_argument("--skip_test", action="store_true")
    parser.add_argument("--quiet", action="store_true")
    args = parser.parse_args(sys.argv[1:])
    print("Rendering " + args.model_path)
    lp = lp.extract(args)
    fake_image = Image.fromarray(np.zeros([512, 512, 3], dtype=np.uint8))
    lp.images = [fake_image] * args.num_frames

    # Initialize system state (RNG)
    render_spiral(
        lp,
        op.extract(args),
        pp.extract(args),
        model_path=args.model_path,
    )