from dataclasses import dataclass, field import numpy as np import torch from skimage import measure from einops import repeat, rearrange import craftsman from craftsman.systems.base import BaseSystem from craftsman.utils.ops import generate_dense_grid_points from craftsman.utils.typing import * from craftsman.utils.misc import get_rank @craftsman.register("shape-autoencoder-system") class ShapeAutoEncoderSystem(BaseSystem): @dataclass class Config(BaseSystem.Config): shape_model_type: str = None shape_model: dict = field(default_factory=dict) sample_posterior: bool = True cfg: Config def configure(self): super().configure() self.shape_model = craftsman.find(self.cfg.shape_model_type)(self.cfg.shape_model) def forward(self, batch: Dict[str, Any]) -> Dict[str, Any]: if "xyz" in batch: if "sdf" in batch: bs = batch["sdf"].shape[0] rand_points = torch.cat([batch["xyz"].view(bs, -1, 3), batch["patch_xyz"].view(bs, -1, 3)], dim=1) target = torch.cat([batch["sdf"].view(bs, -1, 1), batch["patch_sdf"].view(bs, -1, 1)], dim=1).squeeze(-1) criteria = torch.nn.MSELoss() elif "occupancy" in batch: bs = batch["occupancy"].shape[0] rand_points = torch.cat([batch["xyz"].view(bs, -1, 3), batch["patch_xyz"].view(bs, -1, 3)], dim=1) target = torch.cat([batch["occupancy"].view(bs, -1, 1), batch["patch_occupancy"].view(bs, -1, 1)], dim=1).squeeze(-1) criteria = torch.nn.BCEWithLogitsLoss() else: raise NotImplementedError else: rand_points = batch["rand_points"] if "sdf" in batch: target = batch["sdf"] criteria = torch.nn.MSELoss() elif "occupancies" in batch: target = batch["occupancies"] criteria = torch.nn.BCEWithLogitsLoss() else: raise NotImplementedError # forward pass _, latents, posterior, logits = self.shape_model( batch["surface"][..., :3 + self.cfg.shape_model.point_feats], rand_points, sample_posterior=self.cfg.sample_posterior ) if self.cfg.sample_posterior: loss_kl = posterior.kl() loss_kl = torch.sum(loss_kl) / loss_kl.shape[0] return { "loss_logits": criteria(logits, target).mean(), "loss_kl": loss_kl, "logits": logits, "target": target, "latents": latents, } else: return { "loss_logits": criteria(logits, target).mean(), "latents": latents, "logits": logits, } def training_step(self, batch, batch_idx): """ Description: Args: batch: batch_idx: Returns: loss: """ out = self(batch) loss = 0. for name, value in out.items(): if name.startswith("loss_"): self.log(f"train/{name}", value) loss += value * self.C(self.cfg.loss[name.replace("loss_", "lambda_")]) for name, value in self.cfg.loss.items(): self.log(f"train_params/{name}", self.C(value)) return {"loss": loss} @torch.no_grad() def validation_step(self, batch, batch_idx): self.eval() out = self(batch) # self.save_state_dict("latest-weights", self.state_dict()) mesh_v_f, has_surface = self.shape_model.extract_geometry(out["latents"]) self.save_mesh( f"it{self.true_global_step}/{batch['uid'][0]}.obj", mesh_v_f[0][0], mesh_v_f[0][1] ) threshold = 0 outputs = out["logits"] labels = out["target"] pred = torch.zeros_like(outputs) pred[outputs>=threshold] = 1 accuracy = (pred==labels).float().sum(dim=1) / labels.shape[1] accuracy = accuracy.mean() intersection = (pred * labels).sum(dim=1) union = (pred + labels).gt(0).sum(dim=1) iou = intersection * 1.0 / union + 1e-5 iou = iou.mean() self.log("val/accuracy", accuracy) self.log("val/iou", iou) torch.cuda.empty_cache() return {"val/loss": out["loss_logits"], "val/accuracy": accuracy, "val/iou": iou} def on_validation_epoch_end(self): pass