# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: LicenseRef-NvidiaProprietary # # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual # property and proprietary rights in and to this material, related # documentation and any modifications thereto. Any use, reproduction, # disclosure or distribution of this material and related documentation # without an express license agreement from NVIDIA CORPORATION or # its affiliates is strictly prohibited. """ The ray sampler is a module that takes in camera matrices and resolution and batches of rays. Expects cam2world matrices that use the OpenCV camera coordinate system conventions. """ import torch from pdb import set_trace as st import random HUGE_NUMBER = 1e10 TINY_NUMBER = 1e-6 # float32 only has 7 decimal digits precision ###################################################################################### # wrapper to simplify the use of nerfnet ###################################################################################### # https://github.com/Kai-46/nerfplusplus/blob/ebf2f3e75fd6c5dfc8c9d0b533800daaf17bd95f/ddp_model.py#L16 def depth2pts_outside(ray_o, ray_d, depth): ''' ray_o, ray_d: [..., 3] depth: [...]; inverse of distance to sphere origin ''' # note: d1 becomes negative if this mid point is behind camera d1 = -torch.sum(ray_d * ray_o, dim=-1) / torch.sum(ray_d * ray_d, dim=-1) p_mid = ray_o + d1.unsqueeze(-1) * ray_d p_mid_norm = torch.norm(p_mid, dim=-1) ray_d_cos = 1. / torch.norm(ray_d, dim=-1) d2 = torch.sqrt(1. - p_mid_norm * p_mid_norm) * ray_d_cos p_sphere = ray_o + (d1 + d2).unsqueeze(-1) * ray_d rot_axis = torch.cross(ray_o, p_sphere, dim=-1) rot_axis = rot_axis / torch.norm(rot_axis, dim=-1, keepdim=True) phi = torch.asin(p_mid_norm) theta = torch.asin(p_mid_norm * depth) # depth is inside [0, 1] rot_angle = (phi - theta).unsqueeze(-1) # [..., 1] # now rotate p_sphere # Rodrigues formula: https://en.wikipedia.org/wiki/Rodrigues%27_rotation_formula p_sphere_new = p_sphere * torch.cos(rot_angle) + \ torch.cross(rot_axis, p_sphere, dim=-1) * torch.sin(rot_angle) + \ rot_axis * torch.sum(rot_axis*p_sphere, dim=-1, keepdim=True) * (1.-torch.cos(rot_angle)) p_sphere_new = p_sphere_new / torch.norm( p_sphere_new, dim=-1, keepdim=True) pts = torch.cat((p_sphere_new, depth.unsqueeze(-1)), dim=-1) # now calculate conventional depth depth_real = 1. / (depth + TINY_NUMBER) * torch.cos(theta) * ray_d_cos + d1 return pts, depth_real class RaySampler(torch.nn.Module): def __init__(self): super().__init__() self.ray_origins_h, self.ray_directions, self.depths, self.image_coords, self.rendering_options = None, None, None, None, None def create_patch_uv(self, patch_resolution, resolution, cam2world_matrix, fg_bbox=None): def sample_patch_uv(fg_bbox=None): assert patch_resolution <= resolution def sample_patch_range(): patch_reolution_start = random.randint( 0, resolution - patch_resolution) # alias for randrange(start, stop+1) # patch_reolution_end = patch_reolution_start + patch_resolution return patch_reolution_start # , patch_reolution_end def sample_patch_range_oversample_boundary(range_start=None, range_end=None): # left down corner undersampled if range_start is None: # range_start = patch_resolution // 2 range_start = patch_resolution if range_end is None: # range_end = resolution + patch_resolution // 2 range_end = resolution + patch_resolution # oversample the boundary patch_reolution_end = random.randint( range_start, range_end, ) # clip range if patch_reolution_end <= patch_resolution: patch_reolution_end = patch_resolution elif patch_reolution_end > resolution: patch_reolution_end = resolution # patch_reolution_end = patch_reolution_start + patch_resolution return patch_reolution_end # , patch_reolution_end # h_start = sample_patch_range() # assert fg_bbox is not None if fg_bbox is not None and random.random( ) > 0.025: # only train foreground. Has 0.1 prob to sample/train background. # if fg_bbox is not None: # only train foreground. Has 0.1 prob to sample/train background. # only return one UV here top_min, left_min = fg_bbox[:, :2].min(dim=0, keepdim=True)[0][0] height_max, width_max = fg_bbox[:, 2:].max(dim=0, keepdim=True)[0][0] left_boundary, right_boundary = patch_resolution // 2, resolution - patch_resolution // 2 h_mid = random.randint( min(max(top_min, left_boundary), right_boundary), max(min(height_max, right_boundary), left_boundary), ) w_mid = random.randint( min(max(left_min, left_boundary), right_boundary), max(min(width_max, right_boundary), left_boundary), ) h_end = h_mid + patch_resolution // 2 w_end = w_mid + patch_resolution // 2 # if top_min + patch_resolution < height_max: # h_end = sample_patch_range_oversample_boundary( # top_min + patch_resolution, height_max) # else: # h_end = max( # height_max.to(torch.uint8).item(), patch_resolution) # if left_min + patch_resolution < width_max: # w_end = sample_patch_range_oversample_boundary( # left_min + patch_resolution, width_max) # else: # w_end = max( # width_max.to(torch.uint8).item(), patch_resolution) h_start = h_end - patch_resolution w_start = w_end - patch_resolution try: assert h_start >= 0 and w_start >= 0 except: st() else: h_end = sample_patch_range_oversample_boundary() h_start = h_end - patch_resolution w_end = sample_patch_range_oversample_boundary() w_start = w_end - patch_resolution assert h_start >= 0 and w_start >= 0 uv = torch.stack( torch.meshgrid( torch.arange( start=h_start, # end=h_start+patch_resolution, end=h_end, dtype=torch.float32, device=cam2world_matrix.device), torch.arange( start=w_start, # end=w_start + patch_resolution, end=w_end, dtype=torch.float32, device=cam2world_matrix.device), indexing='ij')) * (1. / resolution) + (0.5 / resolution) uv = uv.flip(0).reshape(2, -1).transpose(1, 0) # ij -> xy return uv, (h_start, w_start, patch_resolution, patch_resolution ) # top: int, left: int, height: int, width: int all_uv = [] ray_bboxes = [] for _ in range(cam2world_matrix.shape[0]): uv, bbox = sample_patch_uv(fg_bbox) all_uv.append(uv) ray_bboxes.append(bbox) all_uv = torch.stack(all_uv, 0) # B patch_res**2 2 # ray_bboxes = torch.stack(ray_bboxes, 0) # B patch_res**2 2 return all_uv, ray_bboxes def create_uv(self, resolution, cam2world_matrix): uv = torch.stack( torch.meshgrid(torch.arange(resolution, dtype=torch.float32, device=cam2world_matrix.device), torch.arange(resolution, dtype=torch.float32, device=cam2world_matrix.device), indexing='ij')) * (1. / resolution) + (0.5 / resolution) uv = uv.flip(0).reshape(2, -1).transpose(1, 0) # why uv = uv.unsqueeze(0).repeat(cam2world_matrix.shape[0], 1, 1) return uv def forward(self, cam2world_matrix, intrinsics, resolution, fg_mask=None): """ Create batches of rays and return origins and directions. cam2world_matrix: (N, 4, 4) intrinsics: (N, 3, 3) resolution: int ray_origins: (N, M, 3) ray_dirs: (N, M, 2) """ N, M = cam2world_matrix.shape[0], resolution**2 cam_locs_world = cam2world_matrix[:, :3, 3] fx = intrinsics[:, 0, 0] fy = intrinsics[:, 1, 1] cx = intrinsics[:, 0, 2] cy = intrinsics[:, 1, 2] sk = intrinsics[:, 0, 1] # uv = torch.stack( # torch.meshgrid(torch.arange(resolution, # dtype=torch.float32, # device=cam2world_matrix.device), # torch.arange(resolution, # dtype=torch.float32, # device=cam2world_matrix.device), # indexing='ij')) * (1. / resolution) + (0.5 / # resolution) # uv = uv.flip(0).reshape(2, -1).transpose(1, 0) # why # uv = uv.unsqueeze(0).repeat(cam2world_matrix.shape[0], 1, 1) uv = self.create_uv( resolution, cam2world_matrix, ) x_cam = uv[:, :, 0].view(N, -1) y_cam = uv[:, :, 1].view(N, -1) # [0,1] range z_cam = torch.ones((N, M), device=cam2world_matrix.device) # basically torch.inverse(intrinsics) x_lift = (x_cam - cx.unsqueeze(-1) + cy.unsqueeze(-1) * sk.unsqueeze(-1) / fy.unsqueeze(-1) - sk.unsqueeze(-1) * y_cam / fy.unsqueeze(-1)) / fx.unsqueeze(-1) * z_cam y_lift = (y_cam - cy.unsqueeze(-1)) / fy.unsqueeze(-1) * z_cam cam_rel_points = torch.stack( (x_lift, y_lift, z_cam, torch.ones_like(z_cam)), dim=-1) # st() world_rel_points = torch.bmm(cam2world_matrix, cam_rel_points.permute(0, 2, 1)).permute( 0, 2, 1)[:, :, :3] ray_dirs = world_rel_points - cam_locs_world[:, None, :] ray_dirs = torch.nn.functional.normalize(ray_dirs, dim=2) ray_origins = cam_locs_world.unsqueeze(1).repeat( 1, ray_dirs.shape[1], 1) return ray_origins, ray_dirs, None class PatchRaySampler(RaySampler): def forward(self, cam2world_matrix, intrinsics, patch_resolution, resolution, fg_bbox=None): """ Create batches of rays and return origins and directions. cam2world_matrix: (N, 4, 4) intrinsics: (N, 3, 3) resolution: int ray_origins: (N, M, 3) ray_dirs: (N, M, 2) """ N, M = cam2world_matrix.shape[0], patch_resolution**2 cam_locs_world = cam2world_matrix[:, :3, 3] fx = intrinsics[:, 0, 0] fy = intrinsics[:, 1, 1] cx = intrinsics[:, 0, 2] cy = intrinsics[:, 1, 2] sk = intrinsics[:, 0, 1] # uv = self.create_uv(resolution, cam2world_matrix) # all_uv, ray_bboxes = self.create_patch_uv( all_uv_list = [] ray_bboxes = [] for idx in range(N): uv, bboxes = self.create_patch_uv( patch_resolution, resolution, cam2world_matrix[idx:idx + 1], fg_bbox[idx:idx + 1] if fg_bbox is not None else None) # for debugging, hard coded all_uv_list.append( uv # cam2world_matrix[idx:idx+1], )[0] # for debugging, hard coded ) ray_bboxes.extend(bboxes) all_uv = torch.cat(all_uv_list, 0) # ray_bboxes = torch.cat(ray_bboxes_list, 0) # all_uv, _ = self.create_patch_uv( # patch_resolution, resolution, # cam2world_matrix, fg_bbox) # for debugging, hard coded # st() x_cam = all_uv[:, :, 0].view(N, -1) y_cam = all_uv[:, :, 1].view(N, -1) # [0,1] range z_cam = torch.ones((N, M), device=cam2world_matrix.device) # basically torch.inverse(intrinsics) x_lift = (x_cam - cx.unsqueeze(-1) + cy.unsqueeze(-1) * sk.unsqueeze(-1) / fy.unsqueeze(-1) - sk.unsqueeze(-1) * y_cam / fy.unsqueeze(-1)) / fx.unsqueeze(-1) * z_cam y_lift = (y_cam - cy.unsqueeze(-1)) / fy.unsqueeze(-1) * z_cam cam_rel_points = torch.stack( (x_lift, y_lift, z_cam, torch.ones_like(z_cam)), dim=-1) world_rel_points = torch.bmm(cam2world_matrix, cam_rel_points.permute(0, 2, 1)).permute( 0, 2, 1)[:, :, :3] ray_dirs = world_rel_points - cam_locs_world[:, None, :] ray_dirs = torch.nn.functional.normalize(ray_dirs, dim=2) ray_origins = cam_locs_world.unsqueeze(1).repeat( 1, ray_dirs.shape[1], 1) return ray_origins, ray_dirs, ray_bboxes