Spaces:
Runtime error
Runtime error
# | |
# Copyright (C) 2023, Inria | |
# GRAPHDECO research group, https://team.inria.fr/graphdeco | |
# All rights reserved. | |
# | |
# This software is free for non-commercial, research and evaluation use | |
# under the terms of the LICENSE.md file. | |
# | |
# For inquiries contact george.drettakis@inria.fr | |
# | |
import os | |
import torch | |
from random import randint | |
from PIL import Image | |
from mediapy import read_video | |
from utils.loss_utils import l1_loss, ssim, lpips | |
from gaussian_renderer import render, network_gui | |
import sys | |
from scene import Scene, GaussianModel | |
from utils.general_utils import safe_state | |
import uuid | |
from tqdm import tqdm | |
from utils.image_utils import psnr | |
from argparse import ArgumentParser, Namespace | |
from arguments import ModelParams, PipelineParams, OptimizationParams | |
from scripts.sampling.simple_mv_latent_sample import sample_one | |
try: | |
from torch.utils.tensorboard import SummaryWriter | |
TENSORBOARD_FOUND = True | |
except ImportError: | |
TENSORBOARD_FOUND = False | |
def training( | |
dataset, | |
opt, | |
pipe, | |
testing_iterations, | |
saving_iterations, | |
checkpoint_iterations, | |
checkpoint, | |
debug_from, | |
): | |
first_iter = 0 | |
tb_writer = prepare_output_and_logger(dataset) | |
gaussians = GaussianModel(dataset.sh_degree) | |
scene = Scene(dataset, gaussians) | |
gaussians.training_setup(opt) | |
if checkpoint: | |
(model_params, first_iter) = torch.load(checkpoint) | |
gaussians.restore(model_params, opt) | |
bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0] | |
background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") | |
iter_start = torch.cuda.Event(enable_timing=True) | |
iter_end = torch.cuda.Event(enable_timing=True) | |
viewpoint_stack = None | |
ema_loss_for_log = 0.0 | |
progress_bar = tqdm(range(first_iter, opt.iterations), desc="Training progress") | |
first_iter += 1 | |
for iteration in range(first_iter, opt.iterations + 1): | |
if network_gui.conn == None: | |
network_gui.try_connect() | |
while network_gui.conn != None: | |
try: | |
net_image_bytes = None | |
( | |
custom_cam, | |
do_training, | |
pipe.convert_SHs_python, | |
pipe.compute_cov3D_python, | |
keep_alive, | |
scaling_modifer, | |
) = network_gui.receive() | |
if custom_cam != None: | |
net_image = render( | |
custom_cam, gaussians, pipe, background, scaling_modifer | |
)["render"] | |
net_image_bytes = memoryview( | |
(torch.clamp(net_image, min=0, max=1.0) * 255) | |
.byte() | |
.permute(1, 2, 0) | |
.contiguous() | |
.cpu() | |
.numpy() | |
) | |
network_gui.send(net_image_bytes, dataset.source_path) | |
if do_training and ( | |
(iteration < int(opt.iterations)) or not keep_alive | |
): | |
break | |
except Exception as e: | |
network_gui.conn = None | |
iter_start.record() | |
gaussians.update_learning_rate(iteration) | |
# Every 1000 its we increase the levels of SH up to a maximum degree | |
if iteration % 1000 == 0: | |
gaussians.oneupSHdegree() | |
# Pick a random Camera | |
if not viewpoint_stack: | |
viewpoint_stack = scene.getTrainCameras().copy() | |
viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack) - 1)) | |
# Render | |
if (iteration - 1) == debug_from: | |
pipe.debug = True | |
bg = torch.rand((3), device="cuda") if opt.random_background else background | |
render_pkg = render(viewpoint_cam, gaussians, pipe, bg) | |
image, viewspace_point_tensor, visibility_filter, radii = ( | |
render_pkg["render"], | |
render_pkg["viewspace_points"], | |
render_pkg["visibility_filter"], | |
render_pkg["radii"], | |
) | |
# Loss | |
gt_image = viewpoint_cam.original_image.cuda() | |
Ll1 = l1_loss(image, gt_image) | |
loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * ( | |
1.0 - ssim(image, gt_image) | |
) | |
if opt.lambda_lpips > 0: | |
loss += opt.lambda_lpips * lpips(image, gt_image) | |
loss += torch.mean(gaussians.get_opacity) * 0.1 | |
loss.backward() | |
iter_end.record() | |
with torch.no_grad(): | |
# Progress bar | |
ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log | |
if iteration % 10 == 0: | |
progress_bar.set_postfix({"Loss": f"{ema_loss_for_log:.{7}f}"}) | |
progress_bar.update(10) | |
if iteration == opt.iterations: | |
progress_bar.close() | |
# Log and save | |
training_report( | |
tb_writer, | |
iteration, | |
Ll1, | |
loss, | |
l1_loss, | |
iter_start.elapsed_time(iter_end), | |
testing_iterations, | |
scene, | |
render, | |
(pipe, background), | |
) | |
if iteration in saving_iterations: | |
print("\n[ITER {}] Saving Gaussians".format(iteration)) | |
scene.save(iteration) | |
# Densification | |
if iteration < opt.densify_until_iter: | |
# Keep track of max radii in image-space for pruning | |
gaussians.max_radii2D[visibility_filter] = torch.max( | |
gaussians.max_radii2D[visibility_filter], radii[visibility_filter] | |
) | |
gaussians.add_densification_stats( | |
viewspace_point_tensor, visibility_filter | |
) | |
if ( | |
iteration > opt.densify_from_iter | |
and iteration % opt.densification_interval == 0 | |
): | |
size_threshold = ( | |
20 if iteration > opt.opacity_reset_interval else None | |
) | |
gaussians.densify_and_prune( | |
opt.densify_grad_threshold, | |
0.005, | |
scene.cameras_extent, | |
size_threshold, | |
) | |
if iteration % opt.opacity_reset_interval == 0 or ( | |
dataset.white_background and iteration == opt.densify_from_iter | |
): | |
gaussians.reset_opacity() | |
# Optimizer step | |
if iteration < opt.iterations: | |
gaussians.optimizer.step() | |
gaussians.optimizer.zero_grad(set_to_none=True) | |
if iteration in checkpoint_iterations: | |
print("\n[ITER {}] Saving Checkpoint".format(iteration)) | |
torch.save( | |
(gaussians.capture(), iteration), | |
scene.model_path + "/chkpnt" + str(iteration) + ".pth", | |
) | |
def prepare_output_and_logger(args): | |
if not args.model_path: | |
if os.getenv("OAR_JOB_ID"): | |
unique_str = os.getenv("OAR_JOB_ID") | |
else: | |
unique_str = str(uuid.uuid4()) | |
args.model_path = os.path.join("./output/", unique_str[0:10]) | |
# Set up output folder | |
print("Output folder: {}".format(args.model_path)) | |
os.makedirs(args.model_path, exist_ok=True) | |
with open(os.path.join(args.model_path, "cfg_args"), "w") as cfg_log_f: | |
cfg_log_f.write(str(Namespace(**vars(args)))) | |
# Create Tensorboard writer | |
tb_writer = None | |
if TENSORBOARD_FOUND: | |
tb_writer = SummaryWriter(args.model_path) | |
else: | |
print("Tensorboard not available: not logging progress") | |
return tb_writer | |
def training_report( | |
tb_writer, | |
iteration, | |
Ll1, | |
loss, | |
l1_loss, | |
elapsed, | |
testing_iterations, | |
scene: Scene, | |
renderFunc, | |
renderArgs, | |
): | |
if tb_writer: | |
tb_writer.add_scalar("train_loss_patches/l1_loss", Ll1.item(), iteration) | |
tb_writer.add_scalar("train_loss_patches/total_loss", loss.item(), iteration) | |
tb_writer.add_scalar("iter_time", elapsed, iteration) | |
# Report test and samples of training set | |
if iteration in testing_iterations: | |
torch.cuda.empty_cache() | |
validation_configs = ( | |
{"name": "test", "cameras": scene.getTestCameras()}, | |
{ | |
"name": "train", | |
"cameras": [ | |
scene.getTrainCameras()[idx % len(scene.getTrainCameras())] | |
for idx in range(5, 30, 5) | |
], | |
}, | |
) | |
for config in validation_configs: | |
if config["cameras"] and len(config["cameras"]) > 0: | |
l1_test = 0.0 | |
psnr_test = 0.0 | |
for idx, viewpoint in enumerate(config["cameras"]): | |
image = torch.clamp( | |
renderFunc(viewpoint, scene.gaussians, *renderArgs)["render"], | |
0.0, | |
1.0, | |
) | |
gt_image = torch.clamp( | |
viewpoint.original_image.to("cuda"), 0.0, 1.0 | |
) | |
if tb_writer and (idx < 5): | |
tb_writer.add_images( | |
config["name"] | |
+ "_view_{}/render".format(viewpoint.image_name), | |
image[None], | |
global_step=iteration, | |
) | |
if iteration == testing_iterations[0]: | |
tb_writer.add_images( | |
config["name"] | |
+ "_view_{}/ground_truth".format(viewpoint.image_name), | |
gt_image[None], | |
global_step=iteration, | |
) | |
l1_test += l1_loss(image, gt_image).mean().double() | |
psnr_test += psnr(image, gt_image).mean().double() | |
psnr_test /= len(config["cameras"]) | |
l1_test /= len(config["cameras"]) | |
print( | |
"\n[ITER {}] Evaluating {}: L1 {} PSNR {}".format( | |
iteration, config["name"], l1_test, psnr_test | |
) | |
) | |
if tb_writer: | |
tb_writer.add_scalar( | |
config["name"] + "/loss_viewpoint - l1_loss", l1_test, iteration | |
) | |
tb_writer.add_scalar( | |
config["name"] + "/loss_viewpoint - psnr", psnr_test, iteration | |
) | |
if tb_writer: | |
tb_writer.add_histogram( | |
"scene/opacity_histogram", scene.gaussians.get_opacity, iteration | |
) | |
tb_writer.add_scalar( | |
"total_points", scene.gaussians.get_xyz.shape[0], iteration | |
) | |
torch.cuda.empty_cache() | |
if __name__ == "__main__": | |
# Set up command line argument parser | |
parser = ArgumentParser(description="Training script parameters") | |
lp = ModelParams(parser) | |
op = OptimizationParams(parser) | |
pp = PipelineParams(parser) | |
parser.add_argument("--seed", type=int, default=42) | |
parser.add_argument("--video", type=str, default="") | |
parser.add_argument("--ip", type=str, default="127.0.0.1") | |
parser.add_argument("--port", type=int, default=6009) | |
parser.add_argument("--debug_from", type=int, default=-1) | |
parser.add_argument("--detect_anomaly", action="store_true", default=False) | |
parser.add_argument( | |
"--test_iterations", nargs="+", type=int, default=[7_000, 30_000] | |
) | |
parser.add_argument( | |
"--save_iterations", nargs="+", type=int, default=[7_000, 30_000] | |
) | |
parser.add_argument("--quiet", action="store_true") | |
parser.add_argument("--checkpoint_iterations", nargs="+", type=int, default=[]) | |
parser.add_argument("--start_checkpoint", type=str, default=None) | |
parser.add_argument("--border_ratio", type=float, default=0.3) | |
parser.add_argument("--min_guidance_scale", type=float, default=1.0) | |
parser.add_argument("--max_guidance_scale", type=float, default=2.5) | |
parser.add_argument("--sigma_max", type=float, default=None) | |
args = parser.parse_args(sys.argv[1:]) | |
args.save_iterations.append(args.iterations) | |
print("Optimizing " + args.model_path) | |
# Initialize system state (RNG) | |
safe_state(args.quiet) | |
# Start GUI server, configure and run training | |
network_gui.init(args.ip, args.port) | |
torch.autograd.set_detect_anomaly(args.detect_anomaly) | |
print("=====Start generating MV Images=====") | |
# images, _ = sample_one( | |
# args.image, | |
# args.ckpt_path, | |
# seed=args.seed, | |
# border_ratio=args.border_ratio, | |
# min_guidance_scale=args.min_guidance_scale, | |
# max_guidance_scale=args.max_guidance_scale, | |
# sigma_max=args.sigma_max, | |
# ) | |
images = [] | |
frames = read_video(args.video) | |
for frame in frames: | |
images.append(Image.fromarray(frame)) | |
print("=====Finish generating MV Images=====") | |
lp = lp.extract(args) | |
lp.images = images | |
training( | |
lp, | |
op.extract(args), | |
pp.extract(args), | |
args.test_iterations, | |
args.save_iterations, | |
args.checkpoint_iterations, | |
args.start_checkpoint, | |
args.debug_from, | |
) | |
# All done | |
print("\nTraining complete.") | |