from __future__ import annotations import itertools import time import yaml from contextlib import nullcontext from tqdm import tqdm import torch from torch import nn from torch.cuda.amp import autocast, GradScaler from . import utils from .priors import prior from . import priors from .transformer import TransformerModel from .bar_distribution import BarDistribution, FullSupportBarDistribution, get_bucket_limits, get_custom_bar_dist from .utils import get_cosine_schedule_with_warmup, get_openai_lr, StoreDictKeyPair, get_weighted_single_eval_pos_sampler, get_uniform_single_eval_pos_sampler from . import positional_encodings from .utils import init_dist class Losses(): gaussian = nn.GaussianNLLLoss(full=True, reduction='none') mse = nn.MSELoss(reduction='none') ce = lambda num_classes: nn.CrossEntropyLoss(reduction='none', weight=torch.ones(num_classes)) bce = nn.BCEWithLogitsLoss(reduction='none') get_BarDistribution = BarDistribution def train(priordataloader_class_or_get_batch: prior.PriorDataLoader | callable, criterion, encoder_generator, emsize=200, nhid=200, nlayers=6, nhead=2, dropout=0.0, epochs=10, steps_per_epoch=100, batch_size=200, seq_len=10, lr=None, weight_decay=0.0, warmup_epochs=10, input_normalization=False, y_encoder_generator=None, pos_encoder_generator=None, decoder_dict={}, extra_prior_kwargs_dict={}, scheduler=get_cosine_schedule_with_warmup, load_weights_from_this_state_dict=None, validation_period=10, single_eval_pos_gen=None, gpu_device='cuda:0', aggregate_k_gradients=1, verbose=True, style_encoder_generator=None, epoch_callback=None, step_callback=None, continue_model=None, initializer=None, initialize_with_model=None, train_mixed_precision=False, efficient_eval_masking=True, border_decoder=None , num_global_att_tokens=0, progress_bar=False, **model_extra_args): device = gpu_device if torch.cuda.is_available() else 'cpu:0' print(f'Using {device} device') using_dist, rank, device = init_dist(device) single_eval_pos_gen = single_eval_pos_gen if callable(single_eval_pos_gen) else lambda: single_eval_pos_gen if not isinstance(priordataloader_class_or_get_batch, prior.PriorDataLoader): priordataloader_class = priors.utils.get_batch_to_dataloader(priordataloader_class_or_get_batch) else: priordataloader_class = priordataloader_class_or_get_batch def eval_pos_seq_len_sampler(): single_eval_pos = single_eval_pos_gen() return single_eval_pos, seq_len dl = priordataloader_class(num_steps=steps_per_epoch, batch_size=batch_size, eval_pos_seq_len_sampler=eval_pos_seq_len_sampler, seq_len_maximum=seq_len, device=device, **extra_prior_kwargs_dict) test_batch: prior.Batch = dl.get_test_batch() style_def = test_batch.style print(f'Style definition of first 3 examples: {style_def[:3] if style_def is not None else None}') style_encoder = style_encoder_generator(style_def.shape[1], emsize) if (style_def is not None) else None pos_encoder = (pos_encoder_generator or positional_encodings.NoPositionalEncoding)(emsize, seq_len * 2) if isinstance(criterion, nn.GaussianNLLLoss): n_out = 2 elif isinstance(criterion, BarDistribution) or "BarDistribution" in criterion.__class__.__name__: # TODO remove this fix (only for dev) n_out = criterion.num_bars elif isinstance(criterion, nn.CrossEntropyLoss): n_out = criterion.weight.shape[0] else: n_out = 1 #border_decoder = None if border_decoder is None else border_decoder(emsize, criterion.num_bars + 1).to(device) if continue_model: model = continue_model else: decoder_dict = decoder_dict if decoder_dict else {'standard': (None, n_out)} decoder_once_dict = {} if test_batch.mean_prediction is not None: decoder_once_dict['mean_prediction'] = decoder_dict['standard'] encoder = encoder_generator(dl.num_features, emsize) model = TransformerModel(encoder=encoder , nhead=nhead , ninp=emsize , nhid=nhid , nlayers=nlayers , dropout=dropout , style_encoder=style_encoder , y_encoder=y_encoder_generator(1, emsize) , input_normalization=input_normalization , pos_encoder=pos_encoder , decoder_dict=decoder_dict , init_method=initializer , efficient_eval_masking=efficient_eval_masking , decoder_once_dict=decoder_once_dict , num_global_att_tokens=num_global_att_tokens , **model_extra_args ) model.criterion = criterion if load_weights_from_this_state_dict is not None: model.load_state_dict(load_weights_from_this_state_dict) if initialize_with_model is not None: model.init_from_small_model(initialize_with_model) print(f"Using a Transformer with {sum(p.numel() for p in model.parameters())/1000/1000:.{2}f} M parameters") try: for (k, v), (k2, v2) in zip(model.state_dict().items(), initialize_with_model.state_dict().items()): print(k, ((v - v2) / v).abs().mean(), v.shape) except Exception: pass model.to(device) if using_dist: print("Distributed training") model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[rank], output_device=rank, broadcast_buffers=False, find_unused_parameters=test_batch.mean_prediction is not None) dl.model = model.module # use local model, should not use multi-gpu functionality.. else: dl.model = model # learning rate if lr is None: lr = get_openai_lr(model) print(f"Using OpenAI max lr of {lr}.") optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay) scheduler = scheduler(optimizer, warmup_epochs, epochs if epochs is not None else 100) # when training for fixed time lr schedule takes 100 steps scaler = GradScaler() if train_mixed_precision else None # check that everything uses up-to-date APIs utils.check_compatibility(dl) def train_epoch(): model.train() # Turn on the train mode total_loss = 0. total_positional_losses = 0. total_positional_losses_recorded = 0 nan_steps = 0 ignore_steps = 0 before_get_batch = time.time() assert len(dl) % aggregate_k_gradients == 0, 'Please set the number of steps per epoch s.t. `aggregate_k_gradients` divides it.' tqdm_iter = tqdm(range(len(dl)), desc='Training Epoch') if rank==0 and progress_bar else None # , disable=not verbose for batch, full_data in enumerate(dl): data, targets, single_eval_pos = (full_data.style, full_data.x, full_data.y), full_data.target_y, full_data.single_eval_pos def get_metrics(): return total_loss / steps_per_epoch, ( total_positional_losses / total_positional_losses_recorded).tolist(), \ time_to_get_batch, forward_time, step_time, nan_steps.cpu().item() / (batch + 1), \ ignore_steps.cpu().item() / (batch + 1) tqdm_iter.update() if tqdm_iter is not None else None if using_dist and not (batch % aggregate_k_gradients == aggregate_k_gradients - 1): cm = model.no_sync() else: cm = nullcontext() with cm: time_to_get_batch = time.time() - before_get_batch before_forward = time.time() try: metrics_to_log = {} with autocast(enabled=scaler is not None): # If style is set to None, it should not be transferred to device out = model(tuple(e.to(device) if torch.is_tensor(e) else e for e in data), single_eval_pos=single_eval_pos, only_return_standard_out=False) # this handling is for training old models only, this can be deleted soon(ish) # to only support models that return a tuple of dicts out, output_once = out if isinstance(out, tuple) else (out, None) output = out['standard'] if isinstance(out, dict) else out forward_time = time.time() - before_forward if single_eval_pos is not None: targets = targets[single_eval_pos:] if len(targets.shape) == len(output.shape): # this implies the prior uses a trailing 1 dimesnion # below we assume this not to be the case targets = targets.squeeze(-1) assert targets.shape == output.shape[:-1], f"Target shape {targets.shape} " \ "does not match output shape {output.shape}" if isinstance(criterion, nn.GaussianNLLLoss): assert output.shape[-1] == 2, \ 'need to write a little bit of code to handle multiple regression targets at once' mean_pred = output[..., 0] var_pred = output[..., 1].abs() losses = criterion(mean_pred.flatten(), targets.flatten(), var=var_pred.flatten()) elif isinstance(criterion, (nn.MSELoss, nn.BCEWithLogitsLoss)): targets[torch.isnan(targets)] = -100 losses = criterion(output.flatten(), targets.flatten()) elif isinstance(criterion, nn.CrossEntropyLoss): targets[torch.isnan(targets)] = -100 print(f"{targets.min()=}, {targets.max()=}") losses = criterion(output.reshape(-1, n_out), targets.long().flatten()) elif border_decoder is not None: def apply_batch_wise_criterion(i): output_, targets_, borders_ = output_adaptive[:, i], targets[:, i], borders[i] criterion_ = get_custom_bar_dist(borders_, criterion).to(device) return criterion_(output_, targets_) output_adaptive, borders = out['adaptive_bar'], output_once['borders'] losses_adaptive_bar = torch.stack([apply_batch_wise_criterion(i) for i in range(output_adaptive.shape[1])], 1) losses_fixed_bar = criterion(output, targets) losses = (losses_adaptive_bar + losses_fixed_bar) / 2 metrics_to_log = {**metrics_to_log, **{'loss_fixed_bar': losses_fixed_bar.mean().cpu().detach().item(), 'loss_adaptive_bar': losses_adaptive_bar.mean().cpu().detach().item()}} elif isinstance(criterion, BarDistribution) and full_data.mean_prediction: assert 'mean_prediction' in output_once utils.print_once('Using mean prediction for loss') losses = criterion(output, targets, mean_prediction_logits=output_once['mean_prediction']) # the mean pred loss appears as the last per sequence else: losses = criterion(output, targets) losses = losses.view(-1, output.shape[1]) # sometimes the seq length can be one off # that is because bar dist appends the mean loss, nan_share = utils.torch_nanmean(losses.mean(0), return_nanshare=True) loss_scaled = loss / aggregate_k_gradients if scaler: loss_scaled = scaler.scale(loss_scaled) loss_scaled.backward() if batch % aggregate_k_gradients == aggregate_k_gradients - 1: if scaler: scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), 1.) if scaler: scaler.step(optimizer) scaler.update() else: optimizer.step() optimizer.zero_grad() step_time = time.time() - before_forward if not torch.isnan(loss): total_loss += loss.cpu().detach().item() total_positional_losses += losses.mean(1).cpu().detach() if single_eval_pos is None else \ nn.functional.one_hot(torch.tensor(single_eval_pos), seq_len)*\ utils.torch_nanmean(losses[:seq_len-single_eval_pos].mean(0)).cpu().detach() total_positional_losses_recorded += torch.ones(seq_len) if single_eval_pos is None else \ nn.functional.one_hot(torch.tensor(single_eval_pos), seq_len) metrics_to_log = {**metrics_to_log, **{f"loss": loss, "single_eval_pos": single_eval_pos}} if step_callback is not None and rank == 0: step_callback(metrics_to_log) nan_steps += nan_share ignore_steps += (targets == -100).float().mean() except Exception as e: print("Invalid step encountered, skipping...") print(e) raise(e) #total_loss, total_positional_losses, time_to_get_batch, forward_time, step_time, nan_share, ignore_share = get_metrics() if tqdm_iter: tqdm_iter.set_postfix({'data_time': time_to_get_batch, 'step_time': step_time, 'mean_loss': total_loss / (batch+1)}) before_get_batch = time.time() return get_metrics() total_loss = float('inf') total_positional_losses = float('inf') try: # Initially test the epoch callback function if epoch_callback is not None and rank == 0: epoch_callback(model, 1, data_loader=dl, scheduler=scheduler) for epoch in (range(1, epochs + 1) if epochs is not None else itertools.count(1)): epoch_start_time = time.time() try: total_loss, total_positional_losses, time_to_get_batch, forward_time, step_time, nan_share, ignore_share =\ train_epoch() except Exception as e: print("Invalid epoch encountered, skipping...") print(e) raise (e) if hasattr(dl, 'validate') and epoch % validation_period == 0: with torch.no_grad(): val_score = dl.validate(model) else: val_score = None if verbose: print('-' * 89) print( f'| end of epoch {epoch:3d} | time: {(time.time() - epoch_start_time):5.2f}s | mean loss {total_loss:5.2f} | ' f"pos losses {','.join([f'{l:5.2f}' for l in total_positional_losses])}, lr {scheduler.get_last_lr()[0]}" f' data time {time_to_get_batch:5.2f} step time {step_time:5.2f}' f' forward time {forward_time:5.2f}' f' nan share {nan_share:5.2f} ignore share (for classification tasks) {ignore_share:5.4f}' + (f'val score {val_score}' if val_score is not None else '')) print('-' * 89) # stepping with wallclock time based scheduler if epoch_callback is not None and rank == 0: epoch_callback(model, epoch, data_loader=dl, scheduler=scheduler) scheduler.step() except KeyboardInterrupt: pass if rank == 0: # trivially true for non-parallel training if isinstance(model, torch.nn.parallel.DistributedDataParallel): model = model.module dl = None return total_loss, total_positional_losses, model.to('cpu'), dl def _parse_args(config_parser, parser): # Do we have a config file to parse? args_config, remaining = config_parser.parse_known_args() if args_config.config: with open(args_config.config, 'r') as f: cfg = yaml.safe_load(f) parser.set_defaults(**cfg) # The main arg parser parses the rest of the args, the usual # defaults will have been overridden if config file specified. args = parser.parse_args(remaining) # Cache the args as a text string to save them in the output dir later args_text = yaml.safe_dump(args.__dict__, default_flow_style=False) return args, args_text