Spaces:
Runtime error
Runtime error
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. | |
# | |
# NVIDIA CORPORATION and its licensors retain all intellectual property | |
# and proprietary rights in and to this software, related documentation | |
# and any modifications thereto. Any use, reproduction, disclosure or | |
# distribution of this software and related documentation without an express | |
# license agreement from NVIDIA CORPORATION is strictly prohibited. | |
import re | |
import contextlib | |
import numpy as np | |
import torch | |
import warnings | |
#---------------------------------------------------------------------------- | |
# Cached construction of constant tensors. Avoids CPU=>GPU copy when the | |
# same constant is used multiple times. | |
_constant_cache = dict() | |
def constant(value, shape=None, dtype=None, device=None, memory_format=None): | |
value = np.asarray(value) | |
if shape is not None: | |
shape = tuple(shape) | |
if dtype is None: | |
dtype = torch.get_default_dtype() | |
if device is None: | |
device = torch.device('cpu') | |
if memory_format is None: | |
memory_format = torch.contiguous_format | |
key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format) | |
tensor = _constant_cache.get(key, None) | |
if tensor is None: | |
tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device) | |
if shape is not None: | |
tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape)) | |
tensor = tensor.contiguous(memory_format=memory_format) | |
_constant_cache[key] = tensor | |
return tensor | |
#---------------------------------------------------------------------------- | |
# Replace NaN/Inf with specified numerical values. | |
try: | |
nan_to_num = torch.nan_to_num # 1.8.0a0 | |
except AttributeError: | |
def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin | |
assert isinstance(input, torch.Tensor) | |
if posinf is None: | |
posinf = torch.finfo(input.dtype).max | |
if neginf is None: | |
neginf = torch.finfo(input.dtype).min | |
assert nan == 0 | |
return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out) | |
#---------------------------------------------------------------------------- | |
# Symbolic assert. | |
try: | |
symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access | |
except AttributeError: | |
symbolic_assert = torch.Assert # 1.7.0 | |
#---------------------------------------------------------------------------- | |
# Context manager to suppress known warnings in torch.jit.trace(). | |
class suppress_tracer_warnings(warnings.catch_warnings): | |
def __enter__(self): | |
super().__enter__() | |
warnings.simplefilter('ignore', category=torch.jit.TracerWarning) | |
return self | |
#---------------------------------------------------------------------------- | |
# Assert that the shape of a tensor matches the given list of integers. | |
# None indicates that the size of a dimension is allowed to vary. | |
# Performs symbolic assertion when used in torch.jit.trace(). | |
def assert_shape(tensor, ref_shape): | |
if tensor.ndim != len(ref_shape): | |
raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}') | |
for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)): | |
if ref_size is None: | |
pass | |
elif isinstance(ref_size, torch.Tensor): | |
with suppress_tracer_warnings(): # as_tensor results are registered as constants | |
symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}') | |
elif isinstance(size, torch.Tensor): | |
with suppress_tracer_warnings(): # as_tensor results are registered as constants | |
symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}') | |
elif size != ref_size: | |
raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}') | |
#---------------------------------------------------------------------------- | |
# Function decorator that calls torch.autograd.profiler.record_function(). | |
def profiled_function(fn): | |
def decorator(*args, **kwargs): | |
with torch.autograd.profiler.record_function(fn.__name__): | |
return fn(*args, **kwargs) | |
decorator.__name__ = fn.__name__ | |
return decorator | |
#---------------------------------------------------------------------------- | |
# Sampler for torch.utils.data.DataLoader that loops over the dataset | |
# indefinitely, shuffling items as it goes. | |
class InfiniteSampler(torch.utils.data.Sampler): | |
def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5): | |
assert len(dataset) > 0 | |
assert num_replicas > 0 | |
assert 0 <= rank < num_replicas | |
assert 0 <= window_size <= 1 | |
super().__init__(dataset) | |
self.dataset = dataset | |
self.rank = rank | |
self.num_replicas = num_replicas | |
self.shuffle = shuffle | |
self.seed = seed | |
self.window_size = window_size | |
def __iter__(self): | |
order = np.arange(len(self.dataset)) | |
rnd = None | |
window = 0 | |
if self.shuffle: | |
rnd = np.random.RandomState(self.seed) | |
rnd.shuffle(order) | |
window = int(np.rint(order.size * self.window_size)) | |
idx = 0 | |
while True: | |
i = idx % order.size | |
if idx % self.num_replicas == self.rank: | |
yield order[i] | |
if window >= 2: | |
j = (i - rnd.randint(window)) % order.size | |
order[i], order[j] = order[j], order[i] | |
idx += 1 | |
#---------------------------------------------------------------------------- | |
# Utilities for operating with torch.nn.Module parameters and buffers. | |
def params_and_buffers(module): | |
assert isinstance(module, torch.nn.Module) | |
return list(module.parameters()) + list(module.buffers()) | |
def named_params_and_buffers(module): | |
assert isinstance(module, torch.nn.Module) | |
return list(module.named_parameters()) + list(module.named_buffers()) | |
def copy_params_and_buffers(src_module, dst_module, require_all=False): | |
assert isinstance(src_module, torch.nn.Module) | |
assert isinstance(dst_module, torch.nn.Module) | |
src_tensors = {name: tensor for name, tensor in named_params_and_buffers(src_module)} | |
for name, tensor in named_params_and_buffers(dst_module): | |
assert (name in src_tensors) or (not require_all) | |
if name in src_tensors: | |
tensor.copy_(src_tensors[name].detach()).requires_grad_(tensor.requires_grad) | |
#---------------------------------------------------------------------------- | |
# Context manager for easily enabling/disabling DistributedDataParallel | |
# synchronization. | |
def ddp_sync(module, sync): | |
assert isinstance(module, torch.nn.Module) | |
if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel): | |
yield | |
else: | |
with module.no_sync(): | |
yield | |