Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch_efficient_distloss import flatten_eff_distloss | |
import pytorch_lightning as pl | |
from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_debug | |
import models | |
from models.utils import cleanup | |
from models.ray_utils import get_rays | |
import systems | |
from systems.base import BaseSystem | |
from systems.criterions import PSNR, binary_cross_entropy | |
import pdb | |
def ranking_loss(error, penalize_ratio=0.7, extra_weights=None, type="mean"): | |
error, indices = torch.sort(error) | |
# only sum relatively small errors | |
s_error = torch.index_select( | |
error, 0, index=indices[: int(penalize_ratio * indices.shape[0])] | |
) | |
if extra_weights is not None: | |
weights = torch.index_select( | |
extra_weights, 0, index=indices[: int(penalize_ratio * indices.shape[0])] | |
) | |
s_error = s_error * weights | |
if type == "mean": | |
return torch.mean(s_error) | |
elif type == "sum": | |
return torch.sum(s_error) | |
class PinholeNeuSSystem(BaseSystem): | |
""" | |
Two ways to print to console: | |
1. self.print: correctly handle progress bar | |
2. rank_zero_info: use the logging module | |
""" | |
def prepare(self): | |
self.criterions = {"psnr": PSNR()} | |
self.train_num_samples = self.config.model.train_num_rays * ( | |
self.config.model.num_samples_per_ray | |
+ self.config.model.get("num_samples_per_ray_bg", 0) | |
) | |
self.train_num_rays = self.config.model.train_num_rays | |
self.cos = torch.nn.CosineSimilarity(dim=-1, eps=1e-6) | |
def forward(self, batch): | |
return self.model(batch["rays"]) | |
def preprocess_data(self, batch, stage): | |
if "index" in batch: # validation / testing | |
index = batch["index"] | |
else: | |
if self.config.model.batch_image_sampling: | |
index = torch.randint( | |
0, | |
len(self.dataset.all_images), | |
size=(self.train_num_rays,), | |
device=self.dataset.all_images.device, | |
) | |
else: | |
index = torch.randint( | |
0, | |
len(self.dataset.all_images), | |
size=(1,), | |
device=self.dataset.all_images.device, | |
) | |
if stage in ["train"]: | |
c2w = self.dataset.all_c2w[index] | |
x = torch.randint( | |
0, | |
self.dataset.w, | |
size=(self.train_num_rays,), | |
device=self.dataset.all_images.device, | |
) | |
y = torch.randint( | |
0, | |
self.dataset.h, | |
size=(self.train_num_rays,), | |
device=self.dataset.all_images.device, | |
) | |
if self.dataset.directions.ndim == 3: # (H, W, 3) | |
directions = self.dataset.directions[y, x] | |
# origins = self.dataset.origins[y, x] | |
elif self.dataset.directions.ndim == 4: # (N, H, W, 3) | |
directions = self.dataset.directions[index, y, x] | |
# origins = self.dataset.origins[index, y, x] | |
rays_o, rays_d = get_rays(directions, c2w) | |
rgb = ( | |
self.dataset.all_images[index, y, x] | |
.view(-1, self.dataset.all_images.shape[-1]) | |
.to(self.rank) | |
) | |
normal = ( | |
self.dataset.all_normals_world[index, y, x] | |
.view(-1, self.dataset.all_normals_world.shape[-1]) | |
.to(self.rank) | |
) | |
fg_mask = self.dataset.all_fg_masks[index, y, x].view(-1).to(self.rank) | |
rgb_mask = self.dataset.all_rgb_masks[index, y, x].view(-1).to(self.rank) | |
view_weights = self.dataset.view_weights[index, y, x].view(-1).to(self.rank) | |
else: | |
c2w = self.dataset.all_c2w[index][0] | |
if self.dataset.directions.ndim == 3: # (H, W, 3) | |
directions = self.dataset.directions | |
# origins = self.dataset.origins | |
elif self.dataset.directions.ndim == 4: # (N, H, W, 3) | |
directions = self.dataset.directions[index][0] | |
# origins = self.dataset.origins[index][0] | |
rays_o, rays_d = get_rays(directions, c2w) | |
rgb = ( | |
self.dataset.all_images[index] | |
.view(-1, self.dataset.all_images.shape[-1]) | |
.to(self.rank) | |
) | |
normal = ( | |
self.dataset.all_normals_world[index] | |
.view(-1, self.dataset.all_images.shape[-1]) | |
.to(self.rank) | |
) | |
fg_mask = self.dataset.all_fg_masks[index].view(-1).to(self.rank) | |
rgb_mask = self.dataset.all_rgb_masks[index].view(-1).to(self.rank) | |
view_weights = None | |
cosines = self.cos(rays_d, normal) | |
rays = torch.cat([rays_o, F.normalize(rays_d, p=2, dim=-1)], dim=-1) | |
if stage in ["train"]: | |
if self.config.model.background_color == "white": | |
self.model.background_color = torch.ones( | |
(3,), dtype=torch.float32, device=self.rank | |
) | |
elif self.config.model.background_color == "black": | |
self.model.background_color = torch.zeros( | |
(3,), dtype=torch.float32, device=self.rank | |
) | |
elif self.config.model.background_color == "random": | |
self.model.background_color = torch.rand( | |
(3,), dtype=torch.float32, device=self.rank | |
) | |
else: | |
raise NotImplementedError | |
else: | |
self.model.background_color = torch.ones( | |
(3,), dtype=torch.float32, device=self.rank | |
) | |
if self.dataset.apply_mask: | |
rgb = rgb * fg_mask[..., None] + self.model.background_color * ( | |
1 - fg_mask[..., None] | |
) | |
batch.update( | |
{ | |
"rays": rays, | |
"rgb": rgb, | |
"normal": normal, | |
"fg_mask": fg_mask, | |
"rgb_mask": rgb_mask, | |
"cosines": cosines, | |
"view_weights": view_weights, | |
} | |
) | |
def training_step(self, batch, batch_idx): | |
out = self(batch) | |
cosines = batch["cosines"] | |
fg_mask = batch["fg_mask"] | |
rgb_mask = batch["rgb_mask"] | |
view_weights = batch["view_weights"] | |
cosines[cosines > -0.1] = 0 | |
mask = (fg_mask > 0) & (cosines < -0.1) | |
rgb_mask = out["rays_valid_full"][..., 0] & (rgb_mask > 0) | |
grad_cosines = self.cos(batch["rays"][..., 3:], out["comp_normal"]).detach() | |
# grad_cosines = cosines | |
loss = 0.0 | |
# update train_num_rays | |
if self.config.model.dynamic_ray_sampling: | |
train_num_rays = int( | |
self.train_num_rays | |
* (self.train_num_samples / out["num_samples_full"].sum().item()) | |
) | |
self.train_num_rays = min( | |
int(self.train_num_rays * 0.9 + train_num_rays * 0.1), | |
self.config.model.max_train_num_rays, | |
) | |
erros_rgb_mse = F.mse_loss( | |
out["comp_rgb_full"][rgb_mask], batch["rgb"][rgb_mask], reduction="none" | |
) | |
# erros_rgb_mse = erros_rgb_mse * torch.exp(grad_cosines.abs())[:, None][rgb_mask] / torch.exp(grad_cosines.abs()[rgb_mask]).sum() | |
# loss_rgb_mse = ranking_loss(erros_rgb_mse.sum(dim=1), penalize_ratio=0.7, type='sum') | |
loss_rgb_mse = ranking_loss( | |
erros_rgb_mse.sum(dim=1), penalize_ratio=0.7, type="mean" | |
) | |
self.log("train/loss_rgb_mse", loss_rgb_mse, prog_bar=True, rank_zero_only=True) | |
loss += loss_rgb_mse * self.C(self.config.system.loss.lambda_rgb_mse) | |
loss_rgb_l1 = F.l1_loss( | |
out["comp_rgb_full"][rgb_mask], batch["rgb"][rgb_mask], reduction="none" | |
) | |
loss_rgb_l1 = ranking_loss( | |
loss_rgb_l1.sum(dim=1), | |
extra_weights=view_weights[rgb_mask], | |
penalize_ratio=0.8, | |
) | |
self.log("train/loss_rgb", loss_rgb_l1) | |
loss += loss_rgb_l1 * self.C(self.config.system.loss.lambda_rgb_l1) | |
normal_errors = 1 - F.cosine_similarity( | |
out["comp_normal"], batch["normal"], dim=1 | |
) | |
# normal_errors = normal_errors * cosines.abs() / cosines.abs().sum() | |
normal_errors = ( | |
normal_errors * torch.exp(cosines.abs()) / torch.exp(cosines.abs()).sum() | |
) | |
loss_normal = ranking_loss( | |
normal_errors[mask], | |
penalize_ratio=0.8, | |
# extra_weights=view_weights[mask], | |
type="sum", | |
) | |
self.log("train/loss_normal", loss_normal, prog_bar=True, rank_zero_only=True) | |
loss += loss_normal * self.C(self.config.system.loss.lambda_normal) | |
loss_eikonal = ( | |
(torch.linalg.norm(out["sdf_grad_samples"], ord=2, dim=-1) - 1.0) ** 2 | |
).mean() | |
self.log("train/loss_eikonal", loss_eikonal, prog_bar=True, rank_zero_only=True) | |
loss += loss_eikonal * self.C(self.config.system.loss.lambda_eikonal) | |
opacity = torch.clamp(out["opacity"].squeeze(-1), 1.0e-3, 1.0 - 1.0e-3) | |
loss_mask = binary_cross_entropy( | |
opacity, batch["fg_mask"].float(), reduction="none" | |
) | |
loss_mask = ranking_loss( | |
loss_mask, penalize_ratio=0.9, extra_weights=view_weights | |
) | |
self.log("train/loss_mask", loss_mask, prog_bar=True, rank_zero_only=True) | |
loss += loss_mask * ( | |
self.C(self.config.system.loss.lambda_mask) | |
if self.dataset.has_mask | |
else 0.0 | |
) | |
loss_opaque = binary_cross_entropy(opacity, opacity) | |
self.log("train/loss_opaque", loss_opaque) | |
loss += loss_opaque * self.C(self.config.system.loss.lambda_opaque) | |
loss_sparsity = torch.exp( | |
-self.config.system.loss.sparsity_scale * out["random_sdf"].abs() | |
).mean() | |
self.log( | |
"train/loss_sparsity", loss_sparsity, prog_bar=True, rank_zero_only=True | |
) | |
loss += loss_sparsity * self.C(self.config.system.loss.lambda_sparsity) | |
if self.C(self.config.system.loss.lambda_curvature) > 0: | |
assert ( | |
"sdf_laplace_samples" in out | |
), "Need geometry.grad_type='finite_difference' to get SDF Laplace samples" | |
loss_curvature = out["sdf_laplace_samples"].abs().mean() | |
self.log("train/loss_curvature", loss_curvature) | |
loss += loss_curvature * self.C(self.config.system.loss.lambda_curvature) | |
# distortion loss proposed in MipNeRF360 | |
# an efficient implementation from https://github.com/sunset1995/torch_efficient_distloss | |
if self.C(self.config.system.loss.lambda_distortion) > 0: | |
loss_distortion = flatten_eff_distloss( | |
out["weights"], out["points"], out["intervals"], out["ray_indices"] | |
) | |
self.log("train/loss_distortion", loss_distortion) | |
loss += loss_distortion * self.C(self.config.system.loss.lambda_distortion) | |
if ( | |
self.config.model.learned_background | |
and self.C(self.config.system.loss.lambda_distortion_bg) > 0 | |
): | |
loss_distortion_bg = flatten_eff_distloss( | |
out["weights_bg"], | |
out["points_bg"], | |
out["intervals_bg"], | |
out["ray_indices_bg"], | |
) | |
self.log("train/loss_distortion_bg", loss_distortion_bg) | |
loss += loss_distortion_bg * self.C( | |
self.config.system.loss.lambda_distortion_bg | |
) | |
if self.C(self.config.system.loss.lambda_3d_normal_smooth) > 0: | |
if "random_sdf_grad" not in out: | |
raise ValueError( | |
"random_sdf_grad is required for normal smooth loss, no normal is found in the output." | |
) | |
if "normal_perturb" not in out: | |
raise ValueError( | |
"normal_perturb is required for normal smooth loss, no normal_perturb is found in the output." | |
) | |
normals_3d = out["random_sdf_grad"] | |
normals_perturb_3d = out["normal_perturb"] | |
loss_3d_normal_smooth = (normals_3d - normals_perturb_3d).abs().mean() | |
self.log( | |
"train/loss_3d_normal_smooth", loss_3d_normal_smooth, prog_bar=True | |
) | |
loss += loss_3d_normal_smooth * self.C( | |
self.config.system.loss.lambda_3d_normal_smooth | |
) | |
losses_model_reg = self.model.regularizations(out) | |
for name, value in losses_model_reg.items(): | |
self.log(f"train/loss_{name}", value) | |
loss_ = value * self.C(self.config.system.loss[f"lambda_{name}"]) | |
loss += loss_ | |
self.log("train/inv_s", out["inv_s"], prog_bar=True) | |
for name, value in self.config.system.loss.items(): | |
if name.startswith("lambda"): | |
self.log(f"train_params/{name}", self.C(value)) | |
self.log("train/num_rays", float(self.train_num_rays), prog_bar=True) | |
return {"loss": loss} | |
""" | |
# aggregate outputs from different devices (DP) | |
def training_step_end(self, out): | |
pass | |
""" | |
""" | |
# aggregate outputs from different iterations | |
def training_epoch_end(self, out): | |
pass | |
""" | |
def validation_step(self, batch, batch_idx): | |
out = self(batch) | |
psnr = self.criterions["psnr"]( | |
out["comp_rgb_full"].to(batch["rgb"]), batch["rgb"] | |
) | |
W, H = self.dataset.img_wh | |
self.save_image_grid( | |
f"it{self.global_step}-{batch['index'][0].item()}.png", | |
[ | |
{ | |
"type": "rgb", | |
"img": batch["rgb"].view(H, W, 3), | |
"kwargs": {"data_format": "HWC"}, | |
}, | |
{ | |
"type": "rgb", | |
"img": out["comp_rgb_full"].view(H, W, 3), | |
"kwargs": {"data_format": "HWC"}, | |
}, | |
] | |
+ ( | |
[ | |
{ | |
"type": "rgb", | |
"img": out["comp_rgb_bg"].view(H, W, 3), | |
"kwargs": {"data_format": "HWC"}, | |
}, | |
{ | |
"type": "rgb", | |
"img": out["comp_rgb"].view(H, W, 3), | |
"kwargs": {"data_format": "HWC"}, | |
}, | |
] | |
if self.config.model.learned_background | |
else [] | |
) | |
+ [ | |
{"type": "grayscale", "img": out["depth"].view(H, W), "kwargs": {}}, | |
{ | |
"type": "rgb", | |
"img": out["comp_normal"].view(H, W, 3), | |
"kwargs": {"data_format": "HWC", "data_range": (-1, 1)}, | |
}, | |
], | |
) | |
return {"psnr": psnr, "index": batch["index"]} | |
""" | |
# aggregate outputs from different devices when using DP | |
def validation_step_end(self, out): | |
pass | |
""" | |
def validation_epoch_end(self, out): | |
out = self.all_gather(out) | |
if self.trainer.is_global_zero: | |
out_set = {} | |
for step_out in out: | |
# DP | |
if step_out["index"].ndim == 1: | |
out_set[step_out["index"].item()] = {"psnr": step_out["psnr"]} | |
# DDP | |
else: | |
for oi, index in enumerate(step_out["index"]): | |
out_set[index[0].item()] = {"psnr": step_out["psnr"][oi]} | |
psnr = torch.mean(torch.stack([o["psnr"] for o in out_set.values()])) | |
self.log("val/psnr", psnr, prog_bar=True, rank_zero_only=True) | |
self.export() | |
def test_step(self, batch, batch_idx): | |
out = self(batch) | |
psnr = self.criterions["psnr"]( | |
out["comp_rgb_full"].to(batch["rgb"]), batch["rgb"] | |
) | |
W, H = self.dataset.img_wh | |
self.save_image_grid( | |
f"it{self.global_step}-test/{batch['index'][0].item()}.png", | |
[ | |
{ | |
"type": "rgb", | |
"img": batch["rgb"].view(H, W, 3), | |
"kwargs": {"data_format": "HWC"}, | |
}, | |
{ | |
"type": "rgb", | |
"img": out["comp_rgb_full"].view(H, W, 3), | |
"kwargs": {"data_format": "HWC"}, | |
}, | |
] | |
+ ( | |
[ | |
{ | |
"type": "rgb", | |
"img": out["comp_rgb_bg"].view(H, W, 3), | |
"kwargs": {"data_format": "HWC"}, | |
}, | |
{ | |
"type": "rgb", | |
"img": out["comp_rgb"].view(H, W, 3), | |
"kwargs": {"data_format": "HWC"}, | |
}, | |
] | |
if self.config.model.learned_background | |
else [] | |
) | |
+ [ | |
{"type": "grayscale", "img": out["depth"].view(H, W), "kwargs": {}}, | |
{ | |
"type": "rgb", | |
"img": out["comp_normal"].view(H, W, 3), | |
"kwargs": {"data_format": "HWC", "data_range": (-1, 1)}, | |
}, | |
], | |
) | |
return {"psnr": psnr, "index": batch["index"]} | |
def test_epoch_end(self, out): | |
""" | |
Synchronize devices. | |
Generate image sequence using test outputs. | |
""" | |
out = self.all_gather(out) | |
if self.trainer.is_global_zero: | |
out_set = {} | |
for step_out in out: | |
# DP | |
if step_out["index"].ndim == 1: | |
out_set[step_out["index"].item()] = {"psnr": step_out["psnr"]} | |
# DDP | |
else: | |
for oi, index in enumerate(step_out["index"]): | |
out_set[index[0].item()] = {"psnr": step_out["psnr"][oi]} | |
psnr = torch.mean(torch.stack([o["psnr"] for o in out_set.values()])) | |
self.log("test/psnr", psnr, prog_bar=True, rank_zero_only=True) | |
self.save_img_sequence( | |
f"it{self.global_step}-test", | |
f"it{self.global_step}-test", | |
"(\d+)\.png", | |
save_format="mp4", | |
fps=30, | |
) | |
self.export() | |
def export(self): | |
mesh = self.model.export(self.config.export) | |
self.save_mesh( | |
f"it{self.global_step}-{self.config.model.geometry.isosurface.method}{self.config.model.geometry.isosurface.resolution}.obj", | |
ortho_scale=self.config.export.ortho_scale, | |
**mesh, | |
) | |