Spaces:
Runtime error
Runtime error
import torch | |
from scene import Scene | |
from pathlib import Path | |
from PIL import Image | |
import numpy as np | |
import sys | |
import os | |
from tqdm import tqdm | |
from os import makedirs | |
from gaussian_renderer import render | |
import torchvision | |
from utils.general_utils import safe_state | |
from argparse import ArgumentParser | |
from arguments import ModelParams, PipelineParams, get_combined_args, OptimizationParams | |
from gaussian_renderer import GaussianModel | |
from mediapy import write_video | |
from tqdm import tqdm | |
from einops import rearrange | |
from utils.colormaps import apply_depth_colormap | |
def render_spiral(dataset, opt, pipe, model_path): | |
gaussians = GaussianModel(dataset.sh_degree) | |
scene = Scene(dataset, gaussians, load_iteration=-1, shuffle=False) | |
bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0] | |
background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") | |
viewpoint_stack = scene.getTrainCameras().copy() | |
views = [] | |
for view_cam in tqdm(viewpoint_stack): | |
bg = torch.rand((3), device="cuda") if opt.random_background else background | |
render_pkg = render(view_cam, gaussians, pipe, bg) | |
image, viewspace_point_tensor, visibility_filter, radii = ( | |
render_pkg["depth"], | |
render_pkg["viewspace_points"], | |
render_pkg["visibility_filter"], | |
render_pkg["radii"], | |
) | |
views.append( | |
rearrange( | |
apply_depth_colormap( | |
image[0][..., None], | |
accumulation=render_pkg["alpha"][0][..., None], | |
), | |
"h w c -> c h w", | |
) | |
) | |
views = torch.stack(views) | |
write_video( | |
f"./depth_spirals/{Path(dataset.model_path).stem}.mp4", | |
rearrange(views.cpu().numpy(), "t c h w -> t h w c"), | |
fps=3, | |
) | |
if __name__ == "__main__": | |
# Set up command line argument parser | |
parser = ArgumentParser(description="Training script parameters") | |
lp = ModelParams(parser) | |
op = OptimizationParams(parser) | |
pp = PipelineParams(parser) | |
parser.add_argument("--iteration", default=-1, type=int) | |
parser.add_argument("--skip_train", action="store_true") | |
parser.add_argument("--skip_test", action="store_true") | |
parser.add_argument("--quiet", action="store_true") | |
args = parser.parse_args(sys.argv[1:]) | |
print("Rendering " + args.model_path) | |
lp = lp.extract(args) | |
fake_image = Image.fromarray(np.zeros([512, 512, 3], dtype=np.uint8)) | |
lp.images = [fake_image] * args.num_frames | |
# Initialize system state (RNG) | |
render_spiral( | |
lp, | |
op.extract(args), | |
pp.extract(args), | |
model_path=args.model_path, | |
) | |