import json import logging import math import os import time from contextlib import suppress import numpy as np import torch import torch.nn.functional as F try: import wandb except ImportError: wandb = None from clap_module import LPLoss, LPMetrics, lp_gather_features from clap_module.utils import do_mixup, get_mix_lambda from .distributed import is_master from .zero_shot import zero_shot_eval class AverageMeter(object): """Computes and stores the average and current value""" def __init__(self): self.reset() def reset(self): self.val = 0 self.avg = 0 self.sum = 0 self.count = 0 def update(self, val, n=1): self.val = val self.sum += val * n self.count += n self.avg = self.sum / self.count def unwrap_model(model): if hasattr(model, "module"): return model.module else: return model def train_one_epoch( model, data, epoch, optimizer, scaler, scheduler, args, tb_writer=None, extra_suffix="" ): device = torch.device(args.device) autocast = torch.cuda.amp.autocast if args.precision == "amp" else suppress model.train() loss = LPLoss(args.lp_loss) dataloader, sampler = data["train"].dataloader, data["train"].sampler if args.distributed and sampler is not None: sampler.set_epoch(epoch) num_batches_per_epoch = dataloader.num_batches sample_digits = math.ceil(math.log(dataloader.num_samples + 1, 10)) # for toy dataset if args.dataset_type == "toy": dataloader.dataset.generate_queue() loss_m = AverageMeter() batch_time_m = AverageMeter() data_time_m = AverageMeter() end = time.time() for i, batch in enumerate(dataloader): step = num_batches_per_epoch * epoch + i if isinstance(scheduler, dict): for s in scheduler.values(): s(step) else: scheduler(step) audio = batch # contains mel_spec, wavform, and longer list class_label = batch['class_label'] # audio = audio.to(device=device, non_blocking=True) class_label = class_label.to(device=device, non_blocking=True) if args.mixup: # https://github.com/RetroCirce/HTS-Audio-Transformer/blob/main/utils.py#L146 mix_lambda = torch.from_numpy(get_mix_lambda(0.5, len(audio["waveform"]))).to(device) class_label = do_mixup(class_label, mix_lambda) else: mix_lambda = None data_time_m.update(time.time() - end) if isinstance(optimizer, dict): for o_ in optimizer.values(): o_.zero_grad() else: optimizer.zero_grad() with autocast(): pred = model(audio, mix_lambda=mix_lambda, device=device) total_loss = loss(pred, class_label) if isinstance(optimizer, dict): if scaler is not None: scaler.scale(total_loss).backward() for o_ in optimizer.values(): if args.horovod: o_.synchronize() scaler.unscale_(o_) with o_.skip_synchronize(): scaler.step(o_) else: scaler.step(o_) scaler.update() else: total_loss.backward() for o_ in optimizer.values(): o_.step() else: if scaler is not None: scaler.scale(total_loss).backward() if args.horovod: optimizer.synchronize() scaler.unscale_(optimizer) with optimizer.skip_synchronize(): scaler.step(optimizer) else: scaler.step(optimizer) scaler.update() else: total_loss.backward() optimizer.step() # Note: we clamp to 4.6052 = ln(100), as in the original paper. with torch.no_grad(): unwrap_model(model).clap_model.logit_scale_a.clamp_(0, math.log(100)) unwrap_model(model).clap_model.logit_scale_t.clamp_(0, math.log(100)) batch_time_m.update(time.time() - end) end = time.time() batch_count = i + 1 if is_master(args) and (i % 100 == 0 or batch_count == num_batches_per_epoch): if isinstance(audio, dict): batch_size = len(audio["waveform"]) else: batch_size = len(audio) num_samples = batch_count * batch_size * args.world_size samples_per_epoch = dataloader.num_samples percent_complete = 100.0 * batch_count / num_batches_per_epoch # NOTE loss is coarsely sampled, just master node and per log update loss_m.update(total_loss.item(), batch_size) if isinstance(optimizer, dict): logging.info( f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] " f"Loss: {loss_m.val:#.5g} ({loss_m.avg:#.4g}) " f"Data (t): {data_time_m.avg:.3f} " f"Batch (t): {batch_time_m.avg:.3f} " f"LR: {[o_.param_groups[0]['lr'] for o_ in optimizer.values()]}" ) log_data = { "loss": loss_m.val, "data_time": data_time_m.val, "batch_time": batch_time_m.val, "lr": [o_.param_groups[0]["lr"] for o_ in optimizer.values()], } else: logging.info( f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] " f"Loss: {loss_m.val:#.5g} ({loss_m.avg:#.4g}) " f"Data (t): {data_time_m.avg:.3f} " f"Batch (t): {batch_time_m.avg:.3f} " f"LR: {optimizer.param_groups[0]['lr']:5f} " ) # Save train loss / etc. Using non avg meter values as loggers have their own smoothing log_data = { "loss": loss_m.val, "data_time": data_time_m.val, "batch_time": batch_time_m.val, "lr": optimizer.param_groups[0]["lr"], } for name, val in log_data.items(): name = f"train{extra_suffix}/{name}" if tb_writer is not None: tb_writer.add_scalar(name, val, step) if args.wandb: assert wandb is not None, "Please install wandb." wandb.log({name: val, "step": step}) # resetting batch / data time meters per log window batch_time_m.reset() data_time_m.reset() # end for def evaluate(model, data, epoch, args, tb_writer=None, extra_suffix=""): metrics = {} if not args.parallel_eval: if not is_master(args): return metrics device = torch.device(args.device) model.eval() # CHANGE # zero_shot_metrics = zero_shot_eval(model, data, epoch, args) # metrics.update(zero_shot_metrics) if is_master(args): print('Evaluating...') metric_names = args.lp_metrics.split(',') eval_tool = LPMetrics(metric_names=metric_names) autocast = torch.cuda.amp.autocast if args.precision == "amp" else suppress if "val" in data and ( args.val_frequency and ((epoch % args.val_frequency) == 0 or epoch == args.epochs) ): if args.parallel_eval: dataloader, sampler = data["val"].dataloader, data["val"].sampler if args.distributed and sampler is not None: sampler.set_epoch(epoch) samples_per_val = dataloader.num_samples else: dataloader = data["val"].dataloader num_samples = 0 samples_per_val = dataloader.num_samples eval_info = { 'pred': [], 'target': [] } with torch.no_grad(): for i, batch in enumerate(dataloader): audio = batch # contains mel_spec, wavform, and longer list class_label = batch['class_label'] # audio = audio.to(device=device, non_blocking=True) class_label = class_label.to(device=device, non_blocking=True) with autocast(): pred = model(audio, device=device) if args.parallel_eval: pred, class_label = lp_gather_features(pred, class_label, args.world_size, args.horovod) eval_info['pred'].append(pred) eval_info['target'].append(class_label) num_samples += class_label.shape[0] if (i % 100) == 0: # and i != 0: logging.info( f"Eval Epoch: {epoch} [{num_samples} / {samples_per_val}]" ) if is_master(args): eval_info['pred'] = torch.cat(eval_info['pred'], 0).cpu() eval_info['target'] = torch.cat(eval_info['target'], 0).cpu() metric_dict = eval_tool.evaluate_mertics(eval_info['pred'], eval_info['target']) metrics.update(metric_dict) if "epoch" not in metrics.keys(): metrics.update({"epoch": epoch}) if is_master(args): if not metrics: return metrics logging.info( f"Eval Epoch: {epoch} " + "\n".join( [ "\t".join([f"{m}: {round(metrics[m], 4):.4f}" ]) for m in metrics ] ) ) if args.save_logs: for name, val in metrics.items(): if tb_writer is not None: tb_writer.add_scalar(f"val{extra_suffix}/{name}", val, epoch) with open(os.path.join(args.checkpoint_path, "results.jsonl"), "a+") as f: f.write(json.dumps(metrics)) f.write("\n") if args.wandb: assert wandb is not None, "Please install wandb." for name, val in metrics.items(): wandb.log({f"val{extra_suffix}/{name}": val, "epoch": epoch}) return metrics else: return metrics