import math import random import time from functools import wraps import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torch import distributions as pyd from torch.distributions.utils import _standard_normal from collections.abc import MutableMapping class eval_mode: def __init__(self, *models): self.models = models def __enter__(self): self.prev_states = [] for model in self.models: self.prev_states.append(model.training) model.train(False) def __exit__(self, *args): for model, state in zip(self.models, self.prev_states): model.train(state) return False def set_seed_everywhere(seed): torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) np.random.seed(seed) random.seed(seed) def soft_update_params(net, target_net, tau): for param, target_param in zip(net.parameters(), target_net.parameters()): target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data) def hard_update_params(net, target_net): for param, target_param in zip(net.parameters(), target_net.parameters()): target_param.data.copy_(param.data) def weight_init(m): """Custom weight init for Conv2D and Linear layers.""" if isinstance(m, nn.Linear): nn.init.orthogonal_(m.weight.data) if hasattr(m.bias, 'data'): m.bias.data.fill_(0.0) elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): gain = nn.init.calculate_gain('relu') nn.init.orthogonal_(m.weight.data, gain) if hasattr(m.bias, 'data'): m.bias.data.fill_(0.0) class Until: def __init__(self, until, action_repeat=1): self._until = until self._action_repeat = action_repeat def __call__(self, step): if self._until is None: return True until = self._until // self._action_repeat return step < until class Every: def __init__(self, every, action_repeat=1): self._every = every self._action_repeat = action_repeat def __call__(self, step): if self._every is None: return False every = self._every // self._action_repeat if step % every == 0: return True return False class Timer: def __init__(self): self._start_time = time.time() self._last_time = time.time() def reset(self): elapsed_time = time.time() - self._last_time self._last_time = time.time() total_time = time.time() - self._start_time return elapsed_time, total_time def total_time(self): return time.time() - self._start_time class TruncatedNormal(pyd.Normal): def __init__(self, loc, scale, low=-1.0, high=1.0, eps=1e-6): super().__init__(loc, scale, validate_args=False) self.low = low self.high = high self.eps = eps def _clamp(self, x): clamped_x = torch.clamp(x, self.low + self.eps, self.high - self.eps) x = x - x.detach() + clamped_x.detach() return x def sample(self, sample_shape=torch.Size(), stddev_clip=None): shape = self._extended_shape(sample_shape) eps = _standard_normal(shape, dtype=self.loc.dtype, device=self.loc.device) eps *= self.scale if stddev_clip is not None: eps = torch.clamp(eps, -stddev_clip, stddev_clip) x = self.loc + eps return self._clamp(x) class TanhTransform(pyd.transforms.Transform): domain = pyd.constraints.real codomain = pyd.constraints.interval(-1.0, 1.0) bijective = True sign = +1 def __init__(self, cache_size=1): super().__init__(cache_size=cache_size) @staticmethod def atanh(x): return 0.5 * (x.log1p() - (-x).log1p()) def __eq__(self, other): return isinstance(other, TanhTransform) def _call(self, x): return x.tanh() def _inverse(self, y): # We do not clamp to the boundary here as it may degrade the performance of certain algorithms. # one should use `cache_size=1` instead return self.atanh(y) def log_abs_det_jacobian(self, x, y): # We use a formula that is more numerically stable, see details in the following link # https://github.com/tensorflow/probability/commit/ef6bb176e0ebd1cf6e25c6b5cecdd2428c22963f#diff-e120f70e92e6741bca649f04fcd907b7 return 2. * (math.log(2.) - x - F.softplus(-2. * x)) class SquashedNormal(pyd.transformed_distribution.TransformedDistribution): def __init__(self, loc, scale): self.loc = loc self.scale = scale self.base_dist = pyd.Normal(loc, scale) transforms = [TanhTransform()] super().__init__(self.base_dist, transforms) @property def mean(self): mu = self.loc for tr in self.transforms: mu = tr(mu) return mu def retry(func): """ A Decorator to retry a function for a certain amount of attempts """ @wraps(func) def wrapper(*args, **kwargs): attempts = 0 max_attempts = 1000 while attempts < max_attempts: try: return func(*args, **kwargs) except (OSError, PermissionError): attempts += 1 time.sleep(0.1) raise OSError("Retry failed") return wrapper def flatten_dict(dictionary, parent_key='', separator='_'): items = [] for key in dictionary.keys(): try: value = dictionary[key] except: value = '??? ' new_key = parent_key + separator + key if parent_key else key if isinstance(value, MutableMapping): items.extend(flatten_dict(value, new_key, separator=separator).items()) else: items.append((new_key, value)) return dict(items) def slerp(t, v0, v1, DOT_THRESHOLD=0.9995): ''' Spherical linear interpolation Args: t (float/np.ndarray): Float value between 0.0 and 1.0 v0 (np.ndarray): Starting vector v1 (np.ndarray): Final vector DOT_THRESHOLD (float): Threshold for considering the two vectors as colineal. Not recommended to alter this. Returns: v2 (np.ndarray): Interpolation vector between v0 and v1 ''' c = False if not isinstance(v0,np.ndarray): c = True v0 = v0.detach().cpu().numpy() if not isinstance(v1,np.ndarray): c = True v1 = v1.detach().cpu().numpy() if len(v0.shape) == 1: v0 = v0.reshape(1, -1) if len(v1.shape) == 1: v1 = v1.reshape(1, -1) # Copy the vectors to reuse them later v0_copy = np.copy(v0) v1_copy = np.copy(v1) # Normalize the vectors to get the directions and angles v0 = v0 / np.linalg.norm(v0, axis=-1, keepdims=True) v1 = v1 / np.linalg.norm(v1, axis=-1, keepdims=True) # Dot product with the normalized vectors (can't use np.dot in W) dot = np.sum(v0 * v1, axis=-1) # If absolute value of dot product is almost 1, vectors are ~colineal, so use lerp if (np.abs(dot) > DOT_THRESHOLD).any(): raise NotImplementedError('lerp not implemented') # return lerp(t, v0_copy, v1_copy) # Calculate initial angle between v0 and v1 theta_0 = np.arccos(dot) sin_theta_0 = np.sin(theta_0) # Angle at timestep t theta_t = theta_0 * t sin_theta_t = np.sin(theta_t) # Finish the slerp algorithm s0 = np.sin(theta_0 - theta_t) / sin_theta_0 s1 = sin_theta_t / sin_theta_0 v2 = s0.reshape(-1, 1) * v0_copy + s1.reshape(-1, 1) * v1_copy if c: res = torch.from_numpy(v2).to("cuda") else: res = v2 return res