Spaces:
Runtime error
Runtime error
import gc | |
from collections import defaultdict | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch.autograd import Function | |
from torch.cuda.amp import custom_bwd, custom_fwd | |
import tinycudann as tcnn | |
def chunk_batch(func, chunk_size, move_to_cpu, *args, **kwargs): | |
B = None | |
for arg in args: | |
if isinstance(arg, torch.Tensor): | |
B = arg.shape[0] | |
break | |
out = defaultdict(list) | |
out_type = None | |
for i in range(0, B, chunk_size): | |
out_chunk = func(*[arg[i:i+chunk_size] if isinstance(arg, torch.Tensor) else arg for arg in args], **kwargs) | |
if out_chunk is None: | |
continue | |
out_type = type(out_chunk) | |
if isinstance(out_chunk, torch.Tensor): | |
out_chunk = {0: out_chunk} | |
elif isinstance(out_chunk, tuple) or isinstance(out_chunk, list): | |
chunk_length = len(out_chunk) | |
out_chunk = {i: chunk for i, chunk in enumerate(out_chunk)} | |
elif isinstance(out_chunk, dict): | |
pass | |
else: | |
print(f'Return value of func must be in type [torch.Tensor, list, tuple, dict], get {type(out_chunk)}.') | |
exit(1) | |
for k, v in out_chunk.items(): | |
v = v if torch.is_grad_enabled() else v.detach() | |
v = v.cpu() if move_to_cpu else v | |
out[k].append(v) | |
if out_type is None: | |
return | |
out = {k: torch.cat(v, dim=0) for k, v in out.items()} | |
if out_type is torch.Tensor: | |
return out[0] | |
elif out_type in [tuple, list]: | |
return out_type([out[i] for i in range(chunk_length)]) | |
elif out_type is dict: | |
return out | |
class _TruncExp(Function): # pylint: disable=abstract-method | |
# Implementation from torch-ngp: | |
# https://github.com/ashawkey/torch-ngp/blob/93b08a0d4ec1cc6e69d85df7f0acdfb99603b628/activation.py | |
def forward(ctx, x): # pylint: disable=arguments-differ | |
ctx.save_for_backward(x) | |
return torch.exp(x) | |
def backward(ctx, g): # pylint: disable=arguments-differ | |
x = ctx.saved_tensors[0] | |
return g * torch.exp(torch.clamp(x, max=15)) | |
trunc_exp = _TruncExp.apply | |
def get_activation(name): | |
if name is None: | |
return lambda x: x | |
name = name.lower() | |
if name == 'none': | |
return lambda x: x | |
elif name.startswith('scale'): | |
scale_factor = float(name[5:]) | |
return lambda x: x.clamp(0., scale_factor) / scale_factor | |
elif name.startswith('clamp'): | |
clamp_max = float(name[5:]) | |
return lambda x: x.clamp(0., clamp_max) | |
elif name.startswith('mul'): | |
mul_factor = float(name[3:]) | |
return lambda x: x * mul_factor | |
elif name == 'lin2srgb': | |
return lambda x: torch.where(x > 0.0031308, torch.pow(torch.clamp(x, min=0.0031308), 1.0/2.4)*1.055 - 0.055, 12.92*x).clamp(0., 1.) | |
elif name == 'trunc_exp': | |
return trunc_exp | |
elif name.startswith('+') or name.startswith('-'): | |
return lambda x: x + float(name) | |
elif name == 'sigmoid': | |
return lambda x: torch.sigmoid(x) | |
elif name == 'tanh': | |
return lambda x: torch.tanh(x) | |
else: | |
return getattr(F, name) | |
def dot(x, y): | |
return torch.sum(x*y, -1, keepdim=True) | |
def reflect(x, n): | |
return 2 * dot(x, n) * n - x | |
def scale_anything(dat, inp_scale, tgt_scale): | |
if inp_scale is None: | |
inp_scale = [dat.min(), dat.max()] | |
dat = (dat - inp_scale[0]) / (inp_scale[1] - inp_scale[0]) | |
dat = dat * (tgt_scale[1] - tgt_scale[0]) + tgt_scale[0] | |
return dat | |
def cleanup(): | |
gc.collect() | |
torch.cuda.empty_cache() | |
tcnn.free_temporary_memory() | |