V3D / mesh_recon /systems /neus_pinhole.py
heheyas
init
cfb7702
raw
history blame
19.2 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_efficient_distloss import flatten_eff_distloss
import pytorch_lightning as pl
from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_debug
import models
from models.utils import cleanup
from models.ray_utils import get_rays
import systems
from systems.base import BaseSystem
from systems.criterions import PSNR, binary_cross_entropy
import pdb
def ranking_loss(error, penalize_ratio=0.7, extra_weights=None, type="mean"):
error, indices = torch.sort(error)
# only sum relatively small errors
s_error = torch.index_select(
error, 0, index=indices[: int(penalize_ratio * indices.shape[0])]
)
if extra_weights is not None:
weights = torch.index_select(
extra_weights, 0, index=indices[: int(penalize_ratio * indices.shape[0])]
)
s_error = s_error * weights
if type == "mean":
return torch.mean(s_error)
elif type == "sum":
return torch.sum(s_error)
@systems.register("pinhole-neus-system")
class PinholeNeuSSystem(BaseSystem):
"""
Two ways to print to console:
1. self.print: correctly handle progress bar
2. rank_zero_info: use the logging module
"""
def prepare(self):
self.criterions = {"psnr": PSNR()}
self.train_num_samples = self.config.model.train_num_rays * (
self.config.model.num_samples_per_ray
+ self.config.model.get("num_samples_per_ray_bg", 0)
)
self.train_num_rays = self.config.model.train_num_rays
self.cos = torch.nn.CosineSimilarity(dim=-1, eps=1e-6)
def forward(self, batch):
return self.model(batch["rays"])
def preprocess_data(self, batch, stage):
if "index" in batch: # validation / testing
index = batch["index"]
else:
if self.config.model.batch_image_sampling:
index = torch.randint(
0,
len(self.dataset.all_images),
size=(self.train_num_rays,),
device=self.dataset.all_images.device,
)
else:
index = torch.randint(
0,
len(self.dataset.all_images),
size=(1,),
device=self.dataset.all_images.device,
)
if stage in ["train"]:
c2w = self.dataset.all_c2w[index]
x = torch.randint(
0,
self.dataset.w,
size=(self.train_num_rays,),
device=self.dataset.all_images.device,
)
y = torch.randint(
0,
self.dataset.h,
size=(self.train_num_rays,),
device=self.dataset.all_images.device,
)
if self.dataset.directions.ndim == 3: # (H, W, 3)
directions = self.dataset.directions[y, x]
# origins = self.dataset.origins[y, x]
elif self.dataset.directions.ndim == 4: # (N, H, W, 3)
directions = self.dataset.directions[index, y, x]
# origins = self.dataset.origins[index, y, x]
rays_o, rays_d = get_rays(directions, c2w)
rgb = (
self.dataset.all_images[index, y, x]
.view(-1, self.dataset.all_images.shape[-1])
.to(self.rank)
)
normal = (
self.dataset.all_normals_world[index, y, x]
.view(-1, self.dataset.all_normals_world.shape[-1])
.to(self.rank)
)
fg_mask = self.dataset.all_fg_masks[index, y, x].view(-1).to(self.rank)
rgb_mask = self.dataset.all_rgb_masks[index, y, x].view(-1).to(self.rank)
view_weights = self.dataset.view_weights[index, y, x].view(-1).to(self.rank)
else:
c2w = self.dataset.all_c2w[index][0]
if self.dataset.directions.ndim == 3: # (H, W, 3)
directions = self.dataset.directions
# origins = self.dataset.origins
elif self.dataset.directions.ndim == 4: # (N, H, W, 3)
directions = self.dataset.directions[index][0]
# origins = self.dataset.origins[index][0]
rays_o, rays_d = get_rays(directions, c2w)
rgb = (
self.dataset.all_images[index]
.view(-1, self.dataset.all_images.shape[-1])
.to(self.rank)
)
normal = (
self.dataset.all_normals_world[index]
.view(-1, self.dataset.all_images.shape[-1])
.to(self.rank)
)
fg_mask = self.dataset.all_fg_masks[index].view(-1).to(self.rank)
rgb_mask = self.dataset.all_rgb_masks[index].view(-1).to(self.rank)
view_weights = None
cosines = self.cos(rays_d, normal)
rays = torch.cat([rays_o, F.normalize(rays_d, p=2, dim=-1)], dim=-1)
if stage in ["train"]:
if self.config.model.background_color == "white":
self.model.background_color = torch.ones(
(3,), dtype=torch.float32, device=self.rank
)
elif self.config.model.background_color == "black":
self.model.background_color = torch.zeros(
(3,), dtype=torch.float32, device=self.rank
)
elif self.config.model.background_color == "random":
self.model.background_color = torch.rand(
(3,), dtype=torch.float32, device=self.rank
)
else:
raise NotImplementedError
else:
self.model.background_color = torch.ones(
(3,), dtype=torch.float32, device=self.rank
)
if self.dataset.apply_mask:
rgb = rgb * fg_mask[..., None] + self.model.background_color * (
1 - fg_mask[..., None]
)
batch.update(
{
"rays": rays,
"rgb": rgb,
"normal": normal,
"fg_mask": fg_mask,
"rgb_mask": rgb_mask,
"cosines": cosines,
"view_weights": view_weights,
}
)
def training_step(self, batch, batch_idx):
out = self(batch)
cosines = batch["cosines"]
fg_mask = batch["fg_mask"]
rgb_mask = batch["rgb_mask"]
view_weights = batch["view_weights"]
cosines[cosines > -0.1] = 0
mask = (fg_mask > 0) & (cosines < -0.1)
rgb_mask = out["rays_valid_full"][..., 0] & (rgb_mask > 0)
grad_cosines = self.cos(batch["rays"][..., 3:], out["comp_normal"]).detach()
# grad_cosines = cosines
loss = 0.0
# update train_num_rays
if self.config.model.dynamic_ray_sampling:
train_num_rays = int(
self.train_num_rays
* (self.train_num_samples / out["num_samples_full"].sum().item())
)
self.train_num_rays = min(
int(self.train_num_rays * 0.9 + train_num_rays * 0.1),
self.config.model.max_train_num_rays,
)
erros_rgb_mse = F.mse_loss(
out["comp_rgb_full"][rgb_mask], batch["rgb"][rgb_mask], reduction="none"
)
# erros_rgb_mse = erros_rgb_mse * torch.exp(grad_cosines.abs())[:, None][rgb_mask] / torch.exp(grad_cosines.abs()[rgb_mask]).sum()
# loss_rgb_mse = ranking_loss(erros_rgb_mse.sum(dim=1), penalize_ratio=0.7, type='sum')
loss_rgb_mse = ranking_loss(
erros_rgb_mse.sum(dim=1), penalize_ratio=0.7, type="mean"
)
self.log("train/loss_rgb_mse", loss_rgb_mse, prog_bar=True, rank_zero_only=True)
loss += loss_rgb_mse * self.C(self.config.system.loss.lambda_rgb_mse)
loss_rgb_l1 = F.l1_loss(
out["comp_rgb_full"][rgb_mask], batch["rgb"][rgb_mask], reduction="none"
)
loss_rgb_l1 = ranking_loss(
loss_rgb_l1.sum(dim=1),
extra_weights=view_weights[rgb_mask],
penalize_ratio=0.8,
)
self.log("train/loss_rgb", loss_rgb_l1)
loss += loss_rgb_l1 * self.C(self.config.system.loss.lambda_rgb_l1)
normal_errors = 1 - F.cosine_similarity(
out["comp_normal"], batch["normal"], dim=1
)
# normal_errors = normal_errors * cosines.abs() / cosines.abs().sum()
normal_errors = (
normal_errors * torch.exp(cosines.abs()) / torch.exp(cosines.abs()).sum()
)
loss_normal = ranking_loss(
normal_errors[mask],
penalize_ratio=0.8,
# extra_weights=view_weights[mask],
type="sum",
)
self.log("train/loss_normal", loss_normal, prog_bar=True, rank_zero_only=True)
loss += loss_normal * self.C(self.config.system.loss.lambda_normal)
loss_eikonal = (
(torch.linalg.norm(out["sdf_grad_samples"], ord=2, dim=-1) - 1.0) ** 2
).mean()
self.log("train/loss_eikonal", loss_eikonal, prog_bar=True, rank_zero_only=True)
loss += loss_eikonal * self.C(self.config.system.loss.lambda_eikonal)
opacity = torch.clamp(out["opacity"].squeeze(-1), 1.0e-3, 1.0 - 1.0e-3)
loss_mask = binary_cross_entropy(
opacity, batch["fg_mask"].float(), reduction="none"
)
loss_mask = ranking_loss(
loss_mask, penalize_ratio=0.9, extra_weights=view_weights
)
self.log("train/loss_mask", loss_mask, prog_bar=True, rank_zero_only=True)
loss += loss_mask * (
self.C(self.config.system.loss.lambda_mask)
if self.dataset.has_mask
else 0.0
)
loss_opaque = binary_cross_entropy(opacity, opacity)
self.log("train/loss_opaque", loss_opaque)
loss += loss_opaque * self.C(self.config.system.loss.lambda_opaque)
loss_sparsity = torch.exp(
-self.config.system.loss.sparsity_scale * out["random_sdf"].abs()
).mean()
self.log(
"train/loss_sparsity", loss_sparsity, prog_bar=True, rank_zero_only=True
)
loss += loss_sparsity * self.C(self.config.system.loss.lambda_sparsity)
if self.C(self.config.system.loss.lambda_curvature) > 0:
assert (
"sdf_laplace_samples" in out
), "Need geometry.grad_type='finite_difference' to get SDF Laplace samples"
loss_curvature = out["sdf_laplace_samples"].abs().mean()
self.log("train/loss_curvature", loss_curvature)
loss += loss_curvature * self.C(self.config.system.loss.lambda_curvature)
# distortion loss proposed in MipNeRF360
# an efficient implementation from https://github.com/sunset1995/torch_efficient_distloss
if self.C(self.config.system.loss.lambda_distortion) > 0:
loss_distortion = flatten_eff_distloss(
out["weights"], out["points"], out["intervals"], out["ray_indices"]
)
self.log("train/loss_distortion", loss_distortion)
loss += loss_distortion * self.C(self.config.system.loss.lambda_distortion)
if (
self.config.model.learned_background
and self.C(self.config.system.loss.lambda_distortion_bg) > 0
):
loss_distortion_bg = flatten_eff_distloss(
out["weights_bg"],
out["points_bg"],
out["intervals_bg"],
out["ray_indices_bg"],
)
self.log("train/loss_distortion_bg", loss_distortion_bg)
loss += loss_distortion_bg * self.C(
self.config.system.loss.lambda_distortion_bg
)
if self.C(self.config.system.loss.lambda_3d_normal_smooth) > 0:
if "random_sdf_grad" not in out:
raise ValueError(
"random_sdf_grad is required for normal smooth loss, no normal is found in the output."
)
if "normal_perturb" not in out:
raise ValueError(
"normal_perturb is required for normal smooth loss, no normal_perturb is found in the output."
)
normals_3d = out["random_sdf_grad"]
normals_perturb_3d = out["normal_perturb"]
loss_3d_normal_smooth = (normals_3d - normals_perturb_3d).abs().mean()
self.log(
"train/loss_3d_normal_smooth", loss_3d_normal_smooth, prog_bar=True
)
loss += loss_3d_normal_smooth * self.C(
self.config.system.loss.lambda_3d_normal_smooth
)
losses_model_reg = self.model.regularizations(out)
for name, value in losses_model_reg.items():
self.log(f"train/loss_{name}", value)
loss_ = value * self.C(self.config.system.loss[f"lambda_{name}"])
loss += loss_
self.log("train/inv_s", out["inv_s"], prog_bar=True)
for name, value in self.config.system.loss.items():
if name.startswith("lambda"):
self.log(f"train_params/{name}", self.C(value))
self.log("train/num_rays", float(self.train_num_rays), prog_bar=True)
return {"loss": loss}
"""
# aggregate outputs from different devices (DP)
def training_step_end(self, out):
pass
"""
"""
# aggregate outputs from different iterations
def training_epoch_end(self, out):
pass
"""
def validation_step(self, batch, batch_idx):
out = self(batch)
psnr = self.criterions["psnr"](
out["comp_rgb_full"].to(batch["rgb"]), batch["rgb"]
)
W, H = self.dataset.img_wh
self.save_image_grid(
f"it{self.global_step}-{batch['index'][0].item()}.png",
[
{
"type": "rgb",
"img": batch["rgb"].view(H, W, 3),
"kwargs": {"data_format": "HWC"},
},
{
"type": "rgb",
"img": out["comp_rgb_full"].view(H, W, 3),
"kwargs": {"data_format": "HWC"},
},
]
+ (
[
{
"type": "rgb",
"img": out["comp_rgb_bg"].view(H, W, 3),
"kwargs": {"data_format": "HWC"},
},
{
"type": "rgb",
"img": out["comp_rgb"].view(H, W, 3),
"kwargs": {"data_format": "HWC"},
},
]
if self.config.model.learned_background
else []
)
+ [
{"type": "grayscale", "img": out["depth"].view(H, W), "kwargs": {}},
{
"type": "rgb",
"img": out["comp_normal"].view(H, W, 3),
"kwargs": {"data_format": "HWC", "data_range": (-1, 1)},
},
],
)
return {"psnr": psnr, "index": batch["index"]}
"""
# aggregate outputs from different devices when using DP
def validation_step_end(self, out):
pass
"""
def validation_epoch_end(self, out):
out = self.all_gather(out)
if self.trainer.is_global_zero:
out_set = {}
for step_out in out:
# DP
if step_out["index"].ndim == 1:
out_set[step_out["index"].item()] = {"psnr": step_out["psnr"]}
# DDP
else:
for oi, index in enumerate(step_out["index"]):
out_set[index[0].item()] = {"psnr": step_out["psnr"][oi]}
psnr = torch.mean(torch.stack([o["psnr"] for o in out_set.values()]))
self.log("val/psnr", psnr, prog_bar=True, rank_zero_only=True)
self.export()
def test_step(self, batch, batch_idx):
out = self(batch)
psnr = self.criterions["psnr"](
out["comp_rgb_full"].to(batch["rgb"]), batch["rgb"]
)
W, H = self.dataset.img_wh
self.save_image_grid(
f"it{self.global_step}-test/{batch['index'][0].item()}.png",
[
{
"type": "rgb",
"img": batch["rgb"].view(H, W, 3),
"kwargs": {"data_format": "HWC"},
},
{
"type": "rgb",
"img": out["comp_rgb_full"].view(H, W, 3),
"kwargs": {"data_format": "HWC"},
},
]
+ (
[
{
"type": "rgb",
"img": out["comp_rgb_bg"].view(H, W, 3),
"kwargs": {"data_format": "HWC"},
},
{
"type": "rgb",
"img": out["comp_rgb"].view(H, W, 3),
"kwargs": {"data_format": "HWC"},
},
]
if self.config.model.learned_background
else []
)
+ [
{"type": "grayscale", "img": out["depth"].view(H, W), "kwargs": {}},
{
"type": "rgb",
"img": out["comp_normal"].view(H, W, 3),
"kwargs": {"data_format": "HWC", "data_range": (-1, 1)},
},
],
)
return {"psnr": psnr, "index": batch["index"]}
def test_epoch_end(self, out):
"""
Synchronize devices.
Generate image sequence using test outputs.
"""
out = self.all_gather(out)
if self.trainer.is_global_zero:
out_set = {}
for step_out in out:
# DP
if step_out["index"].ndim == 1:
out_set[step_out["index"].item()] = {"psnr": step_out["psnr"]}
# DDP
else:
for oi, index in enumerate(step_out["index"]):
out_set[index[0].item()] = {"psnr": step_out["psnr"][oi]}
psnr = torch.mean(torch.stack([o["psnr"] for o in out_set.values()]))
self.log("test/psnr", psnr, prog_bar=True, rank_zero_only=True)
self.save_img_sequence(
f"it{self.global_step}-test",
f"it{self.global_step}-test",
"(\d+)\.png",
save_format="mp4",
fps=30,
)
self.export()
def export(self):
mesh = self.model.export(self.config.export)
self.save_mesh(
f"it{self.global_step}-{self.config.model.geometry.isosurface.method}{self.config.model.geometry.isosurface.resolution}.obj",
ortho_scale=self.config.export.ortho_scale,
**mesh,
)