|
import math |
|
from collections import defaultdict |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
import craftsman |
|
from craftsman.utils.typing import * |
|
|
|
|
|
def dot(x, y): |
|
return torch.sum(x * y, -1, keepdim=True) |
|
|
|
|
|
def reflect(x, n): |
|
return 2 * dot(x, n) * n - x |
|
|
|
|
|
ValidScale = Union[Tuple[float, float], Num[Tensor, "2 D"]] |
|
|
|
|
|
def scale_tensor( |
|
dat: Num[Tensor, "... D"], inp_scale: ValidScale, tgt_scale: ValidScale |
|
): |
|
if inp_scale is None: |
|
inp_scale = (0, 1) |
|
if tgt_scale is None: |
|
tgt_scale = (0, 1) |
|
if isinstance(tgt_scale, Tensor): |
|
assert dat.shape[-1] == tgt_scale.shape[-1] |
|
dat = (dat - inp_scale[0]) / (inp_scale[1] - inp_scale[0]) |
|
dat = dat * (tgt_scale[1] - tgt_scale[0]) + tgt_scale[0] |
|
return dat |
|
|
|
|
|
def chunk_batch(func: Callable, chunk_size: int, *args, **kwargs) -> Any: |
|
if chunk_size <= 0: |
|
return func(*args, **kwargs) |
|
B = None |
|
for arg in list(args) + list(kwargs.values()): |
|
if isinstance(arg, torch.Tensor): |
|
B = arg.shape[0] |
|
break |
|
assert ( |
|
B is not None |
|
), "No tensor found in args or kwargs, cannot determine batch size." |
|
out = defaultdict(list) |
|
out_type = None |
|
|
|
for i in range(0, max(1, B), chunk_size): |
|
out_chunk = func( |
|
*[ |
|
arg[i : i + chunk_size] if isinstance(arg, torch.Tensor) else arg |
|
for arg in args |
|
], |
|
**{ |
|
k: arg[i : i + chunk_size] if isinstance(arg, torch.Tensor) else arg |
|
for k, arg in kwargs.items() |
|
}, |
|
) |
|
if out_chunk is None: |
|
continue |
|
out_type = type(out_chunk) |
|
if isinstance(out_chunk, torch.Tensor): |
|
out_chunk = {0: out_chunk} |
|
elif isinstance(out_chunk, tuple) or isinstance(out_chunk, list): |
|
chunk_length = len(out_chunk) |
|
out_chunk = {i: chunk for i, chunk in enumerate(out_chunk)} |
|
elif isinstance(out_chunk, dict): |
|
pass |
|
else: |
|
print( |
|
f"Return value of func must be in type [torch.Tensor, list, tuple, dict], get {type(out_chunk)}." |
|
) |
|
exit(1) |
|
for k, v in out_chunk.items(): |
|
v = v if torch.is_grad_enabled() else v.detach() |
|
out[k].append(v) |
|
|
|
if out_type is None: |
|
return None |
|
|
|
out_merged: Dict[Any, Optional[torch.Tensor]] = {} |
|
for k, v in out.items(): |
|
if all([vv is None for vv in v]): |
|
|
|
out_merged[k] = None |
|
elif all([isinstance(vv, torch.Tensor) for vv in v]): |
|
out_merged[k] = torch.cat(v, dim=0) |
|
else: |
|
raise TypeError( |
|
f"Unsupported types in return value of func: {[type(vv) for vv in v if not isinstance(vv, torch.Tensor)]}" |
|
) |
|
|
|
if out_type is torch.Tensor: |
|
return out_merged[0] |
|
elif out_type in [tuple, list]: |
|
return out_type([out_merged[i] for i in range(chunk_length)]) |
|
elif out_type is dict: |
|
return out_merged |
|
|
|
|
|
def randn_tensor( |
|
shape: Union[Tuple, List], |
|
generator: Optional[Union[List["torch.Generator"], "torch.Generator"]] = None, |
|
device: Optional["torch.device"] = None, |
|
dtype: Optional["torch.dtype"] = None, |
|
layout: Optional["torch.layout"] = None, |
|
): |
|
"""A helper function to create random tensors on the desired `device` with the desired `dtype`. When |
|
passing a list of generators, you can seed each batch size individually. If CPU generators are passed, the tensor |
|
is always created on the CPU. |
|
""" |
|
|
|
rand_device = device |
|
batch_size = shape[0] |
|
|
|
layout = layout or torch.strided |
|
device = device or torch.device("cpu") |
|
|
|
if generator is not None: |
|
gen_device_type = generator.device.type if not isinstance(generator, list) else generator[0].device.type |
|
if gen_device_type != device.type and gen_device_type == "cpu": |
|
rand_device = "cpu" |
|
if device != "mps": |
|
logger.info( |
|
f"The passed generator was created on 'cpu' even though a tensor on {device} was expected." |
|
f" Tensors will be created on 'cpu' and then moved to {device}. Note that one can probably" |
|
f" slighly speed up this function by passing a generator that was created on the {device} device." |
|
) |
|
elif gen_device_type != device.type and gen_device_type == "cuda": |
|
raise ValueError(f"Cannot generate a {device} tensor from a generator of type {gen_device_type}.") |
|
|
|
|
|
if isinstance(generator, list) and len(generator) == 1: |
|
generator = generator[0] |
|
|
|
if isinstance(generator, list): |
|
shape = (1,) + shape[1:] |
|
latents = [ |
|
torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype, layout=layout) |
|
for i in range(batch_size) |
|
] |
|
latents = torch.cat(latents, dim=0).to(device) |
|
else: |
|
latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout).to(device) |
|
|
|
return latents |
|
|
|
|
|
def generate_dense_grid_points( |
|
bbox_min: np.ndarray, |
|
bbox_max: np.ndarray, |
|
octree_depth: int, |
|
indexing: str = "ij" |
|
): |
|
length = bbox_max - bbox_min |
|
num_cells = np.exp2(octree_depth) |
|
x = np.linspace(bbox_min[0], bbox_max[0], int(num_cells) + 1, dtype=np.float32) |
|
y = np.linspace(bbox_min[1], bbox_max[1], int(num_cells) + 1, dtype=np.float32) |
|
z = np.linspace(bbox_min[2], bbox_max[2], int(num_cells) + 1, dtype=np.float32) |
|
[xs, ys, zs] = np.meshgrid(x, y, z, indexing=indexing) |
|
xyz = np.stack((xs, ys, zs), axis=-1) |
|
xyz = xyz.reshape(-1, 3) |
|
grid_size = [int(num_cells) + 1, int(num_cells) + 1, int(num_cells) + 1] |
|
|
|
return xyz, grid_size, length |