Spaces:
Sleeping
Sleeping
import torch | |
from .prior import Batch | |
from ..utils import default_device | |
loaded_models = {} | |
def get_model(model_name, device): | |
if model_name not in loaded_models: | |
import submitit | |
group, index = model_name.split(':') | |
ex = submitit.get_executor() | |
model = ex.get_group(group)[int(index)].results()[0][2] | |
model.to(device) | |
loaded_models[model_name] = model | |
return loaded_models[model_name] | |
def get_batch(batch_size, seq_len, num_features, get_batch, model, single_eval_pos, epoch, device=default_device, hyperparameters=None, **kwargs): | |
""" | |
Important Assumptions: | |
'inf_batch_size', 'max_level', 'sample_only_one_level', 'eval_seq_len' and 'epochs_per_level' in hyperparameters | |
You can train a new model, based on an old one to only sample from a single level. | |
You specify `level_0_model` as a group:index string and the model will be loaded from the checkpoint. | |
:param batch_size: | |
:param seq_len: | |
:param num_features: | |
:param get_batch: | |
:param model: | |
:param single_eval_pos: | |
:param epoch: | |
:param device: | |
:param hyperparameters: | |
:param kwargs: | |
:return: | |
""" | |
if level_0_model := hyperparameters.get('level_0_model', None): | |
assert hyperparameters['sample_only_one_level'], "level_0_model only makes sense if you sample only one level" | |
assert hyperparameters['max_level'] == 1, "level_0_model only makes sense if you sample only one level" | |
level_0_model = get_model(level_0_model, device) | |
model = level_0_model | |
# the level describes how many fantasized steps are possible. This starts at 0 for the first epochs. | |
epochs_per_level = hyperparameters['epochs_per_level'] | |
share_predict_mean_distribution = hyperparameters.get('share_predict_mean_distribution', 0.) | |
use_mean_prediction = share_predict_mean_distribution or\ | |
(model.decoder_dict_once is not None and 'mean_prediction' in model.decoder_dict_once) | |
num_evals = seq_len - single_eval_pos | |
level = min(min(epoch // epochs_per_level, hyperparameters['max_level']), num_evals - 1) | |
if level_0_model: | |
level = 1 | |
eval_seq_len = hyperparameters['eval_seq_len'] | |
add_seq_len = 0 if use_mean_prediction else eval_seq_len | |
long_seq_len = seq_len + add_seq_len | |
if level_0_model: | |
styles = torch.ones(batch_size, 1, device=device, dtype=torch.long) | |
elif hyperparameters['sample_only_one_level']: | |
styles = torch.randint(level + 1, (1, 1), device=device).repeat(batch_size, 1) # styles are sorted :) | |
else: | |
styles = torch.randint(level + 1, (batch_size,1), device=device).sort(0).values # styles are sorted :) | |
predict_mean_distribution = None | |
if share_predict_mean_distribution: | |
max_used_level = max(styles) | |
# below code assumes epochs are base 0! | |
share_of_training = epoch / epochs_per_level | |
#print(share_of_training, (max_used_level + 1. - share_predict_mean_distribution), max_used_level, level, epoch) | |
predict_mean_distribution = (share_of_training >= (max_used_level + 1. - share_predict_mean_distribution)) and (max_used_level < hyperparameters['max_level']) | |
x, y, targets = [], [], [] | |
for considered_level in range(level+1): | |
num_elements = (styles == considered_level).sum() | |
if not num_elements: | |
continue | |
returns: Batch = get_batch(batch_size=num_elements, seq_len=long_seq_len, | |
num_features=num_features, device=device, | |
hyperparameters=hyperparameters, model=model, | |
single_eval_pos=single_eval_pos, epoch=epoch, | |
**kwargs) | |
levels_x, levels_y, levels_targets = returns.x, returns.y, returns.target_y | |
assert not returns.other_filled_attributes(), f"Unexpected filled attributes: {returns.other_filled_attributes()}" | |
assert levels_y is levels_targets | |
levels_targets = levels_targets.clone() | |
if len(levels_y.shape) == 2: | |
levels_y = levels_y.unsqueeze(2) | |
levels_targets = levels_targets.unsqueeze(2) | |
if considered_level > 0: | |
feed_x = levels_x[:single_eval_pos + 1 + add_seq_len].repeat(1, num_evals, 1) | |
feed_x[single_eval_pos, :] = levels_x[single_eval_pos:seq_len].reshape(-1, *levels_x.shape[2:]) | |
if not use_mean_prediction: | |
feed_x[single_eval_pos + 1:] = levels_x[seq_len:].repeat(1, num_evals, 1) | |
feed_y = levels_y[:single_eval_pos + 1 + add_seq_len].repeat(1, num_evals, 1) | |
feed_y[single_eval_pos, :] = levels_y[single_eval_pos:seq_len].reshape(-1, *levels_y.shape[2:]) | |
if not use_mean_prediction: | |
feed_y[single_eval_pos + 1:] = levels_y[seq_len:].repeat(1, num_evals, 1) | |
model.eval() | |
means = [] | |
for feed_x_b, feed_y_b in zip(torch.split(feed_x, hyperparameters['inf_batch_size'], dim=1), | |
torch.split(feed_y, hyperparameters['inf_batch_size'], dim=1)): | |
with torch.cuda.amp.autocast(): | |
style = torch.zeros(feed_x_b.shape[1], 1, dtype=torch.int64, device=device) + considered_level - 1 | |
if level_0_model is not None and level_0_model.style_encoder is None: | |
style = None | |
out = model( | |
(style, feed_x_b, feed_y_b), | |
single_eval_pos=single_eval_pos+1, only_return_standard_out=False | |
) | |
if isinstance(out, tuple): | |
output, once_output = out | |
else: | |
output = out | |
once_output = {} | |
if once_output and 'mean_prediction' in once_output: | |
mean_pred_logits = once_output['mean_prediction'].float() | |
assert tuple(mean_pred_logits.shape) == (feed_x_b.shape[1], model.criterion.num_bars),\ | |
f"{tuple(mean_pred_logits.shape)} vs {(feed_x_b.shape[1], model.criterion.num_bars)}" | |
means.append(model.criterion.icdf(mean_pred_logits, 1.-1./eval_seq_len)) | |
else: | |
logits = output['standard'].float() | |
means.append(model.criterion.mean(logits).max(0).values) | |
means = torch.cat(means, 0) | |
levels_targets_new = means.view(seq_len-single_eval_pos, *levels_y.shape[1:]) | |
levels_targets[single_eval_pos:seq_len] = levels_targets_new #- levels_targets_new.mean(0) | |
model.train() | |
levels_x = levels_x[:seq_len] | |
levels_y = levels_y[:seq_len] | |
levels_targets = levels_targets[:seq_len] | |
x.append(levels_x) | |
y.append(levels_y) | |
targets.append(levels_targets) | |
x = torch.cat(x, 1) | |
# if predict_mean_distribution: print(f'predict mean dist in b, {epoch=}, {max_used_level=}') | |
return Batch(x=x, y=torch.cat(y, 1), target_y=torch.cat(targets, 1), style=styles, | |
mean_prediction=predict_mean_distribution.item() if predict_mean_distribution is not None else None) | |