Spaces:
Sleeping
Sleeping
import time | |
import torch | |
from torch import nn | |
import gpytorch | |
from .prior import Batch | |
from .utils import get_batch_to_dataloader | |
from ..utils import default_device | |
# We will use the simplest form of GP model, exact inference | |
class ExactGPModel(gpytorch.models.ExactGP): | |
def __init__(self, train_x, train_y, likelihood): | |
super(ExactGPModel, self).__init__(train_x, train_y, likelihood) | |
self.mean_module = gpytorch.means.ConstantMean() | |
self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel()) | |
def forward(self, x): | |
mean_x = self.mean_module(x) | |
covar_x = self.covar_module(x) | |
return gpytorch.distributions.MultivariateNormal(mean_x, covar_x) | |
def get_model(x, y, hyperparameters): | |
likelihood = gpytorch.likelihoods.GaussianLikelihood(noise_constraint=gpytorch.constraints.GreaterThan(1.e-9)) | |
model = ExactGPModel(x, y, likelihood) | |
model.likelihood.noise = torch.ones_like(model.likelihood.noise) * hyperparameters["noise"] | |
model.covar_module.outputscale = torch.ones_like(model.covar_module.outputscale) * hyperparameters["outputscale"] | |
model.covar_module.base_kernel.lengthscale = torch.ones_like(model.covar_module.base_kernel.lengthscale) * \ | |
hyperparameters["lengthscale"] | |
return model, likelihood | |
def get_batch(batch_size, seq_len, num_features, device=default_device, hyperparameters=None, | |
equidistant_x=False, fix_x=None, **kwargs): | |
if isinstance(hyperparameters, (tuple, list)): | |
hyperparameters = {"noise": hyperparameters[0] | |
, "outputscale": hyperparameters[1] | |
, "lengthscale": hyperparameters[2] | |
, "is_binary_classification": hyperparameters[3] | |
# , "num_features_used": hyperparameters[4] | |
, "normalize_by_used_features": hyperparameters[5] | |
, "order_y": hyperparameters[6] | |
, "sampling": hyperparameters[7] | |
} | |
elif hyperparameters is None: | |
hyperparameters = {"noise": .1, "outputscale": .1, "lengthscale": .1} | |
if 'verbose' in hyperparameters and hyperparameters['verbose']: | |
print({"noise": hyperparameters['noise'], "outputscale": hyperparameters['outputscale'] | |
, "lengthscale": hyperparameters['lengthscale'], 'batch_size': batch_size, 'sampling': hyperparameters['sampling']}) | |
observation_noise = hyperparameters.get("observation_noise", True) | |
# hyperparameters = {k: hyperparameters[k]() if callable(hyperparameters[k]) else hyperparameters[k] for k in | |
# hyperparameters.keys()} | |
assert not (equidistant_x and (fix_x is not None)) | |
with gpytorch.settings.fast_computations(*hyperparameters.get('fast_computations', (True, True, True))): | |
if equidistant_x: | |
assert num_features == 1 | |
x = torch.linspace(0, 1., seq_len).unsqueeze(0).repeat(batch_size, 1).unsqueeze(-1) | |
elif fix_x is not None: | |
assert fix_x.shape == (seq_len, num_features) | |
x = fix_x.unsqueeze(0).repeat(batch_size, 1, 1).to(device) | |
else: | |
if hyperparameters.get('sampling','uniform') == 'uniform': | |
x = torch.rand(batch_size, seq_len, num_features, device=device) | |
elif hyperparameters.get('sampling','uniform') == 'normal': | |
x = torch.randn(batch_size, seq_len, num_features, device=device) | |
elif isinstance(hyperparameters['sampling'], str) and hyperparameters['sampling'].startswith('uniform_'): | |
left_border, right_border = [float(v) for v in hyperparameters['sampling'][len('uniform_'):].split('_')] | |
x = torch.rand(batch_size, seq_len, num_features, device=device) * (right_border - left_border) + left_border | |
elif isinstance(hyperparameters['sampling'], str) and hyperparameters['sampling'].startswith('clustered_'): | |
dist_std, local_dist_std, base_likelihood = [float(v) for v in hyperparameters['sampling'][len('clustered_'):].split('_')] | |
def first_sample(dist): | |
return dist().unsqueeze(0) | |
def append_sample(samples, dist, local_dist, base_likelihood): | |
if samples is None: | |
return first_sample(dist) | |
num_samples, batch_size, num_features = samples.shape | |
use_base = torch.rand(batch_size) < base_likelihood | |
sample_mean = torch.where(use_base[:, None].repeat(1, num_features), | |
torch.zeros(batch_size, num_features), | |
samples[torch.randint(num_samples, (batch_size,)), | |
torch.arange(batch_size), :]) | |
return torch.cat((samples, (local_dist() + sample_mean).unsqueeze(0)), 0) | |
def create_sample(num_samples, dist, local_dist, base_likelihood): | |
samples = None | |
for i in range(num_samples): | |
samples = append_sample(samples, dist, local_dist, base_likelihood) | |
return samples[torch.randperm(num_samples)] | |
x = create_sample(seq_len, lambda: torch.randn(batch_size, num_features)*dist_std, | |
lambda: torch.rand(batch_size, num_features)*local_dist_std, base_likelihood)\ | |
.transpose(0,1).to(device) | |
elif isinstance(hyperparameters['sampling'], str) and hyperparameters['sampling'].startswith( | |
'gmix_'): | |
blob_width, n_centers_max, stddev = [float(v) for v in | |
hyperparameters['sampling'][len('gmix_'):].split('_')] | |
n_centers_max = int(n_centers_max) | |
def get_x(batch_size, n_samples, num_features, blob_width, n_centers_max, stddev, device): | |
n_centers = torch.randint(1, n_centers_max, tuple(), device=device) | |
centers = torch.rand((batch_size, n_centers, num_features), device=device) * blob_width - blob_width / 2 | |
center_assignments = torch.randint(n_centers, (batch_size, n_samples,), device=device) | |
noise = torch.randn((batch_size, n_samples, num_features), device=device) * stddev | |
return centers.gather(1, center_assignments[..., None].repeat(1, 1, | |
num_features)) + noise # centers: (b, m, f), ass: (b,n) | |
x = get_x(batch_size, seq_len, num_features, blob_width, n_centers_max, stddev, device) | |
elif isinstance(hyperparameters['sampling'], str) and hyperparameters['sampling'].startswith( | |
'grep_'): | |
stddev, = [float(v) for v in hyperparameters['sampling'][len('grep_'):].split('_')] | |
x = torch.randn(batch_size, seq_len//2, num_features, device=device) * stddev | |
x = x.repeat(1,2,1) | |
x = x[:,torch.randperm(x.shape[1]),:] | |
else: | |
x = torch.randn(batch_size, seq_len, num_features, device=device) * hyperparameters.get('sampling', 1.) | |
model, likelihood = get_model(x, torch.Tensor(), hyperparameters) | |
model.to(device) | |
# trained_model = ExactGPModel(train_x, train_y, likelihood).cuda() | |
# trained_model.eval() | |
successful_sample = False | |
while not successful_sample: | |
try: | |
with gpytorch.settings.prior_mode(True): | |
model, likelihood = get_model(x, torch.Tensor(), hyperparameters) | |
model.to(device) | |
d = model(x) | |
if observation_noise: | |
target_sample = sample = likelihood(d).sample().transpose(0, 1) | |
else: | |
target_sample = d.sample().transpose(0, 1) # this will be the target for the loss | |
sample = likelihood(target_sample).sample() # this will be the input to the Transformer | |
successful_sample = True | |
except RuntimeError: # This can happen when torch.linalg.eigh fails. Restart with new init resolves this. | |
print('GP Sampling unsuccessful, retrying.. ') | |
print(x) | |
print(hyperparameters) | |
if bool(torch.any(torch.isnan(x)).detach().cpu().numpy()): | |
print({"noise": hyperparameters['noise'], "outputscale": hyperparameters['outputscale'] | |
, "lengthscale": hyperparameters['lengthscale'], 'batch_size': batch_size}) | |
if hyperparameters.get('improvement_classification', False): | |
single_eval_pos = kwargs['single_eval_pos'] | |
max_so_far = sample[:single_eval_pos].max(0).values | |
sample[single_eval_pos:] = (sample > max_so_far).float()[single_eval_pos:] | |
return Batch(x=x.transpose(0, 1), y=sample, target_y=target_sample) | |
DataLoader = get_batch_to_dataloader(get_batch) | |
def get_model_on_device(x,y,hyperparameters,device): | |
model, likelihood = get_model(x, y, hyperparameters) | |
model.to(device) | |
return model, likelihood | |
def evaluate(x, y, y_non_noisy, use_mse=False, hyperparameters={}, get_model_on_device=get_model_on_device, device=default_device, step_size=1, start_pos=0): | |
start_time = time.time() | |
losses_after_t = [.0] if start_pos == 0 else [] | |
all_losses_after_t = [] | |
with gpytorch.settings.fast_computations(*hyperparameters.get('fast_computations',(True,True,True))), gpytorch.settings.fast_pred_var(False): | |
for t in range(max(start_pos, 1), len(x), step_size): | |
loss_sum = 0. | |
model, likelihood = get_model_on_device(x[:t].transpose(0, 1), y[:t].transpose(0, 1), hyperparameters, device) | |
model.eval() | |
# print([t.shape for t in model.train_inputs]) | |
# print(x[:t].transpose(0,1).shape, x[t].unsqueeze(1).shape, y[:t].transpose(0,1).shape) | |
f = model(x[t].unsqueeze(1)) | |
l = likelihood(f) | |
means = l.mean.squeeze() | |
varis = l.covariance_matrix.squeeze() | |
# print(l.variance.squeeze(), l.mean.squeeze(), y[t]) | |
assert len(means.shape) == len(varis.shape) == 1 | |
assert len(means) == len(varis) == x.shape[1] | |
if use_mse: | |
c = nn.MSELoss(reduction='none') | |
ls = c(means, y[t]) | |
else: | |
ls = -l.log_prob(y[t].unsqueeze(1)) | |
losses_after_t.append(ls.mean()) | |
all_losses_after_t.append(ls.flatten()) | |
return torch.stack(all_losses_after_t).to('cpu'), torch.tensor(losses_after_t).to('cpu'), time.time() - start_time | |
if __name__ == '__main__': | |
hps = (.1,.1,.1) | |
for redo_idx in range(1): | |
print( | |
evaluate(*get_batch(1000, 10, hyperparameters=hps, num_features=10), use_mse=False, hyperparameters=hps)) | |