ccolas's picture
Upload 174 files
93c029f
import torch; torch.manual_seed(0)
import torch.utils
from torch.utils.data import DataLoader
import torch.distributions
import torch.nn as nn
import matplotlib.pyplot as plt; plt.rcParams['figure.dpi'] = 200
from src.cocktails.representation_learning.dataset import MyDataset, get_representation_from_ingredient, get_max_n_ingredients
import json
import pandas as pd
import numpy as np
import os
from src.cocktails.representation_learning.vae_model import get_vae_model
from src.cocktails.config import COCKTAILS_CSV_DATA, FULL_COCKTAIL_REP_PATH, EXPERIMENT_PATH
from src.cocktails.utilities.cocktail_utilities import get_bunch_of_rep_keys
from src.cocktails.utilities.ingredients_utilities import ingredient_profiles
from resource import getrusage
from resource import RUSAGE_SELF
import gc
gc.collect(2)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
def get_params():
data = pd.read_csv(COCKTAILS_CSV_DATA)
max_ingredients, ingredient_set, liquor_set, liqueur_set = get_max_n_ingredients(data)
num_ingredients = len(ingredient_set)
rep_keys = get_bunch_of_rep_keys()['custom']
ing_keys = [k.split(' ')[1] for k in rep_keys]
ing_keys.remove('volume')
nb_ing_categories = len(set(ingredient_profiles['type']))
category_encodings = dict(zip(sorted(set(ingredient_profiles['type'])), np.eye(nb_ing_categories)))
params = dict(trial_id='test',
save_path=EXPERIMENT_PATH + "/deepset_vae/",
nb_epochs=2000,
print_every=50,
plot_every=100,
batch_size=64,
lr=0.001,
dropout=0.,
nb_epoch_switch_beta=600,
latent_dim=10,
beta_vae=0.2,
ing_keys=ing_keys,
nb_ingredients=len(ingredient_set),
hidden_dims_ingredients=[128],
hidden_dims_cocktail=[32],
hidden_dims_decoder=[32],
agg='mean',
activation='relu',
auxiliaries_dict=dict(categories=dict(weight=0, type='classif', final_activ=None, dim_output=len(set(data['subcategory']))),
glasses=dict(weight=0, type='classif', final_activ=None, dim_output=len(set(data['glass']))),
prep_type=dict(weight=0, type='classif', final_activ=None, dim_output=len(set(data['category']))),
cocktail_reps=dict(weight=0, type='regression', final_activ=None, dim_output=13),
volume=dict(weight=0, type='regression', final_activ='relu', dim_output=1),
taste_reps=dict(weight=0, type='regression', final_activ='relu', dim_output=2),
ingredients_presence=dict(weight=0, type='multiclassif', final_activ=None, dim_output=num_ingredients)),
category_encodings=category_encodings
)
# params = dict(trial_id='test',
# save_path=EXPERIMENT_PATH + "/deepset_vae/",
# nb_epochs=1000,
# print_every=50,
# plot_every=100,
# batch_size=64,
# lr=0.001,
# dropout=0.,
# nb_epoch_switch_beta=500,
# latent_dim=64,
# beta_vae=0.3,
# ing_keys=ing_keys,
# nb_ingredients=len(ingredient_set),
# hidden_dims_ingredients=[128],
# hidden_dims_cocktail=[128, 128],
# hidden_dims_decoder=[128, 128],
# agg='mean',
# activation='mish',
# auxiliaries_dict=dict(categories=dict(weight=0.5, type='classif', final_activ=None, dim_output=len(set(data['subcategory']))),
# glasses=dict(weight=0.03, type='classif', final_activ=None, dim_output=len(set(data['glass']))),
# prep_type=dict(weight=0.02, type='classif', final_activ=None, dim_output=len(set(data['category']))),
# cocktail_reps=dict(weight=1, type='regression', final_activ=None, dim_output=13),
# volume=dict(weight=1, type='regression', final_activ='relu', dim_output=1),
# taste_reps=dict(weight=1, type='regression', final_activ='relu', dim_output=2),
# ingredients_presence=dict(weight=1.5, type='multiclassif', final_activ=None, dim_output=num_ingredients)),
# category_encodings=category_encodings
# )
water_rep, indexes_to_normalize = get_representation_from_ingredient(ingredients=['water'], quantities=[1],
max_q_per_ing=dict(zip(ingredient_set, [1] * num_ingredients)), index=0,
params=params)
dim_rep_ingredient = water_rep.size
params['indexes_ing_to_normalize'] = indexes_to_normalize
params['deepset_latent_dim'] = dim_rep_ingredient * max_ingredients
params['input_dim'] = dim_rep_ingredient
params['dim_rep_ingredient'] = dim_rep_ingredient
params = compute_expe_name_and_save_path(params)
del params['category_encodings'] # to dump
with open(params['save_path'] + 'params.json', 'w') as f:
json.dump(params, f)
params = complete_params(params)
return params
def complete_params(params):
data = pd.read_csv(COCKTAILS_CSV_DATA)
cocktail_reps = np.loadtxt(FULL_COCKTAIL_REP_PATH)
nb_ing_categories = len(set(ingredient_profiles['type']))
category_encodings = dict(zip(sorted(set(ingredient_profiles['type'])), np.eye(nb_ing_categories)))
params['cocktail_reps'] = cocktail_reps
params['raw_data'] = data
params['category_encodings'] = category_encodings
return params
def compute_losses_and_accuracies(loss_functions, auxiliaries, auxiliaries_str, outputs, data):
losses = dict()
accuracies = dict()
other_metrics = dict()
for i_k, k in enumerate(auxiliaries_str):
# get ground truth
# compute loss
if k == 'volume':
outputs[i_k] = outputs[i_k].flatten()
ground_truth = auxiliaries[k]
if ground_truth.dtype == torch.float64:
losses[k] = loss_functions[k](outputs[i_k], ground_truth.float()).float()
elif ground_truth.dtype == torch.int64:
if str(loss_functions[k]) != "BCEWithLogitsLoss()":
losses[k] = loss_functions[k](outputs[i_k].float(), ground_truth.long()).float()
else:
losses[k] = loss_functions[k](outputs[i_k].float(), ground_truth.float()).float()
else:
losses[k] = loss_functions[k](outputs[i_k], ground_truth).float()
# compute accuracies
if str(loss_functions[k]) == 'CrossEntropyLoss()':
bs, n_options = outputs[i_k].shape
predicted = outputs[i_k].argmax(dim=1).detach().numpy()
true = ground_truth.int().detach().numpy()
confusion_matrix = np.zeros([n_options, n_options])
for i in range(bs):
confusion_matrix[true[i], predicted[i]] += 1
acc = confusion_matrix.diagonal().sum() / bs
for i in range(n_options):
if confusion_matrix[i].sum() != 0:
confusion_matrix[i] /= confusion_matrix[i].sum()
other_metrics[k + '_confusion'] = confusion_matrix
accuracies[k] = np.mean(outputs[i_k].argmax(dim=1).detach().numpy() == ground_truth.int().detach().numpy())
assert (acc - accuracies[k]) < 1e-5
elif str(loss_functions[k]) == 'BCEWithLogitsLoss()':
assert k == 'ingredients_presence'
outputs_rescaled = outputs[i_k].detach().numpy() * data.dataset.std_ing_quantities + data.dataset.mean_ing_quantities
predicted_presence = (outputs_rescaled > 0).astype(bool)
presence = ground_truth.detach().numpy().astype(bool)
other_metrics[k + '_false_positive'] = np.mean(np.logical_and(predicted_presence.astype(bool), ~presence.astype(bool)))
other_metrics[k + '_false_negative'] = np.mean(np.logical_and(~predicted_presence.astype(bool), presence.astype(bool)))
accuracies[k] = np.mean(predicted_presence == presence) # accuracy for multi class labeling
elif str(loss_functions[k]) == 'MSELoss()':
accuracies[k] = np.nan
else:
raise ValueError
return losses, accuracies, other_metrics
def compute_metric_output(aux_other_metrics, data, ingredient_quantities, x_hat):
ing_q = ingredient_quantities.detach().numpy() * data.dataset.std_ing_quantities + data.dataset.mean_ing_quantities
ing_presence = (ing_q > 0)
x_hat = x_hat.detach().numpy() * data.dataset.std_ing_quantities + data.dataset.mean_ing_quantities
# abs_diff = np.abs(ing_q - x_hat) * data.dataset.max_ing_quantities
abs_diff = np.abs(ing_q - x_hat)
ing_q_abs_loss_when_present, ing_q_abs_loss_when_absent = [], []
for i in range(ingredient_quantities.shape[0]):
ing_q_abs_loss_when_present.append(np.mean(abs_diff[i, np.where(ing_presence[i])]))
ing_q_abs_loss_when_absent.append(np.mean(abs_diff[i, np.where(~ing_presence[i])]))
aux_other_metrics['ing_q_abs_loss_when_present'] = np.mean(ing_q_abs_loss_when_present)
aux_other_metrics['ing_q_abs_loss_when_absent'] = np.mean(ing_q_abs_loss_when_absent)
return aux_other_metrics
def run_epoch(opt, train, model, data, loss_functions, weights, params):
if train:
model.train()
else:
model.eval()
# prepare logging of losses
losses = dict(kld_loss=[],
mse_loss=[],
vae_loss=[],
volume_loss=[],
global_loss=[])
accuracies = dict()
other_metrics = dict()
for aux in params['auxiliaries_dict'].keys():
losses[aux] = []
accuracies[aux] = []
if train: opt.zero_grad()
for d in data:
nb_ingredients = d[0]
batch_size = nb_ingredients.shape[0]
x_ingredients = d[1].float()
ingredient_quantities = d[2]
cocktail_reps = d[3]
auxiliaries = d[4]
for k in auxiliaries.keys():
if auxiliaries[k].dtype == torch.float64: auxiliaries[k] = auxiliaries[k].float()
taste_valid = d[-1]
x = x_ingredients.to(device)
x_hat, z, mean, log_var, outputs, auxiliaries_str = model.forward_direct(ingredient_quantities.float())
# get auxiliary losses and accuracies
aux_losses, aux_accuracies, aux_other_metrics = compute_losses_and_accuracies(loss_functions, auxiliaries, auxiliaries_str, outputs, data)
# compute vae loss
mse_loss = ((ingredient_quantities - x_hat) ** 2).mean().float()
kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mean ** 2 - log_var.exp(), dim=1)).float()
vae_loss = mse_loss + params['beta_vae'] * (params['latent_dim'] / params['nb_ingredients']) * kld_loss
# compute total volume loss to train decoder
# volume_loss = ((ingredient_quantities.sum(dim=1) - x_hat.sum(dim=1)) ** 2).mean().float()
volume_loss = torch.FloatTensor([0])
aux_other_metrics = compute_metric_output(aux_other_metrics, data, ingredient_quantities, x_hat)
indexes_taste_valid = np.argwhere(taste_valid.detach().numpy()).flatten()
if indexes_taste_valid.size > 0:
outputs_taste = model.get_auxiliary(z[indexes_taste_valid], aux_str='taste_reps')
gt = auxiliaries['taste_reps'][indexes_taste_valid]
factor_loss = indexes_taste_valid.size / (0.3 * batch_size)# factor on the loss: if same ratio as actual dataset factor = 1 if there is less data, then the factor decreases, more data, it increases
aux_losses['taste_reps'] = (loss_functions['taste_reps'](outputs_taste, gt) * factor_loss).float()
else:
aux_losses['taste_reps'] = torch.FloatTensor([0]).reshape([])
aux_accuracies['taste_reps'] = 0
# aggregate losses
global_loss = torch.sum(torch.cat([torch.atleast_1d(vae_loss), torch.atleast_1d(volume_loss)] + [torch.atleast_1d(aux_losses[k] * weights[k]) for k in params['auxiliaries_dict'].keys()]))
# for k in params['auxiliaries_dict'].keys():
# global_loss += aux_losses[k] * weights[k]
if train:
global_loss.backward()
opt.step()
opt.zero_grad()
# logging
losses['global_loss'].append(float(global_loss))
losses['mse_loss'].append(float(mse_loss))
losses['vae_loss'].append(float(vae_loss))
losses['volume_loss'].append(float(volume_loss))
losses['kld_loss'].append(float(kld_loss))
for k in params['auxiliaries_dict'].keys():
losses[k].append(float(aux_losses[k]))
accuracies[k].append(float(aux_accuracies[k]))
for k in aux_other_metrics.keys():
if k not in other_metrics.keys():
other_metrics[k] = [aux_other_metrics[k]]
else:
other_metrics[k].append(aux_other_metrics[k])
for k in losses.keys():
losses[k] = np.mean(losses[k])
for k in accuracies.keys():
accuracies[k] = np.mean(accuracies[k])
for k in other_metrics.keys():
other_metrics[k] = np.mean(other_metrics[k], axis=0)
return model, losses, accuracies, other_metrics
def prepare_data_and_loss(params):
train_data = MyDataset(split='train', params=params)
test_data = MyDataset(split='test', params=params)
train_data_loader = DataLoader(train_data, batch_size=params['batch_size'], shuffle=True)
test_data_loader = DataLoader(test_data, batch_size=params['batch_size'], shuffle=True)
loss_functions = dict()
weights = dict()
for k in sorted(params['auxiliaries_dict'].keys()):
if params['auxiliaries_dict'][k]['type'] == 'classif':
if k == 'glasses':
classif_weights = train_data.glasses_weights
elif k == 'prep_type':
classif_weights = train_data.prep_types_weights
elif k == 'categories':
classif_weights = train_data.categories_weights
else:
raise ValueError
loss_functions[k] = nn.CrossEntropyLoss(torch.FloatTensor(classif_weights))
elif params['auxiliaries_dict'][k]['type'] == 'multiclassif':
loss_functions[k] = nn.BCEWithLogitsLoss()
elif params['auxiliaries_dict'][k]['type'] == 'regression':
loss_functions[k] = nn.MSELoss()
else:
raise ValueError
weights[k] = params['auxiliaries_dict'][k]['weight']
return loss_functions, train_data_loader, test_data_loader, weights
def print_losses(train, losses, accuracies, other_metrics):
keyword = 'Train' if train else 'Eval'
print(f'\t{keyword} logs:')
keys = ['global_loss', 'vae_loss', 'mse_loss', 'kld_loss', 'volume_loss']
for k in keys:
print(f'\t\t{k} - Loss: {losses[k]:.2f}')
for k in sorted(accuracies.keys()):
print(f'\t\t{k} (aux) - Loss: {losses[k]:.2f}, Acc: {accuracies[k]:.2f}')
for k in sorted(other_metrics.keys()):
if 'confusion' not in k:
print(f'\t\t{k} - {other_metrics[k]:.2f}')
def run_experiment(params, verbose=True):
loss_functions, train_data_loader, test_data_loader, weights = prepare_data_and_loss(params)
params['filter_decoder_output'] = train_data_loader.dataset.filter_decoder_output
model_params = [params[k] for k in ["input_dim", "deepset_latent_dim", "hidden_dims_ingredients", "activation",
"hidden_dims_cocktail", "hidden_dims_decoder", "nb_ingredients", "latent_dim", "agg", "dropout", "auxiliaries_dict",
"filter_decoder_output"]]
model = get_vae_model(*model_params)
opt = torch.optim.AdamW(model.parameters(), lr=params['lr'])
all_train_losses = []
all_eval_losses = []
all_train_accuracies = []
all_eval_accuracies = []
all_eval_other_metrics = []
all_train_other_metrics = []
best_loss = np.inf
model, eval_losses, eval_accuracies, eval_other_metrics = run_epoch(opt=opt, train=False, model=model, data=test_data_loader, loss_functions=loss_functions,
weights=weights, params=params)
all_eval_losses.append(eval_losses)
all_eval_accuracies.append(eval_accuracies)
all_eval_other_metrics.append(eval_other_metrics)
if verbose: print(f'\n--------\nEpoch #0')
if verbose: print_losses(train=False, accuracies=eval_accuracies, losses=eval_losses, other_metrics=eval_other_metrics)
for epoch in range(params['nb_epochs']):
if verbose and (epoch + 1) % params['print_every'] == 0: print(f'\n--------\nEpoch #{epoch+1}')
model, train_losses, train_accuracies, train_other_metrics = run_epoch(opt=opt, train=True, model=model, data=train_data_loader, loss_functions=loss_functions,
weights=weights, params=params)
if verbose and (epoch + 1) % params['print_every'] == 0: print_losses(train=True, accuracies=train_accuracies, losses=train_losses, other_metrics=train_other_metrics)
model, eval_losses, eval_accuracies, eval_other_metrics = run_epoch(opt=opt, train=False, model=model, data=test_data_loader, loss_functions=loss_functions,
weights=weights, params=params)
if verbose and (epoch + 1) % params['print_every'] == 0: print_losses(train=False, accuracies=eval_accuracies, losses=eval_losses, other_metrics=eval_other_metrics)
if eval_losses['global_loss'] < best_loss:
best_loss = eval_losses['global_loss']
if verbose: print(f'Saving new best model with loss {best_loss:.2f}')
torch.save(model.state_dict(), params['save_path'] + f'checkpoint_best.save')
# log
all_train_losses.append(train_losses)
all_train_accuracies.append(train_accuracies)
all_eval_losses.append(eval_losses)
all_eval_accuracies.append(eval_accuracies)
all_eval_other_metrics.append(eval_other_metrics)
all_train_other_metrics.append(train_other_metrics)
# if epoch == params['nb_epoch_switch_beta']:
# params['beta_vae'] = 2.5
# params['auxiliaries_dict']['prep_type']['weight'] /= 10
# params['auxiliaries_dict']['glasses']['weight'] /= 10
if (epoch + 1) % params['plot_every'] == 0:
plot_results(all_train_losses, all_train_accuracies, all_train_other_metrics,
all_eval_losses, all_eval_accuracies, all_eval_other_metrics, params['plot_path'], weights)
return model
def plot_results(all_train_losses, all_train_accuracies, all_train_other_metrics,
all_eval_losses, all_eval_accuracies, all_eval_other_metrics, plot_path, weights):
steps = np.arange(len(all_eval_accuracies))
loss_keys = sorted(all_train_losses[0].keys())
acc_keys = sorted(all_train_accuracies[0].keys())
metrics_keys = sorted(all_train_other_metrics[0].keys())
plt.figure()
plt.title('Train losses')
for k in loss_keys:
factor = 1 if k == 'mse_loss' else 1
if k not in weights.keys():
plt.plot(steps[1:], [train_loss[k] * factor for train_loss in all_train_losses], label=k)
else:
if weights[k] != 0:
plt.plot(steps[1:], [train_loss[k] * factor for train_loss in all_train_losses], label=k)
plt.legend()
plt.ylim([0, 4])
plt.savefig(plot_path + 'train_losses.png', dpi=200)
fig = plt.gcf()
plt.close(fig)
plt.figure()
plt.title('Train accuracies')
for k in acc_keys:
if weights[k] != 0:
plt.plot(steps[1:], [train_acc[k] for train_acc in all_train_accuracies], label=k)
plt.legend()
plt.ylim([0, 1])
plt.savefig(plot_path + 'train_acc.png', dpi=200)
fig = plt.gcf()
plt.close(fig)
plt.figure()
plt.title('Train other metrics')
for k in metrics_keys:
if 'confusion' not in k and 'presence' in k:
plt.plot(steps[1:], [train_metric[k] for train_metric in all_train_other_metrics], label=k)
plt.legend()
plt.ylim([0, 1])
plt.savefig(plot_path + 'train_ing_presence_errors.png', dpi=200)
fig = plt.gcf()
plt.close(fig)
plt.figure()
plt.title('Train other metrics')
for k in metrics_keys:
if 'confusion' not in k and 'presence' not in k:
plt.plot(steps[1:], [train_metric[k] for train_metric in all_train_other_metrics], label=k)
plt.legend()
plt.savefig(plot_path + 'train_ing_q_error.png', dpi=200)
fig = plt.gcf()
plt.close(fig)
plt.figure()
plt.title('Eval losses')
for k in loss_keys:
factor = 1 if k == 'mse_loss' else 1
if k not in weights.keys():
plt.plot(steps, [eval_loss[k] * factor for eval_loss in all_eval_losses], label=k)
else:
if weights[k] != 0:
plt.plot(steps, [eval_loss[k] * factor for eval_loss in all_eval_losses], label=k)
plt.legend()
plt.ylim([0, 4])
plt.savefig(plot_path + 'eval_losses.png', dpi=200)
fig = plt.gcf()
plt.close(fig)
plt.figure()
plt.title('Eval accuracies')
for k in acc_keys:
if weights[k] != 0:
plt.plot(steps, [eval_acc[k] for eval_acc in all_eval_accuracies], label=k)
plt.legend()
plt.ylim([0, 1])
plt.savefig(plot_path + 'eval_acc.png', dpi=200)
fig = plt.gcf()
plt.close(fig)
plt.figure()
plt.title('Eval other metrics')
for k in metrics_keys:
if 'confusion' not in k and 'presence' in k:
plt.plot(steps, [eval_metric[k] for eval_metric in all_eval_other_metrics], label=k)
plt.legend()
plt.ylim([0, 1])
plt.savefig(plot_path + 'eval_ing_presence_errors.png', dpi=200)
fig = plt.gcf()
plt.close(fig)
plt.figure()
plt.title('Eval other metrics')
for k in metrics_keys:
if 'confusion' not in k and 'presence' not in k:
plt.plot(steps, [eval_metric[k] for eval_metric in all_eval_other_metrics], label=k)
plt.legend()
plt.savefig(plot_path + 'eval_ing_q_error.png', dpi=200)
fig = plt.gcf()
plt.close(fig)
for k in metrics_keys:
if 'confusion' in k:
plt.figure()
plt.title(k)
plt.ylabel('True')
plt.xlabel('Predicted')
plt.imshow(all_eval_other_metrics[-1][k], vmin=0, vmax=1)
plt.colorbar()
plt.savefig(plot_path + f'eval_{k}.png', dpi=200)
fig = plt.gcf()
plt.close(fig)
for k in metrics_keys:
if 'confusion' in k:
plt.figure()
plt.title(k)
plt.ylabel('True')
plt.xlabel('Predicted')
plt.imshow(all_train_other_metrics[-1][k], vmin=0, vmax=1)
plt.colorbar()
plt.savefig(plot_path + f'train_{k}.png', dpi=200)
fig = plt.gcf()
plt.close(fig)
plt.close('all')
def get_model(model_path):
with open(model_path + 'params.json', 'r') as f:
params = json.load(f)
params['save_path'] = model_path
max_ing_quantities = np.loadtxt(params['save_path'] + 'max_ing_quantities.txt')
mean_ing_quantities = np.loadtxt(params['save_path'] + 'mean_ing_quantities.txt')
std_ing_quantities = np.loadtxt(params['save_path'] + 'std_ing_quantities.txt')
min_when_present_ing_quantities = np.loadtxt(params['save_path'] + 'min_when_present_ing_quantities.txt')
def filter_decoder_output(output):
output = output.detach().numpy()
output_unnormalized = output * std_ing_quantities + mean_ing_quantities
if output.ndim == 1:
output_unnormalized[np.where(output_unnormalized < min_when_present_ing_quantities)] = 0
else:
for i in range(output.shape[0]):
output_unnormalized[i, np.where(output_unnormalized[i] < min_when_present_ing_quantities)] = 0
return output_unnormalized.copy()
params['filter_decoder_output'] = filter_decoder_output
model_chkpt = model_path + "checkpoint_best.save"
model_params = [params[k] for k in ["input_dim", "deepset_latent_dim", "hidden_dims_ingredients", "activation",
"hidden_dims_cocktail", "hidden_dims_decoder", "nb_ingredients", "latent_dim", "agg", "dropout", "auxiliaries_dict",
"filter_decoder_output"]]
model = get_vae_model(*model_params)
model.load_state_dict(torch.load(model_chkpt))
model.eval()
return model, filter_decoder_output, params
def compute_expe_name_and_save_path(params):
weights_str = '['
for aux in params['auxiliaries_dict'].keys():
weights_str += f'{params["auxiliaries_dict"][aux]["weight"]}, '
weights_str = weights_str[:-2] + ']'
save_path = params['save_path'] + params["trial_id"]
save_path += f'_lr{params["lr"]}'
save_path += f'_betavae{params["beta_vae"]}'
save_path += f'_bs{params["batch_size"]}'
save_path += f'_latentdim{params["latent_dim"]}'
save_path += f'_hding{params["hidden_dims_ingredients"]}'
save_path += f'_hdcocktail{params["hidden_dims_cocktail"]}'
save_path += f'_hddecoder{params["hidden_dims_decoder"]}'
save_path += f'_agg{params["agg"]}'
save_path += f'_activ{params["activation"]}'
save_path += f'_w{weights_str}'
counter = 0
while os.path.exists(save_path + f"_{counter}"):
counter += 1
save_path = save_path + f"_{counter}" + '/'
params["save_path"] = save_path
os.makedirs(save_path)
os.makedirs(save_path + 'plots/')
params['plot_path'] = save_path + 'plots/'
print(f'logging to {save_path}')
return params
if __name__ == '__main__':
params = get_params()
run_experiment(params)