File size: 6,050 Bytes
0f079b2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 |
import math
from collections import defaultdict
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import craftsman
from craftsman.utils.typing import *
def dot(x, y):
return torch.sum(x * y, -1, keepdim=True)
def reflect(x, n):
return 2 * dot(x, n) * n - x
ValidScale = Union[Tuple[float, float], Num[Tensor, "2 D"]]
def scale_tensor(
dat: Num[Tensor, "... D"], inp_scale: ValidScale, tgt_scale: ValidScale
):
if inp_scale is None:
inp_scale = (0, 1)
if tgt_scale is None:
tgt_scale = (0, 1)
if isinstance(tgt_scale, Tensor):
assert dat.shape[-1] == tgt_scale.shape[-1]
dat = (dat - inp_scale[0]) / (inp_scale[1] - inp_scale[0])
dat = dat * (tgt_scale[1] - tgt_scale[0]) + tgt_scale[0]
return dat
def chunk_batch(func: Callable, chunk_size: int, *args, **kwargs) -> Any:
if chunk_size <= 0:
return func(*args, **kwargs)
B = None
for arg in list(args) + list(kwargs.values()):
if isinstance(arg, torch.Tensor):
B = arg.shape[0]
break
assert (
B is not None
), "No tensor found in args or kwargs, cannot determine batch size."
out = defaultdict(list)
out_type = None
# max(1, B) to support B == 0
for i in range(0, max(1, B), chunk_size):
out_chunk = func(
*[
arg[i : i + chunk_size] if isinstance(arg, torch.Tensor) else arg
for arg in args
],
**{
k: arg[i : i + chunk_size] if isinstance(arg, torch.Tensor) else arg
for k, arg in kwargs.items()
},
)
if out_chunk is None:
continue
out_type = type(out_chunk)
if isinstance(out_chunk, torch.Tensor):
out_chunk = {0: out_chunk}
elif isinstance(out_chunk, tuple) or isinstance(out_chunk, list):
chunk_length = len(out_chunk)
out_chunk = {i: chunk for i, chunk in enumerate(out_chunk)}
elif isinstance(out_chunk, dict):
pass
else:
print(
f"Return value of func must be in type [torch.Tensor, list, tuple, dict], get {type(out_chunk)}."
)
exit(1)
for k, v in out_chunk.items():
v = v if torch.is_grad_enabled() else v.detach()
out[k].append(v)
if out_type is None:
return None
out_merged: Dict[Any, Optional[torch.Tensor]] = {}
for k, v in out.items():
if all([vv is None for vv in v]):
# allow None in return value
out_merged[k] = None
elif all([isinstance(vv, torch.Tensor) for vv in v]):
out_merged[k] = torch.cat(v, dim=0)
else:
raise TypeError(
f"Unsupported types in return value of func: {[type(vv) for vv in v if not isinstance(vv, torch.Tensor)]}"
)
if out_type is torch.Tensor:
return out_merged[0]
elif out_type in [tuple, list]:
return out_type([out_merged[i] for i in range(chunk_length)])
elif out_type is dict:
return out_merged
def randn_tensor(
shape: Union[Tuple, List],
generator: Optional[Union[List["torch.Generator"], "torch.Generator"]] = None,
device: Optional["torch.device"] = None,
dtype: Optional["torch.dtype"] = None,
layout: Optional["torch.layout"] = None,
):
"""A helper function to create random tensors on the desired `device` with the desired `dtype`. When
passing a list of generators, you can seed each batch size individually. If CPU generators are passed, the tensor
is always created on the CPU.
"""
# device on which tensor is created defaults to device
rand_device = device
batch_size = shape[0]
layout = layout or torch.strided
device = device or torch.device("cpu")
if generator is not None:
gen_device_type = generator.device.type if not isinstance(generator, list) else generator[0].device.type
if gen_device_type != device.type and gen_device_type == "cpu":
rand_device = "cpu"
if device != "mps":
logger.info(
f"The passed generator was created on 'cpu' even though a tensor on {device} was expected."
f" Tensors will be created on 'cpu' and then moved to {device}. Note that one can probably"
f" slighly speed up this function by passing a generator that was created on the {device} device."
)
elif gen_device_type != device.type and gen_device_type == "cuda":
raise ValueError(f"Cannot generate a {device} tensor from a generator of type {gen_device_type}.")
# make sure generator list of length 1 is treated like a non-list
if isinstance(generator, list) and len(generator) == 1:
generator = generator[0]
if isinstance(generator, list):
shape = (1,) + shape[1:]
latents = [
torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype, layout=layout)
for i in range(batch_size)
]
latents = torch.cat(latents, dim=0).to(device)
else:
latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout).to(device)
return latents
def generate_dense_grid_points(
bbox_min: np.ndarray,
bbox_max: np.ndarray,
octree_depth: int,
indexing: str = "ij"
):
length = bbox_max - bbox_min
num_cells = np.exp2(octree_depth)
x = np.linspace(bbox_min[0], bbox_max[0], int(num_cells) + 1, dtype=np.float32)
y = np.linspace(bbox_min[1], bbox_max[1], int(num_cells) + 1, dtype=np.float32)
z = np.linspace(bbox_min[2], bbox_max[2], int(num_cells) + 1, dtype=np.float32)
[xs, ys, zs] = np.meshgrid(x, y, z, indexing=indexing)
xyz = np.stack((xs, ys, zs), axis=-1)
xyz = xyz.reshape(-1, 3)
grid_size = [int(num_cells) + 1, int(num_cells) + 1, int(num_cells) + 1]
return xyz, grid_size, length |