Spaces:
Sleeping
Sleeping
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
from typing import Optional, Dict, Tuple | |
import numpy as np | |
import torch as th | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import random | |
from dva.mvp.extensions.mvpraymarch.mvpraymarch import mvpraymarch | |
from dva.mvp.extensions.utils.utils import compute_raydirs | |
import logging | |
logger = logging.getLogger(__name__) | |
def convert_camera_parameters(Rt, K): | |
R = Rt[:, :3, :3] | |
t = -R.permute(0, 2, 1).bmm(Rt[:, :3, 3].unsqueeze(2)).squeeze(2) | |
return dict( | |
campos=t, | |
camrot=R, | |
focal=K[:, :2, :2], | |
princpt=K[:, :2, 2], | |
) | |
def subsample_pixel_coords( | |
pixel_coords: th.Tensor, batch_size: int, ray_subsample_factor: int = 4 | |
): | |
H, W = pixel_coords.shape[:2] | |
SW = W // ray_subsample_factor | |
SH = H // ray_subsample_factor | |
all_coords = [] | |
for _ in range(batch_size): | |
# TODO: this is ugly, switch to pytorch? | |
x0 = th.randint(0, ray_subsample_factor - 1, size=()) | |
y0 = th.randint(0, ray_subsample_factor - 1, size=()) | |
dx = ray_subsample_factor | |
dy = ray_subsample_factor | |
x1 = x0 + dx * SW | |
y1 = y0 + dy * SH | |
all_coords.append(pixel_coords[y0:y1:dy, x0:x1:dx, :]) | |
all_coords = th.stack(all_coords, dim=0) | |
return all_coords | |
def resize_pixel_coords( | |
pixel_coords: th.Tensor, batch_size: int, ray_subsample_factor: int = 4 | |
): | |
H, W = pixel_coords.shape[:2] | |
SW = W // ray_subsample_factor | |
SH = H // ray_subsample_factor | |
all_coords = [] | |
for _ in range(batch_size): | |
# TODO: this is ugly, switch to pytorch? | |
x0, y0 = ray_subsample_factor // 2, ray_subsample_factor // 2 | |
dx = ray_subsample_factor | |
dy = ray_subsample_factor | |
x1 = x0 + dx * SW | |
y1 = y0 + dy * SH | |
all_coords.append(pixel_coords[y0:y1:dy, x0:x1:dx, :]) | |
all_coords = th.stack(all_coords, dim=0) | |
return all_coords | |
class RayMarcher(nn.Module): | |
def __init__( | |
self, | |
image_height, | |
image_width, | |
volradius, | |
fadescale=8.0, | |
fadeexp=8.0, | |
dt=1.0, | |
ray_subsample_factor=1, | |
accum=2, | |
termthresh=0.99, | |
blocksize=None, | |
with_t_img=True, | |
chlast=False, | |
assets=None, | |
): | |
super().__init__() | |
# TODO: add config? | |
self.image_height = image_height | |
self.image_width = image_width | |
self.volradius = volradius | |
self.dt = dt | |
self.fadescale = fadescale | |
self.fadeexp = fadeexp | |
# NOTE: this seems to not work for other configs? | |
if blocksize is None: | |
blocksize = (8, 16) | |
self.blocksize = blocksize | |
self.with_t_img = with_t_img | |
self.chlast = chlast | |
self.accum = accum | |
self.termthresh = termthresh | |
base_pixel_coords = th.stack( | |
th.meshgrid( | |
th.arange(self.image_height, dtype=th.float32), | |
th.arange(self.image_width, dtype=th.float32), | |
)[::-1], | |
dim=-1, | |
) | |
self.register_buffer("base_pixel_coords", base_pixel_coords, persistent=False) | |
self.fixed_bvh_cache = {-1: (th.empty(0), th.empty(0), th.empty(0))} | |
self.ray_subsample_factor = ray_subsample_factor | |
def _set_pix_coords(self): | |
dev = self.base_pixel_coords.device | |
self.base_pixel_coords = th.stack( | |
th.meshgrid( | |
th.arange(self.image_height, dtype=th.float32, device=dev), | |
th.arange(self.image_width, dtype=th.float32, device=dev), | |
)[::-1], | |
dim=-1, | |
) | |
def resize(self, h: int, w: int): | |
self.image_height = h | |
self.image_width = w | |
self._set_pix_coords() | |
def forward( | |
self, | |
prim_rgba: th.Tensor, | |
prim_pos: th.Tensor, | |
prim_rot: th.Tensor, | |
prim_scale: th.Tensor, | |
K: th.Tensor, | |
RT: th.Tensor, | |
ray_subsample_factor: Optional[int] = None, | |
): | |
""" | |
Args: | |
prim_rgba: primitive payload [B, K, 4, S, S, S], | |
K - # of primitives, S - primitive size | |
prim_pos: locations [B, K, 3] | |
prim_rot: rotations [B, K, 3, 3] | |
prim_scale: scales [B, K, 3] | |
K: intrinsics [B, 3, 3] | |
RT: extrinsics [B, 3, 4] | |
Returns: | |
a dict of tensors | |
""" | |
# TODO: maybe we can re-use mvpraymarcher? | |
B = prim_rgba.shape[0] | |
device = prim_rgba.device | |
# TODO: this should return focal 2x2? | |
camera = convert_camera_parameters(RT, K) | |
camera = {k: v.contiguous() for k, v in camera.items()} | |
dt = self.dt / self.volradius | |
if ray_subsample_factor is None: | |
ray_subsample_factor = self.ray_subsample_factor | |
if ray_subsample_factor > 1 and self.training: | |
pixel_coords = subsample_pixel_coords( | |
self.base_pixel_coords, int(B), ray_subsample_factor | |
) | |
elif ray_subsample_factor > 1: | |
pixel_coords = resize_pixel_coords( | |
self.base_pixel_coords, | |
int(B), | |
ray_subsample_factor, | |
) | |
else: | |
pixel_coords = ( | |
self.base_pixel_coords[np.newaxis].expand(B, -1, -1, -1).contiguous() | |
) | |
prim_pos = prim_pos / self.volradius | |
focal = th.diagonal(camera["focal"], dim1=1, dim2=2).contiguous() | |
# TODO: port this? | |
raypos, raydir, tminmax = compute_raydirs( | |
viewpos=camera["campos"], | |
viewrot=camera["camrot"], | |
focal=focal, | |
princpt=camera["princpt"], | |
pixelcoords=pixel_coords, | |
volradius=self.volradius, | |
) | |
rgba = mvpraymarch( | |
raypos, | |
raydir, | |
stepsize=dt, | |
tminmax=tminmax, | |
algo=0, | |
template=prim_rgba.permute(0, 1, 3, 4, 5, 2).contiguous(), | |
warp=None, | |
termthresh=self.termthresh, | |
primtransf=(prim_pos, prim_rot, prim_scale), | |
fadescale=self.fadescale, | |
fadeexp=self.fadeexp, | |
usebvh="fixedorder", | |
chlast=True, | |
) | |
rgba = rgba.permute(0, 3, 1, 2) | |
preds = { | |
"rgba_image": rgba, | |
"pixel_coords": pixel_coords, | |
} | |
return preds | |
def generate_colored_boxes(template, prim_rot, alpha=10000.0, seed=123456): | |
B = template.shape[0] | |
output = template.clone() | |
device = template.device | |
lightdir = -3 * th.ones([B, 3], dtype=th.float32, device=device) | |
lightdir = lightdir / th.norm(lightdir, p=2, dim=1, keepdim=True) | |
zz, yy, xx = th.meshgrid( | |
th.linspace(-1.0, 1.0, template.size(-1), device=device), | |
th.linspace(-1.0, 1.0, template.size(-1), device=device), | |
th.linspace(-1.0, 1.0, template.size(-1), device=device), | |
) | |
primnormalx = th.where( | |
(th.abs(xx) >= th.abs(yy)) & (th.abs(xx) >= th.abs(zz)), | |
th.sign(xx) * th.ones_like(xx), | |
th.zeros_like(xx), | |
) | |
primnormaly = th.where( | |
(th.abs(yy) >= th.abs(xx)) & (th.abs(yy) >= th.abs(zz)), | |
th.sign(yy) * th.ones_like(xx), | |
th.zeros_like(xx), | |
) | |
primnormalz = th.where( | |
(th.abs(zz) >= th.abs(xx)) & (th.abs(zz) >= th.abs(yy)), | |
th.sign(zz) * th.ones_like(xx), | |
th.zeros_like(xx), | |
) | |
primnormal = th.stack([primnormalx, -primnormaly, -primnormalz], dim=-1) | |
primnormal = primnormal / th.sqrt(th.sum(primnormal**2, dim=-1, keepdim=True)) | |
output[:, :, 3, :, :, :] = alpha | |
np.random.seed(seed) | |
for i in range(template.size(1)): | |
# generating a random color | |
output[:, i, 0, :, :, :] = np.random.rand() * 255.0 | |
output[:, i, 1, :, :, :] = np.random.rand() * 255.0 | |
output[:, i, 2, :, :, :] = np.random.rand() * 255.0 | |
# get light direction in local coordinate system? | |
lightdir0 = lightdir | |
mult = th.sum( | |
lightdir0[:, None, None, None, :] * primnormal[np.newaxis], dim=-1 | |
)[:, np.newaxis, :, :, :].clamp(min=0.2) | |
output[:, i, :3, :, :, :] *= 1.4 * mult | |
return output | |