Spaces:
Sleeping
Sleeping
import torch | |
from torch.distributions import Kumaraswamy | |
def exp_in_prev_range(x,factor): | |
mini, maxi = x.min(0)[0], x.max(0)[0] | |
expx = (factor*x).exp() | |
expx_01 = (expx - expx.min(0)[0]) / (expx.max(0)[0] - expx.min(0)[0]) | |
return expx_01 * (maxi - mini) + mini | |
def get_batch(*args, hyperparameters, get_batch, **kwargs): | |
""" | |
This `get_batch` can be used to wrap another `get_batch` and apply a Kumaraswamy transform to the input. | |
The x's have to be in [0,1] for this to work! | |
""" | |
returns = get_batch(*args, hyperparameters=hyperparameters, **kwargs) | |
input_warping_type = hyperparameters.get('input_warping_type', 'kumar') | |
# controls what part of the batch ('x', 'y' or 'xy') to apply the warping to | |
input_warping_groups = hyperparameters.get('input_warping_groups', 'x') | |
# whether to norm inputs between 0 and 1 before warping, as warping is only possible in that range. | |
input_warping_norm = hyperparameters.get('input_warping_norm', False) | |
use_icdf = hyperparameters.get('input_warping_use_icdf', False) | |
def norm_to_0_1(x): | |
eps = .00001 | |
maxima = torch.max(x, 0)[0] | |
minima = torch.min(x, 0)[0] | |
normed_x = (x - minima) / (maxima - minima + eps) | |
def denorm(normed_x): | |
return normed_x * (maxima - minima + eps) + minima | |
return normed_x, denorm | |
def warp_input(x): | |
if input_warping_norm: | |
x, denorm = norm_to_0_1(x) | |
if input_warping_type == 'kumar': | |
if 'input_warping_c_std' in hyperparameters: | |
assert 'input_warping_c0_std' not in hyperparameters and 'input_warping_c1_std' not in hyperparameters | |
hyperparameters['input_warping_c0_std'] = hyperparameters['input_warping_c_std'] | |
hyperparameters['input_warping_c1_std'] = hyperparameters['input_warping_c_std'] | |
inside = 0 | |
while not inside: | |
c1 = (torch.randn(*x.shape[1:], device=x.device) * hyperparameters.get('input_warping_c1_std', .75)).exp() | |
c0 = (torch.randn(*x.shape[1:], device=x.device) * hyperparameters.get('input_warping_c0_std', .75)).exp() | |
if not hyperparameters.get('input_warping_in_range', False): | |
inside = True | |
elif (c1 < 10).all() and (c1 > 0).all() and (c0 < 10).all() and (c0 > 0).all(): | |
inside = True | |
else: | |
inside -= 1 | |
if inside < -100: | |
print('It seems that the input warping is not working.') | |
if c1_v := hyperparameters.get('fix_input_warping_c1', False): | |
c1[:] = c1_v | |
if c0_v := hyperparameters.get('fix_input_warping_c0', False): | |
c0[:] = c0_v | |
if hyperparameters.get('verbose', False): | |
print(f'c1: {c1}, c0: {c0}') | |
k = Kumaraswamy(concentration1=c1, concentration0=c0) | |
x_transformed = k.icdf(x) if use_icdf else k.cdf(x) | |
elif input_warping_type == 'exp': | |
transform_likelihood = hyperparameters.get('input_warping_transform_likelihood', 0.2) | |
to_be_transformed = torch.rand_like(x[0,0]) < transform_likelihood | |
transform_factors = torch.rand_like(x[0,0]) * hyperparameters.get('input_warping_transform_factor', 1.) | |
log_direction = torch.rand_like(x[0,0]) < 0.5 | |
exp_x = exp_in_prev_range(x, transform_factors) | |
minus_exp_x = 1.-exp_in_prev_range(1.-x, transform_factors) | |
exp_x = torch.where(log_direction, exp_x, minus_exp_x) | |
x_transformed = torch.where(to_be_transformed[None,None,:], exp_x, x) | |
elif input_warping_type is None or input_warping_type == 'none': | |
x_transformed = x | |
else: | |
raise ValueError(f"Unknown input_warping_type: {input_warping_type}") | |
if input_warping_norm: | |
x_transformed = denorm(x_transformed) | |
return x_transformed | |
if 'x' in input_warping_groups: | |
returns.x = warp_input(returns.x) | |
if 'y' in input_warping_groups: | |
returns.y = warp_input(returns.y) | |
returns.target_y = warp_input(returns.target_y) | |
return returns | |