bluestyle97's picture
Upload 147 files
184193d verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.transforms import v2
from einops import rearrange
from freesplatter.models.transformer import Transformer
from freesplatter.utils.infer_util import instantiate_from_config
from freesplatter.utils.recon_util import estimate_focal, fast_pnp
C0 = 0.28209479177387814
def RGB2SH(rgb):
return (rgb - 0.5) / C0
class FreeSplatterModel(nn.Module):
def __init__(
self,
transformer_config=None,
renderer_config=None,
use_2dgs=False,
sh_residual=False,
):
super().__init__()
self.sh_dim = (renderer_config.sh_degree + 1) ** 2 * 3
self.sh_residual = sh_residual
self.use_2dgs = use_2dgs
self.transformer = instantiate_from_config(transformer_config)
if not use_2dgs:
from .renderer.gaussian_renderer import GaussianRenderer
else:
from .renderer_2dgs.gaussian_renderer import GaussianRenderer
self.gs_renderer = GaussianRenderer(renderer_config=renderer_config)
self.register_buffer('pp', torch.tensor([256, 256], dtype=torch.float32), persistent=False)
def forward_gaussians(self, images, **kwargs):
"""
images: B x N x 3 x H x W
"""
gaussians = self.transformer(images) # B x N x H x W x C
if self.sh_residual:
residual = torch.zeros_like(gaussians)
sh = RGB2SH(rearrange(images, 'b n c h w -> b n h w c'))
residual[..., 3:6] = sh
gaussians = gaussians + residual
gaussians = rearrange(gaussians, 'b n h w c -> b (n h w) c')
return gaussians
def forward_renderer(self, gaussians, c2ws, fxfycxcy, **kwargs):
"""
gaussians: B x K x 14
c2ws: B x N x 4 x 4
fxfycxcy: B x N x 4
"""
render_results = self.gs_renderer.render(gaussians, fxfycxcy, c2ws, **kwargs)
return render_results
@torch.inference_mode()
def estimate_focals(
self,
images,
masks=None,
use_first_focal=False,
):
"""
Estimate the focal lengths of N input images.
images: N x 3 x H x W
masks: N x 1 x H x W
"""
assert images.ndim == 4
N, _, H, W = images.shape
assert H == W, "Non-square images are not supported."
pp = self.pp.to(images)
# pp = torch.tensor([W/2, H/2]).to(images)
focals = []
for i in range(N):
if use_first_focal and i > 0:
break
images_input = torch.cat([images[i:], images[:i]], dim=0)
gaussians = self.forward_gaussians(images_input.unsqueeze(0)) # 1 x (N x H x W) x 14
points = rearrange(gaussians[0, :H*W, :3], '(h w) c -> h w c', h=H, w=W)
mask = masks[i] if masks is not None else None
focal = estimate_focal(points, pp=pp, mask=mask)
focals.append(focal)
focals = torch.stack(focals).to(images)
focals = focals.mean().reshape(1).repeat(N)
return focals
@torch.inference_mode()
def estimate_poses(
self,
images,
gaussians=None,
masks=None,
focals=None,
use_first_focal=True,
opacity_threshold=5e-2,
pnp_iter=20,
):
"""
Estimate the camera poses of N input images.
images: N x 3 x h x W
gaussians: K x 14 or 1 x K x 14
masks: N x 1 x H x W
focals: N
"""
assert images.ndim == 4
N, _, H, W = images.shape
assert H == W, "Non-square images are not supported."
# predict gaussians from images
if gaussians is None:
gaussians = self.forward_gaussians(images.unsqueeze(0)) # 1 x (N x H x W) x 14
else:
if gaussians.ndim == 2:
gaussians = gaussians.unsqueeze(0)
assert gaussians.shape[1] == N * H * W
points = gaussians[..., :3].reshape(1, N, H, W, 3).squeeze(0) # N x H x W x 3
opacities = gaussians[..., 3+self.sh_dim].reshape(1, N, H, W).squeeze(0)
opacities = torch.sigmoid(opacities) # N x H x W
# estimate focals if not provided
if focals is None:
focals = self.estimate_focals(images, masks=masks, use_first_focal=use_first_focal)
# run PnP
c2ws = []
for i in range(N):
pts3d = points[i].float().detach().cpu().numpy()
# If masks are not provided, we use Gaussian opacities
if masks is None:
mask = (opacities[i] > opacity_threshold).detach().cpu().numpy()
else:
mask = masks[i].reshape(H, W).bool().detach().cpu().numpy()
focal = focals[i].item()
_, c2w = fast_pnp(pts3d, mask, focal=focal, niter_PnP=pnp_iter)
c2ws.append(torch.from_numpy(c2w))
c2ws = torch.stack(c2ws, dim=0).to(images)
return c2ws, focals