PFNEngineeringConstrainedBO / Rosen_PFN4BO.py
rosenyu's picture
Upload 529 files
165ee00 verified
raw
history blame
14.1 kB
import contextlib
import torch
import scipy
import math
from sklearn.preprocessing import power_transform, PowerTransformer, StandardScaler
from torchvision.transforms.functional import to_tensor
from pfns4bo import transformer
from pfns4bo import bar_distribution
import torch
import numpy as np
import pfns4bo
from pfns4bo.scripts.acquisition_functions import TransformerBOMethod
import warnings
warnings.filterwarnings('ignore')
device = torch.device("cpu")
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.float32
from sklearn.utils import resample
@torch.enable_grad()
def Rosen_PFN(model_name,
trained_X,
trained_Y,
X_pen,
trasform_type,
what_do_you_want
):
PFN = TransformerBOMethod(torch.load(model_name).requires_grad_(False), device=device)
# X_pen.requires_grad_(True)
# with torch.no_grad():
dim = trained_X.shape[1]
x_given = trained_X
x_eval = X_pen
x_predict = torch.cat([x_given, x_eval], dim=0)
x_full_feed = torch.cat([x_given, x_given, x_eval], dim=0).unsqueeze(1)
if trasform_type== 'std':
pt = StandardScaler()
pt.fit(trained_Y)
PT_trained_Y = pt.transform(trained_Y)
trained_Y = to_tensor(PT_trained_Y).to(torch.float32).reshape(trained_Y.shape)
elif trasform_type== 'power':
pt = PowerTransformer(method="yeo-johnson")
pt.fit(trained_Y.detach().numpy())
# PT_trained_Y = pt.transform(trained_Y.detach().numpy())
# trained_Y = to_tensor(PT_trained_Y).to(torch.float32).reshape(trained_Y.shape)
# print(trained_Y.shape)
# print(trained_Y)
trained_Y, _ = general_power_transform(trained_Y,
trained_Y,
.0,
less_safe=False) #.squeeze(1)
# print(trained_Y.shape)
# .squeeze(1)
# y_given = general_power_transform(y_given.unsqueeze(1),
# y_given.unsqueeze(1),
# .0,
# less_safe=False).squeeze(1)
y_given = trained_Y
y_given = y_given.reshape(-1)
y_full_feed = y_given.unsqueeze(1)
criterion: bar_distribution.BarDistribution = PFN.model.criterion
style = None
logits = PFN.model(
(style,
x_full_feed.repeat_interleave(dim=1, repeats=y_full_feed.shape[1]),
y_full_feed.repeat(1,x_full_feed.shape[1])),
single_eval_pos=len(x_given)
)
# logits = logits.softmax(-1).log_()
logits = logits.softmax(-1).log()
logits_given = logits[:len(x_given)]
logits_eval = logits[len(x_given):]
best_f = torch.max(y_given)
if what_do_you_want == 'mean':
output = criterion.mean(logits_eval)
if trasform_type== 'std' or trasform_type== 'power':
if pt.standardize:
XX = output.clone()
scale = torch.from_numpy(pt._scaler.scale_)
std_mean = torch.from_numpy(pt._scaler.mean_)
XX = torch_std_inverse_transform(XX, scale, std_mean)
for i, lmbda in enumerate(pt.lambdas_):
with np.errstate(invalid="ignore"): # hide NaN warnings
XX = torch_power_inverse_transform(XX, lmbda)
# print(XX)
return XX
# output = pt.inverse_transform(output)
# output = torch.from_numpy(output)
elif what_do_you_want == 'ei':
output = criterion.ei(logits_eval, best_f)
elif what_do_you_want == 'ucb':
acq_function = criterion.ucb
ucb_rest_prob = .05
if ucb_rest_prob is not None:
acq_function = lambda *args: criterion.ucb(*args, rest_prob=ucb_rest_prob)
output = acq_ensembling(acq_function(logits_eval, best_f))
elif what_do_you_want == 'variance':
output = criterion.variance(logits_eval)
elif what_do_you_want == 'mode':
output = criterion.mode(logits_eval)
elif what_do_you_want == 'ts':
mn = criterion.mean(logits_eval)
if trasform_type== 'std' or trasform_type== 'power':
if pt.standardize:
XX = mn.clone()
scale = torch.from_numpy(pt._scaler.scale_)
std_mean = torch.from_numpy(pt._scaler.mean_)
XX = torch_std_inverse_transform(XX, scale, std_mean)
for i, lmbda in enumerate(pt.lambdas_):
with np.errstate(invalid="ignore"): # hide NaN warnings
XX = torch_power_inverse_transform(XX, lmbda)
var = criterion.variance(logits_eval)
return XX, var
return output
def Rosen_PFN_Parallel(model_name,
trained_X,
trained_Y,
GX,
X_pen,
trasform_type,
what_do_you_want
):
PFN = TransformerBOMethod(torch.load(model_name), device=device)
with torch.no_grad():
dim = trained_X.shape[1]
x_given = trained_X
x_eval = X_pen
x_predict = torch.cat([x_given, x_eval], dim=0)
x_full_feed = torch.cat([x_given, x_given, x_eval], dim=0).unsqueeze(1)
y_given = trained_Y
y_given = y_given.reshape(-1)
######################################################################
# Objective Power Transform
y_given, pt_y = general_power_transform(y_given.unsqueeze(1),
y_given.unsqueeze(1),
.0,
less_safe=False)
y_given = y_given.squeeze(1)
######################################################################
######################################################################
# Constraints Power Transform
# Changes for Parallel:
GX = -GX
GX_t, pt_GX = general_power_transform(GX, GX, .0, less_safe=False)
G_thres, _ = general_power_transform(GX,
torch.zeros((1, GX.shape[1])).to(GX.device),
.0,
less_safe=False)
GX = GX_t
######################################################################
y_full_feed = y_given.unsqueeze(1)
criterion: bar_distribution.BarDistribution = PFN.model.criterion
style = None
logits = PFN.model(
(style,
x_full_feed.repeat_interleave(dim=1, repeats=y_full_feed.shape[1]+GX.shape[1]),
torch.cat([y_full_feed, GX], dim=1).unsqueeze(2) ),
single_eval_pos=len(x_given)
)
logits = logits.softmax(-1).log_()
logits_given = logits[:len(x_given)]
logits_eval = logits[len(x_given):]
best_f = torch.max(y_given)
objective_given = logits_given[:,0,:].unsqueeze(1)
objective_eval = logits_eval[:,0,:].unsqueeze(1)
constraint_given = logits_given[:,1:,:]
constraint_eval = logits_eval[:,1:,:]
if what_do_you_want == 'mean':
obj_output = criterion.mean(objective_eval)
con_output = criterion.mean(constraint_eval)
elif what_do_you_want == 'ei':
# Changes for CEI
# Objective
tau = torch.max(y_given)
objective_acq_value = acq_ensembling(criterion.ei(objective_eval, tau))
# Constraints
constraints_acq_value = acq_ensembling(criterion.pi(constraint_eval[:,0,:].unsqueeze(1), G_thres[0, 0].item()))
constraints_acq_value = constraints_acq_value.unsqueeze(1)
for jj in range(1,constraint_eval.shape[1]):
next_constraints_acq_value = acq_ensembling(criterion.pi(constraint_eval[:,jj,:].unsqueeze(1), G_thres[0, jj].item()))
next_constraints_acq_value = next_constraints_acq_value.unsqueeze(1)
constraints_acq_value = torch.cat([constraints_acq_value,next_constraints_acq_value], dim=1)
return objective_acq_value, constraints_acq_value
elif what_do_you_want == 'variance':
output = criterion.variance(logits_eval)
elif what_do_you_want == 'mode':
output = criterion.mode(logits_eval)
elif what_do_you_want == 'cts':
obj_mnn = criterion.mean(objective_eval)
obj_mnn = pt_y.inverse_transform(obj_mnn)
obj_mnn = torch.from_numpy(obj_mnn)
con_mnn = criterion.mean(constraint_eval)
con_mnn = pt_GX.inverse_transform(con_mnn)
con_mnn = torch.from_numpy(-con_mnn)
obj_varr = criterion.variance(objective_eval)
con_varr = criterion.variance(constraint_eval)
return obj_mnn, obj_varr, con_mnn, con_varr
return output
def acq_ensembling(acq_values): # (points, ensemble dim)
return acq_values.max(1).values
def torch_std_inverse_transform(X, scale, mean):
X *= scale
X += mean
return X
def torch_power_inverse_transform(x, lmbda):
out = torch.zeros_like(x)
pos = x >= 0
# when x >= 0
if abs(lmbda) < np.spacing(1.0):
out[pos] = torch.exp(x[pos])-1
else: # lmbda != 0
out[pos] = torch.pow(x[pos] * lmbda + 1, 1 / lmbda) - 1
# when x < 0
if abs(lmbda - 2) > np.spacing(1.0):
out[~pos] = 1 - torch.pow(-(2 - lmbda) * x[~pos] + 1, 1 / (2 - lmbda))
else: # lmbda == 2
out[~pos] = 1 - torch.exp(-x[~pos])
return out
################################################################################
## PFN defined functions
################################################################################
def log01(x, eps=.0000001, input_between_zero_and_one=False):
logx = torch.log(x + eps)
if input_between_zero_and_one:
return (logx - math.log(eps)) / (math.log(1 + eps) - math.log(eps))
return (logx - logx.min(0)[0]) / (logx.max(0)[0] - logx.min(0)[0])
def log01_batch(x, eps=.0000001, input_between_zero_and_one=False):
x = x.repeat(1, x.shape[-1] + 1, 1)
for b in range(x.shape[-1]):
x[:, b, b] = log01(x[:, b, b], eps=eps, input_between_zero_and_one=input_between_zero_and_one)
return x
def lognormed_batch(x, eval_pos, eps=.0000001):
x = x.repeat(1, x.shape[-1] + 1, 1)
for b in range(x.shape[-1]):
logx = torch.log(x[:, b, b]+eps)
x[:, b, b] = (logx - logx[:eval_pos].mean(0))/logx[:eval_pos].std(0)
return x
def _rank_transform(x_train, x):
assert len(x_train.shape) == len(x.shape) == 1
relative_to = torch.cat((torch.zeros_like(x_train[:1]),x_train.unique(sorted=True,), torch.ones_like(x_train[-1:])),-1)
higher_comparison = (relative_to < x[...,None]).sum(-1).clamp(min=1)
pos_inside_interval = (x - relative_to[higher_comparison-1])/(relative_to[higher_comparison] - relative_to[higher_comparison-1])
x_transformed = higher_comparison - 1 + pos_inside_interval
return x_transformed/(len(relative_to)-1.)
def rank_transform(x_train, x):
assert x.shape[1] == x_train.shape[1], f"{x.shape=} and {x_train.shape=}"
# make sure everything is between 0 and 1
assert (x_train >= 0.).all() and (x_train <= 1.).all(), f"{x_train=}"
assert (x >= 0.).all() and (x <= 1.).all(), f"{x=}"
return_x = x.clone()
for feature_dim in range(x.shape[1]):
return_x[:, feature_dim] = _rank_transform(x_train[:, feature_dim], x[:, feature_dim])
return return_x
def general_power_transform(x_train, x_apply, eps, less_safe=False):
# print('in function')
# print(x_train)
# print(x_apply)
# print('in function')
if eps > 0:
try:
pt = PowerTransformer(method='box-cox')
pt.fit(x_train.cpu()+eps)
x_out = torch.tensor(pt.transform(x_apply.cpu()+eps), dtype=x_apply.dtype, device=x_apply.device)
except Exception as e:
print(e)
x_out = x_apply - x_train.mean(0)
print(x_train)
print(x_out)
else:
pt = PowerTransformer(method='yeo-johnson')
if not less_safe and (x_train.std() > 1_000 or x_train.mean().abs() > 1_000):
x_apply = (x_apply - x_train.mean(0)) / x_train.std(0)
x_train = (x_train - x_train.mean(0)) / x_train.std(0)
# print('inputs are LAARGEe, normalizing them')
try:
pt.fit(x_train.cpu().double())
# except ValueError as e:
except Exception as e:
# print(x_train)
# print('caught this errrr', e)
if less_safe:
x_train = (x_train - x_train.mean(0)) / x_train.std(0)
x_apply = (x_apply - x_train.mean(0)) / x_train.std(0)
else:
x_train = x_train - x_train.mean(0)
x_apply = x_apply - x_train.mean(0)
# print(x_train)
pt.fit(x_train.cpu().double())
# print(x_train)
x_out = torch.tensor(pt.transform(x_apply.cpu()), dtype=x_apply.dtype, device=x_apply.device)
if torch.isnan(x_out).any() or torch.isinf(x_out).any():
print('WARNING: power transform failed')
print(f"{x_train=} and {x_apply=}")
x_out = x_apply - x_train.mean(0)
return x_out, pt