Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torchvision.transforms import v2 | |
from einops import rearrange | |
from freesplatter.models.transformer import Transformer | |
from freesplatter.utils.infer_util import instantiate_from_config | |
from freesplatter.utils.recon_util import estimate_focal, fast_pnp | |
C0 = 0.28209479177387814 | |
def RGB2SH(rgb): | |
return (rgb - 0.5) / C0 | |
class FreeSplatterModel(nn.Module): | |
def __init__( | |
self, | |
transformer_config=None, | |
renderer_config=None, | |
use_2dgs=False, | |
sh_residual=False, | |
): | |
super().__init__() | |
self.sh_dim = (renderer_config.sh_degree + 1) ** 2 * 3 | |
self.sh_residual = sh_residual | |
self.use_2dgs = use_2dgs | |
self.transformer = instantiate_from_config(transformer_config) | |
if not use_2dgs: | |
from .renderer.gaussian_renderer import GaussianRenderer | |
else: | |
from .renderer_2dgs.gaussian_renderer import GaussianRenderer | |
self.gs_renderer = GaussianRenderer(renderer_config=renderer_config) | |
self.register_buffer('pp', torch.tensor([256, 256], dtype=torch.float32), persistent=False) | |
def forward_gaussians(self, images, **kwargs): | |
""" | |
images: B x N x 3 x H x W | |
""" | |
gaussians = self.transformer(images) # B x N x H x W x C | |
if self.sh_residual: | |
residual = torch.zeros_like(gaussians) | |
sh = RGB2SH(rearrange(images, 'b n c h w -> b n h w c')) | |
residual[..., 3:6] = sh | |
gaussians = gaussians + residual | |
gaussians = rearrange(gaussians, 'b n h w c -> b (n h w) c') | |
return gaussians | |
def forward_renderer(self, gaussians, c2ws, fxfycxcy, **kwargs): | |
""" | |
gaussians: B x K x 14 | |
c2ws: B x N x 4 x 4 | |
fxfycxcy: B x N x 4 | |
""" | |
render_results = self.gs_renderer.render(gaussians, fxfycxcy, c2ws, **kwargs) | |
return render_results | |
def estimate_focals( | |
self, | |
images, | |
masks=None, | |
use_first_focal=False, | |
): | |
""" | |
Estimate the focal lengths of N input images. | |
images: N x 3 x H x W | |
masks: N x 1 x H x W | |
""" | |
assert images.ndim == 4 | |
N, _, H, W = images.shape | |
assert H == W, "Non-square images are not supported." | |
pp = self.pp.to(images) | |
# pp = torch.tensor([W/2, H/2]).to(images) | |
focals = [] | |
for i in range(N): | |
if use_first_focal and i > 0: | |
break | |
images_input = torch.cat([images[i:], images[:i]], dim=0) | |
gaussians = self.forward_gaussians(images_input.unsqueeze(0)) # 1 x (N x H x W) x 14 | |
points = rearrange(gaussians[0, :H*W, :3], '(h w) c -> h w c', h=H, w=W) | |
mask = masks[i] if masks is not None else None | |
focal = estimate_focal(points, pp=pp, mask=mask) | |
focals.append(focal) | |
focals = torch.stack(focals).to(images) | |
focals = focals.mean().reshape(1).repeat(N) | |
return focals | |
def estimate_poses( | |
self, | |
images, | |
gaussians=None, | |
masks=None, | |
focals=None, | |
use_first_focal=True, | |
opacity_threshold=5e-2, | |
pnp_iter=20, | |
): | |
""" | |
Estimate the camera poses of N input images. | |
images: N x 3 x h x W | |
gaussians: K x 14 or 1 x K x 14 | |
masks: N x 1 x H x W | |
focals: N | |
""" | |
assert images.ndim == 4 | |
N, _, H, W = images.shape | |
assert H == W, "Non-square images are not supported." | |
# predict gaussians from images | |
if gaussians is None: | |
gaussians = self.forward_gaussians(images.unsqueeze(0)) # 1 x (N x H x W) x 14 | |
else: | |
if gaussians.ndim == 2: | |
gaussians = gaussians.unsqueeze(0) | |
assert gaussians.shape[1] == N * H * W | |
points = gaussians[..., :3].reshape(1, N, H, W, 3).squeeze(0) # N x H x W x 3 | |
opacities = gaussians[..., 3+self.sh_dim].reshape(1, N, H, W).squeeze(0) | |
opacities = torch.sigmoid(opacities) # N x H x W | |
# estimate focals if not provided | |
if focals is None: | |
focals = self.estimate_focals(images, masks=masks, use_first_focal=use_first_focal) | |
# run PnP | |
c2ws = [] | |
for i in range(N): | |
pts3d = points[i].float().detach().cpu().numpy() | |
# If masks are not provided, we use Gaussian opacities | |
if masks is None: | |
mask = (opacities[i] > opacity_threshold).detach().cpu().numpy() | |
else: | |
mask = masks[i].reshape(H, W).bool().detach().cpu().numpy() | |
focal = focals[i].item() | |
_, c2w = fast_pnp(pts3d, mask, focal=focal, niter_PnP=pnp_iter) | |
c2ws.append(torch.from_numpy(c2w)) | |
c2ws = torch.stack(c2ws, dim=0).to(images) | |
return c2ws, focals |