import math from dataclasses import dataclass from typing import List, Optional, Union import numpy as np import torch import torch.nn.functional as F import trimesh from PIL import Image from torch import BoolTensor, FloatTensor LIST_TYPE = Union[list, np.ndarray, torch.Tensor] def list_to_pt( x: LIST_TYPE, dtype: Optional[torch.dtype] = None, device: Optional[str] = None ) -> torch.Tensor: if isinstance(x, list) or isinstance(x, np.ndarray): return torch.tensor(x, dtype=dtype, device=device) return x.to(dtype=dtype) def get_c2w( elevation_deg: LIST_TYPE, distance: LIST_TYPE, azimuth_deg: Optional[LIST_TYPE], num_views: Optional[int] = 1, device: Optional[str] = None, ) -> torch.FloatTensor: if azimuth_deg is None: assert ( num_views is not None ), "num_views must be provided if azimuth_deg is None." azimuth_deg = torch.linspace( 0, 360, num_views + 1, dtype=torch.float32, device=device )[:-1] else: num_views = len(azimuth_deg) azimuth_deg = list_to_pt(azimuth_deg, dtype=torch.float32, device=device) elevation_deg = list_to_pt(elevation_deg, dtype=torch.float32, device=device) camera_distances = list_to_pt(distance, dtype=torch.float32, device=device) elevation = elevation_deg * math.pi / 180 azimuth = azimuth_deg * math.pi / 180 camera_positions = torch.stack( [ camera_distances * torch.cos(elevation) * torch.cos(azimuth), camera_distances * torch.cos(elevation) * torch.sin(azimuth), camera_distances * torch.sin(elevation), ], dim=-1, ) center = torch.zeros_like(camera_positions) up = torch.tensor([0, 0, 1], dtype=torch.float32, device=device)[None, :].repeat( num_views, 1 ) lookat = F.normalize(center - camera_positions, dim=-1) right = F.normalize(torch.cross(lookat, up, dim=-1), dim=-1) up = F.normalize(torch.cross(right, lookat, dim=-1), dim=-1) c2w3x4 = torch.cat( [torch.stack([right, up, -lookat], dim=-1), camera_positions[:, :, None]], dim=-1, ) c2w = torch.cat([c2w3x4, torch.zeros_like(c2w3x4[:, :1])], dim=1) c2w[:, 3, 3] = 1.0 return c2w def get_projection_matrix( fovy_deg: LIST_TYPE, aspect_wh: float = 1.0, near: float = 0.1, far: float = 100.0, device: Optional[str] = None, ) -> torch.FloatTensor: fovy_deg = list_to_pt(fovy_deg, dtype=torch.float32, device=device) batch_size = fovy_deg.shape[0] fovy = fovy_deg * math.pi / 180 tan_half_fovy = torch.tan(fovy / 2) projection_matrix = torch.zeros( batch_size, 4, 4, dtype=torch.float32, device=device ) projection_matrix[:, 0, 0] = 1 / (aspect_wh * tan_half_fovy) projection_matrix[:, 1, 1] = -1 / tan_half_fovy projection_matrix[:, 2, 2] = -(far + near) / (far - near) projection_matrix[:, 2, 3] = -2 * far * near / (far - near) projection_matrix[:, 3, 2] = -1 return projection_matrix def get_orthogonal_projection_matrix( batch_size: int, left: float, right: float, bottom: float, top: float, near: float = 0.1, far: float = 100.0, device: Optional[str] = None, ) -> torch.FloatTensor: projection_matrix = torch.zeros( batch_size, 4, 4, dtype=torch.float32, device=device ) projection_matrix[:, 0, 0] = 2 / (right - left) projection_matrix[:, 1, 1] = -2 / (top - bottom) projection_matrix[:, 2, 2] = -2 / (far - near) projection_matrix[:, 0, 3] = -(right + left) / (right - left) projection_matrix[:, 1, 3] = -(top + bottom) / (top - bottom) projection_matrix[:, 2, 3] = -(far + near) / (far - near) projection_matrix[:, 3, 3] = 1 return projection_matrix @dataclass class Camera: c2w: Optional[torch.FloatTensor] w2c: torch.FloatTensor proj_mtx: torch.FloatTensor mvp_mtx: torch.FloatTensor cam_pos: Optional[torch.FloatTensor] def __getitem__(self, index): if isinstance(index, int): sl = slice(index, index + 1) elif isinstance(index, slice): sl = index else: raise NotImplementedError return Camera( c2w=self.c2w[sl] if self.c2w is not None else None, w2c=self.w2c[sl], proj_mtx=self.proj_mtx[sl], mvp_mtx=self.mvp_mtx[sl], cam_pos=self.cam_pos[sl] if self.cam_pos is not None else None, ) def to(self, device: Optional[str] = None): if self.c2w is not None: self.c2w = self.c2w.to(device) self.w2c = self.w2c.to(device) self.proj_mtx = self.proj_mtx.to(device) self.mvp_mtx = self.mvp_mtx.to(device) if self.cam_pos is not None: self.cam_pos = self.cam_pos.to(device) def __len__(self): return self.c2w.shape[0] def get_camera( elevation_deg: Optional[LIST_TYPE] = None, distance: Optional[LIST_TYPE] = None, fovy_deg: Optional[LIST_TYPE] = None, azimuth_deg: Optional[LIST_TYPE] = None, num_views: Optional[int] = 1, c2w: Optional[torch.FloatTensor] = None, w2c: Optional[torch.FloatTensor] = None, proj_mtx: Optional[torch.FloatTensor] = None, aspect_wh: float = 1.0, near: float = 0.1, far: float = 100.0, device: Optional[str] = None, ): if w2c is None: if c2w is None: c2w = get_c2w(elevation_deg, distance, azimuth_deg, num_views, device) camera_positions = c2w[:, :3, 3] w2c = torch.linalg.inv(c2w) else: camera_positions = None c2w = None if proj_mtx is None: proj_mtx = get_projection_matrix( fovy_deg, aspect_wh=aspect_wh, near=near, far=far, device=device ) mvp_mtx = proj_mtx @ w2c return Camera( c2w=c2w, w2c=w2c, proj_mtx=proj_mtx, mvp_mtx=mvp_mtx, cam_pos=camera_positions ) def get_orthogonal_camera( elevation_deg: LIST_TYPE, distance: LIST_TYPE, left: float, right: float, bottom: float, top: float, azimuth_deg: Optional[LIST_TYPE] = None, num_views: Optional[int] = 1, near: float = 0.1, far: float = 100.0, device: Optional[str] = None, ): c2w = get_c2w(elevation_deg, distance, azimuth_deg, num_views, device) camera_positions = c2w[:, :3, 3] w2c = torch.linalg.inv(c2w) proj_mtx = get_orthogonal_projection_matrix( batch_size=c2w.shape[0], left=left, right=right, bottom=bottom, top=top, near=near, far=far, device=device, ) mvp_mtx = proj_mtx @ w2c return Camera( c2w=c2w, w2c=w2c, proj_mtx=proj_mtx, mvp_mtx=mvp_mtx, cam_pos=camera_positions )