# Modified from Audiocraft (https://github.com/facebookresearch/audiocraft) from pathlib import Path import time import typing as tp import warnings import flashy import math import omegaconf import torch from torch.nn import functional as F from . import base, builders from .compression import CompressionSolver from .. import metrics as eval_metrics from .. import models from ..data.audio_dataset import AudioDataset from ..data.music_dataset import MusicDataset, MusicInfo, AudioInfo from ..data.audio_utils import normalize_audio from ..modules.conditioners import JointEmbedCondition, SegmentWithAttributes, WavCondition from ..utils.cache import CachedBatchWriter, CachedBatchLoader from ..utils.samples.manager import SampleManager from ..utils.utils import get_dataset_from_loader, is_jsonable, warn_once class MusicGenSolver(base.StandardSolver): """Solver for MusicGen training task. Used in: https://arxiv.org/abs/2306.05284 """ DATASET_TYPE: builders.DatasetType = builders.DatasetType.MUSIC def __init__(self, cfg: omegaconf.DictConfig): super().__init__(cfg) # easier access to sampling parameters self.generation_params = { 'use_sampling': self.cfg.generate.lm.use_sampling, 'temp': self.cfg.generate.lm.temp, 'top_k': self.cfg.generate.lm.top_k, 'top_p': self.cfg.generate.lm.top_p, } self._best_metric_name: tp.Optional[str] = 'ce' self._cached_batch_writer = None self._cached_batch_loader = None if cfg.cache.path: if cfg.cache.write: self._cached_batch_writer = CachedBatchWriter(Path(cfg.cache.path)) if self.cfg.cache.write_num_shards: self.logger.warning("Multiple shard cache, best_metric_name will be set to None.") self._best_metric_name = None else: self._cached_batch_loader = CachedBatchLoader( Path(cfg.cache.path), cfg.dataset.batch_size, cfg.dataset.num_workers, min_length=self.cfg.optim.updates_per_epoch or 1) self.dataloaders['original_train'] = self.dataloaders['train'] self.dataloaders['train'] = self._cached_batch_loader # type: ignore @staticmethod def get_eval_solver_from_sig(sig: str, dtype: tp.Optional[str] = None, device: tp.Optional[str] = None, autocast: bool = True, batch_size: tp.Optional[int] = None, override_cfg: tp.Optional[tp.Union[dict, omegaconf.DictConfig]] = None, **kwargs): """Mostly a convenience function around magma.train.get_solver_from_sig, populating all the proper param, deactivating EMA, FSDP, loading the best state, basically all you need to get a solver ready to "play" with in single GPU mode and with minimal memory overhead. Args: sig (str): signature to load. dtype (str or None): potential dtype, as a string, i.e. 'float16'. device (str or None): potential device, as a string, i.e. 'cuda'. override_cfg (dict or omegaconf.DictConfig or None): potential device, as a string, i.e. 'cuda'. """ from audiocraft import train our_override_cfg: tp.Dict[str, tp.Any] = {'optim': {'ema': {'use': False}}} our_override_cfg['autocast'] = autocast if dtype is not None: our_override_cfg['dtype'] = dtype if device is not None: our_override_cfg['device'] = device if batch_size is not None: our_override_cfg['dataset'] = {'batch_size': batch_size} if override_cfg is None: override_cfg = {} override_cfg = omegaconf.OmegaConf.merge( omegaconf.DictConfig(override_cfg), omegaconf.DictConfig(our_override_cfg)) # type: ignore solver = train.get_solver_from_sig( sig, override_cfg=override_cfg, load_best=True, disable_fsdp=True, ignore_state_keys=['optimizer', 'ema'], **kwargs) solver.model.eval() return solver def get_formatter(self, stage_name: str) -> flashy.Formatter: return flashy.Formatter({ 'lr': '.2E', 'ce': '.3f', 'ppl': '.3f', 'grad_norm': '.3E', }, exclude_keys=['ce_q*', 'ppl_q*']) @property def best_metric_name(self) -> tp.Optional[str]: return self._best_metric_name def build_model(self) -> None: """Instantiate models and optimizer.""" # we can potentially not use all quantizers with which the EnCodec model was trained # (e.g. we trained the model with quantizers dropout) self.compression_model = CompressionSolver.wrapped_model_from_checkpoint( self.cfg, self.cfg.compression_model_checkpoint, device=self.device) assert self.compression_model.sample_rate == self.cfg.sample_rate, ( f"Compression model sample rate is {self.compression_model.sample_rate} but " f"Solver sample rate is {self.cfg.sample_rate}." ) # ensure we have matching configuration between LM and compression model assert self.cfg.transformer_lm.card == self.compression_model.cardinality, ( "Cardinalities of the LM and compression model don't match: ", f"LM cardinality is {self.cfg.transformer_lm.card} vs ", f"compression model cardinality is {self.compression_model.cardinality}" ) assert self.cfg.transformer_lm.n_q == self.compression_model.num_codebooks, ( "Numbers of codebooks of the LM and compression models don't match: ", f"LM number of codebooks is {self.cfg.transformer_lm.n_q} vs ", f"compression model numer of codebooks is {self.compression_model.num_codebooks}" ) self.logger.info("Compression model has %d codebooks with %d cardinality, and a framerate of %d", self.compression_model.num_codebooks, self.compression_model.cardinality, self.compression_model.frame_rate) # instantiate LM model self.model: models.LMModel = models.builders.get_lm_model(self.cfg).to(self.device) if self.cfg.fsdp.use: assert not self.cfg.autocast, "Cannot use autocast with fsdp" self.model = self.wrap_with_fsdp(self.model) self.register_ema('model') # initialize optimization self.optimizer = builders.get_optimizer(builders.get_optim_parameter_groups(self.model), self.cfg.optim) self.lr_scheduler = builders.get_lr_scheduler(self.optimizer, self.cfg.schedule, self.total_updates) self.register_stateful('compression_model', 'model', 'optimizer', 'lr_scheduler') self.register_best_state('model') self.autocast_dtype = { 'float16': torch.float16, 'bfloat16': torch.bfloat16 }[self.cfg.autocast_dtype] self.scaler: tp.Optional[torch.cuda.amp.GradScaler] = None if self.cfg.fsdp.use: need_scaler = self.cfg.fsdp.param_dtype == 'float16' else: need_scaler = self.cfg.autocast and self.autocast_dtype is torch.float16 if need_scaler: if self.cfg.fsdp.use: from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler self.scaler = ShardedGradScaler() # type: ignore else: self.scaler = torch.cuda.amp.GradScaler() self.register_stateful('scaler') def build_dataloaders(self) -> None: """Instantiate audio dataloaders for each stage.""" self.dataloaders = builders.get_audio_datasets(self.cfg, dataset_type=self.DATASET_TYPE) def show(self) -> None: """Show the compression model and LM model.""" self.logger.info("Compression model:") self.log_model_summary(self.compression_model) self.logger.info("LM model:") self.log_model_summary(self.model) def load_state_dict(self, state: dict) -> None: if 'condition_provider' in state: model_state = state['model'] condition_provider_state = state.pop('condition_provider') prefix = 'condition_provider.' for key, value in condition_provider_state.items(): key = prefix + key assert key not in model_state model_state[key] = value super().load_state_dict(state) def load_from_pretrained(self, name: str): # TODO: support native HF versions of MusicGen. lm_pkg = models.loaders.load_lm_model_ckpt(name) state: dict = { 'best_state': { 'model': lm_pkg['best_state'], }, } return state def _compute_cross_entropy( self, logits: torch.Tensor, targets: torch.Tensor, mask: torch.Tensor ) -> tp.Tuple[torch.Tensor, tp.List[torch.Tensor]]: """Compute cross entropy between multi-codebook targets and model's logits. The cross entropy is computed per codebook to provide codebook-level cross entropy. Valid timesteps for each of the codebook are pulled from the mask, where invalid timesteps are set to 0. Args: logits (torch.Tensor): Model's logits of shape [B, K, T, card]. targets (torch.Tensor): Target codes, of shape [B, K, T]. mask (torch.Tensor): Mask for valid target codes, of shape [B, K, T]. Returns: ce (torch.Tensor): Cross entropy averaged over the codebooks ce_per_codebook (list of torch.Tensor): Cross entropy per codebook (detached). """ B, K, T = targets.shape assert logits.shape[:-1] == targets.shape assert mask.shape == targets.shape ce = torch.zeros([], device=targets.device) ce_per_codebook: tp.List[torch.Tensor] = [] for k in range(K): logits_k = logits[:, k, ...].contiguous().view(-1, logits.size(-1)) # [B x T, card] targets_k = targets[:, k, ...].contiguous().view(-1) # [B x T] mask_k = mask[:, k, ...].contiguous().view(-1) # [B x T] ce_targets = targets_k[mask_k] ce_logits = logits_k[mask_k] q_ce = F.cross_entropy(ce_logits, ce_targets) ce += q_ce ce_per_codebook.append(q_ce.detach()) # average cross entropy across codebooks ce = ce / K return ce, ce_per_codebook def _prepare_tokens_and_attributes( self, batch: tp.Tuple[torch.Tensor, tp.List[SegmentWithAttributes]], check_synchronization_points: bool = False ) -> tp.Tuple[dict, torch.Tensor, torch.Tensor]: """Prepare input batchs for language model training. Args: batch (tuple[torch.Tensor, list[SegmentWithAttributes]]): Input batch with audio tensor of shape [B, C, T] and corresponding metadata as SegmentWithAttributes (with B items). check_synchronization_points (bool): Whether to check for synchronization points slowing down training. Returns: Condition tensors (dict[str, any]): Preprocessed condition attributes. Tokens (torch.Tensor): Audio tokens from compression model, of shape [B, K, T_s], with B the batch size, K the number of codebooks, T_s the token timesteps. Padding mask (torch.Tensor): Mask with valid positions in the tokens tensor, of shape [B, K, T_s]. """ if self._cached_batch_loader is None or self.current_stage != "train": audio, video_list, infos = batch assert isinstance(video_list,list) assert len(video_list)==2 [local_video, global_video]=video_list local_video = local_video.to(self.device) global_video = global_video.to(self.device) local_video_tokens = None global_video_tokens = None audio = audio.to(self.device) audio_tokens = None assert audio.size(0) == len(infos), ( f"Mismatch between number of items in audio batch ({audio.size(0)})", f" and in metadata ({len(infos)})" ) else: print(f'----------musicgen.py---line264-------') audio = None # In that case the batch will be a tuple coming from the _cached_batch_writer bit below. infos, = batch # type: ignore assert all([isinstance(info, AudioInfo) for info in infos]) assert all([info.audio_tokens is not None for info in infos]) # type: ignore audio_tokens = torch.stack([info.audio_tokens for info in infos]).to(self.device) # type: ignore audio_tokens = audio_tokens.long() for info in infos: if isinstance(info, MusicInfo): # Careful here, if you want to use this condition_wav (e.b. chroma conditioning), # then you must be using the chroma cache! otherwise the code will try # to use this segment and fail (by that I mean you will see NaN everywhere). info.self_wav = WavCondition( torch.full([1, info.channels, info.total_frames], float('NaN')), length=torch.tensor([info.n_frames]), sample_rate=[info.sample_rate], path=[info.meta.path], seek_time=[info.seek_time]) dataset = get_dataset_from_loader(self.dataloaders['original_train']) assert isinstance(dataset, MusicDataset), type(dataset) if dataset.paraphraser is not None and info.description is not None: # Hackingly reapplying paraphraser when using cache. info.description = dataset.paraphraser.sample_paraphrase( info.meta.path, info.description) # Now we should be synchronization free. if self.device == "cuda" and check_synchronization_points: torch.cuda.set_sync_debug_mode("warn") if audio_tokens is None: with torch.no_grad(): audio_tokens, scale = self.compression_model.encode(audio) assert scale is None, "Scaled compression model not supported with LM." # create a padding mask to hold valid vs invalid positions padding_mask = torch.ones_like(audio_tokens, dtype=torch.bool, device=audio_tokens.device) # replace encodec tokens from padded audio with special_token_id if self.cfg.tokens.padding_with_special_token: audio_tokens = audio_tokens.clone() padding_mask = padding_mask.clone() token_sample_rate = self.compression_model.frame_rate B, K, T_s = audio_tokens.shape for i in range(B): n_samples = infos[i].n_frames audio_sample_rate = infos[i].sample_rate # take the last token generated from actual audio frames (non-padded audio) valid_tokens = math.floor(float(n_samples) / audio_sample_rate * token_sample_rate) audio_tokens[i, :, valid_tokens:] = self.model.special_token_id padding_mask[i, :, valid_tokens:] = 0 if self.device == "cuda" and check_synchronization_points: torch.cuda.set_sync_debug_mode("default") if self._cached_batch_writer is not None and self.current_stage == 'train': assert self._cached_batch_loader is None assert audio_tokens is not None for info, one_audio_tokens in zip(infos, audio_tokens): assert isinstance(info, AudioInfo) if isinstance(info, MusicInfo): assert not info.joint_embed, "joint_embed and cache not supported yet." info.self_wav = None assert one_audio_tokens.max() < 2**15, one_audio_tokens.max().item() info.audio_tokens = one_audio_tokens.short().cpu() self._cached_batch_writer.save(infos) assert isinstance(video_list,list) assert len(video_list)==2 return [local_video, global_video], audio_tokens, padding_mask def run_step(self, idx: int, batch: tp.Tuple[torch.Tensor, tp.List[SegmentWithAttributes]], metrics: dict) -> dict: """Perform one training or valid step on a given batch.""" check_synchronization_points = idx == 1 and self.device == 'cuda' video_list, audio_tokens, padding_mask = self._prepare_tokens_and_attributes( batch, check_synchronization_points) assert isinstance(video_list,list) assert len(video_list)==2 [local_video, global_video]=video_list self.deadlock_detect.update('tokens_and_conditions') if check_synchronization_points: torch.cuda.set_sync_debug_mode('warn') with self.autocast: assert len(video_list)==2 model_output = self.model.compute_predictions(audio_tokens, [], [local_video, global_video]) # type: ignore logits = model_output.logits mask = padding_mask & model_output.mask ce, ce_per_codebook = self._compute_cross_entropy(logits, audio_tokens, mask) loss = ce self.deadlock_detect.update('loss') if check_synchronization_points: torch.cuda.set_sync_debug_mode('default') if self.is_training: metrics['lr'] = self.optimizer.param_groups[0]['lr'] if self.scaler is not None: loss = self.scaler.scale(loss) self.deadlock_detect.update('scale') if self.cfg.fsdp.use: loss.backward() flashy.distrib.average_tensors(self.model.buffers()) elif self.cfg.optim.eager_sync: with flashy.distrib.eager_sync_model(self.model): loss.backward() else: # this should always be slower but can be useful # for weird use cases like multiple backwards. loss.backward() flashy.distrib.sync_model(self.model) self.deadlock_detect.update('backward') if self.scaler is not None: self.scaler.unscale_(self.optimizer) if self.cfg.optim.max_norm: if self.cfg.fsdp.use: metrics['grad_norm'] = self.model.clip_grad_norm_(self.cfg.optim.max_norm) # type: ignore else: metrics['grad_norm'] = torch.nn.utils.clip_grad_norm_( self.model.parameters(), self.cfg.optim.max_norm ) if self.scaler is None: self.optimizer.step() else: self.scaler.step(self.optimizer) self.scaler.update() if self.lr_scheduler: self.lr_scheduler.step() self.optimizer.zero_grad() self.deadlock_detect.update('optim') if self.scaler is not None: scale = self.scaler.get_scale() metrics['grad_scale'] = scale if not loss.isfinite().all(): raise RuntimeError("Model probably diverged.") metrics['ce'] = ce metrics['ppl'] = torch.exp(ce) for k, ce_q in enumerate(ce_per_codebook): metrics[f'ce_q{k + 1}'] = ce_q metrics[f'ppl_q{k + 1}'] = torch.exp(ce_q) return metrics @torch.no_grad() def run_generate_step(self, batch: tp.Tuple[torch.Tensor, tp.List[SegmentWithAttributes]], gen_duration: float, prompt_duration: tp.Optional[float] = None, remove_prompt: bool = False, **generation_params) -> dict: """Run generate step on a batch of optional audio tensor and corresponding attributes/video feature. Args: batch (tuple[torch.Tensor, list[SegmentWithAttributes]]): use_prompt (bool): Whether to do audio continuation generation with prompt from audio batch. gen_duration (float): Target audio duration for the generation. prompt_duration (float, optional): Duration for the audio prompt to use for continuation. remove_prompt (bool, optional): Whether to remove the prompt from the generated audio. generation_params: Additional generation parameters. Returns: gen_outputs (dict): Generation outputs, consisting in audio, audio tokens from both the generation and the prompt along with additional information. """ bench_start = time.time() audio, video_list, meta = batch assert isinstance(video_list,list) assert len(video_list)==2 [local_video, global_video]=video_list assert audio.size(0) == len(local_video), ( f"Mismatch between number of items in audio batch ({audio.size(0)})", f" and in video ({len(meta)})" ) assert audio.size(0) == len(global_video), ( f"Mismatch between number of items in audio batch ({audio.size(0)})", f" and in video ({len(meta)})" ) local_video_attributes = local_video global_video_attributes = global_video # prepare attributes # attributes = [x.to_condition_attributes() for x in meta] # TODO: Add dropout for chroma? # prepare audio prompt if prompt_duration is None: prompt_audio = None else: assert prompt_duration < gen_duration, "Prompt duration must be lower than target generation duration" prompt_audio_frames = int(prompt_duration * self.compression_model.sample_rate) prompt_audio = audio[..., :prompt_audio_frames] # get audio tokens from compression model assert isinstance(video_list,list) if prompt_audio is None or prompt_audio.nelement() == 0: assert len(video_list)==2 num_samples = len(local_video_attributes) prompt_tokens = None else: num_samples = None prompt_audio = prompt_audio.to(self.device) prompt_tokens, scale = self.compression_model.encode(prompt_audio) assert scale is None, "Compression model in MusicGen should not require rescaling." # generate by sampling from the LM with self.autocast: total_gen_len = math.ceil(gen_duration * self.compression_model.frame_rate) assert len(video_list)==2 gen_tokens = self.model.generate( prompt_tokens, [local_video_attributes, global_video_attributes], max_gen_len=total_gen_len, num_samples=num_samples, **self.generation_params) # generate audio from tokens assert gen_tokens.dim() == 3 gen_audio = self.compression_model.decode(gen_tokens, None) bench_end = time.time() gen_outputs = { 'rtf': (bench_end - bench_start) / gen_duration, 'ref_audio': audio, 'gen_audio': gen_audio, 'gen_tokens': gen_tokens, 'prompt_audio': prompt_audio, 'prompt_tokens': prompt_tokens, } return gen_outputs def generate_audio(self) -> dict: """Audio generation stage.""" generate_stage_name = f'{self.current_stage}' sample_manager = SampleManager(self.xp) self.logger.info(f"Generating samples in {sample_manager.base_folder}") loader = self.dataloaders['generate'] updates = len(loader) lp = self.log_progress(generate_stage_name, loader, total=updates, updates=self.log_updates) dataset = get_dataset_from_loader(loader) dataset_duration = dataset.segment_duration assert dataset_duration is not None assert isinstance(dataset, AudioDataset) target_duration = self.cfg.generate.lm.gen_duration prompt_duration = self.cfg.generate.lm.prompt_duration if target_duration is None: target_duration = dataset_duration if prompt_duration is None: prompt_duration = dataset_duration / 4 assert prompt_duration < dataset_duration, ( f"Specified prompt duration ({prompt_duration}s) is longer", f" than reference audio duration ({dataset_duration}s)" ) def get_hydrated_conditions(meta: tp.List[SegmentWithAttributes]): hydrated_conditions = [] for sample in [x.to_condition_attributes() for x in meta]: cond_dict = {} for cond_type in sample.__annotations__.keys(): for cond_key, cond_val in getattr(sample, cond_type).items(): if cond_key not in self.model.condition_provider.conditioners.keys(): continue if is_jsonable(cond_val): cond_dict[cond_key] = cond_val elif isinstance(cond_val, WavCondition): cond_dict[cond_key] = cond_val.path elif isinstance(cond_val, JointEmbedCondition): cond_dict[cond_key] = cond_val.text # only support text at inference for now else: # if we reached this point, it is not clear how to log the condition # so we just log the type. cond_dict[cond_key] = str(type(cond_val)) continue hydrated_conditions.append(cond_dict) return hydrated_conditions metrics: dict = {} average = flashy.averager() for batch in lp: audio, video_list, meta = batch assert isinstance(video_list,list) assert len(video_list)==2 local_video = video_list[0] global_video = video_list[1] hydrated_conditions = meta[0].meta.video_path sample_generation_params = { **{f'classifier_free_guidance_{k}': v for k, v in self.cfg.classifier_free_guidance.items()}, **self.generation_params } if self.cfg.generate.lm.unprompted_samples: if self.cfg.generate.lm.gen_gt_samples: # get the ground truth instead of generation self.logger.warn( "Use ground truth instead of audio generation as generate.lm.gen_gt_samples=true") gen_unprompted_audio = audio rtf = 1. else: gen_unprompted_outputs = self.run_generate_step( batch, gen_duration=target_duration, prompt_duration=None, **self.generation_params) gen_unprompted_audio = gen_unprompted_outputs['gen_audio'].cpu() rtf = gen_unprompted_outputs['rtf'] sample_manager.add_samples( gen_unprompted_audio, self.epoch, hydrated_conditions, ground_truth_wavs=audio, generation_args=sample_generation_params) if self.cfg.generate.lm.prompted_samples: gen_outputs = self.run_generate_step( batch, gen_duration=target_duration, prompt_duration=prompt_duration, **self.generation_params) gen_audio = gen_outputs['gen_audio'].cpu() prompt_audio = gen_outputs['prompt_audio'].cpu() sample_manager.add_samples( gen_audio, self.epoch, hydrated_conditions, prompt_wavs=prompt_audio, ground_truth_wavs=audio, generation_args=sample_generation_params) metrics['rtf'] = rtf metrics = average(metrics) flashy.distrib.barrier() return metrics def generate(self) -> dict: """Generate stage.""" self.model.eval() with torch.no_grad(): return self.generate_audio() def run_epoch(self): if self.cfg.cache.write: if ((self.epoch - 1) % self.cfg.cache.write_num_shards) != self.cfg.cache.write_shard: return super().run_epoch() def train(self): """Train stage. """ if self._cached_batch_writer is not None: self._cached_batch_writer.start_epoch(self.epoch) if self._cached_batch_loader is None: dataset = get_dataset_from_loader(self.dataloaders['train']) assert isinstance(dataset, AudioDataset) dataset.current_epoch = self.epoch else: self._cached_batch_loader.start_epoch(self.epoch) return super().train() def evaluate_audio_generation(self) -> dict: """Evaluate audio generation with off-the-shelf metrics.""" evaluate_stage_name = f'{self.current_stage}_generation' # instantiate evaluation metrics, if at least one metric is defined, run audio generation evaluation fad: tp.Optional[eval_metrics.FrechetAudioDistanceMetric] = None kldiv: tp.Optional[eval_metrics.KLDivergenceMetric] = None text_consistency: tp.Optional[eval_metrics.TextConsistencyMetric] = None chroma_cosine: tp.Optional[eval_metrics.ChromaCosineSimilarityMetric] = None should_run_eval = False eval_chroma_wavs: tp.Optional[torch.Tensor] = None if self.cfg.evaluate.metrics.fad: fad = builders.get_fad(self.cfg.metrics.fad).to(self.device) should_run_eval = True if self.cfg.evaluate.metrics.kld: kldiv = builders.get_kldiv(self.cfg.metrics.kld).to(self.device) should_run_eval = True if self.cfg.evaluate.metrics.text_consistency: text_consistency = builders.get_text_consistency(self.cfg.metrics.text_consistency).to(self.device) should_run_eval = True if self.cfg.evaluate.metrics.chroma_cosine: chroma_cosine = builders.get_chroma_cosine_similarity(self.cfg.metrics.chroma_cosine).to(self.device) # if we have predefind wavs for chroma we should purge them for computing the cosine metric has_predefined_eval_chromas = 'self_wav' in self.model.condition_provider.conditioners and \ self.model.condition_provider.conditioners['self_wav'].has_eval_wavs() if has_predefined_eval_chromas: warn_once(self.logger, "Attempting to run cosine eval for config with pre-defined eval chromas! " 'Resetting eval chromas to None for evaluation.') eval_chroma_wavs = self.model.condition_provider.conditioners.self_wav.eval_wavs # type: ignore self.model.condition_provider.conditioners.self_wav.reset_eval_wavs(None) # type: ignore should_run_eval = True def get_compressed_audio(audio: torch.Tensor) -> torch.Tensor: audio_tokens, scale = self.compression_model.encode(audio.to(self.device)) compressed_audio = self.compression_model.decode(audio_tokens, scale) return compressed_audio[..., :audio.shape[-1]] metrics: dict = {} if should_run_eval: loader = self.dataloaders['evaluate'] updates = len(loader) lp = self.log_progress(f'{evaluate_stage_name} inference', loader, total=updates, updates=self.log_updates) average = flashy.averager() dataset = get_dataset_from_loader(loader) assert isinstance(dataset, AudioDataset) self.logger.info(f"Computing evaluation metrics on {len(dataset)} samples") for idx, batch in enumerate(lp): audio, meta = batch assert all([self.cfg.sample_rate == m.sample_rate for m in meta]) target_duration = audio.shape[-1] / self.cfg.sample_rate if self.cfg.evaluate.fixed_generation_duration: target_duration = self.cfg.evaluate.fixed_generation_duration gen_outputs = self.run_generate_step( batch, gen_duration=target_duration, **self.generation_params ) y_pred = gen_outputs['gen_audio'].detach() y_pred = y_pred[..., :audio.shape[-1]] normalize_kwargs = dict(self.cfg.generate.audio) normalize_kwargs.pop('format', None) y_pred = torch.stack([normalize_audio(w, **normalize_kwargs) for w in y_pred], dim=0).cpu() y = audio.cpu() # should already be on CPU but just in case sizes = torch.tensor([m.n_frames for m in meta]) # actual sizes without padding sample_rates = torch.tensor([m.sample_rate for m in meta]) # sample rates for audio samples audio_stems = [Path(m.meta.path).stem + f"_{m.seek_time}" for m in meta] if fad is not None: if self.cfg.metrics.fad.use_gt: y_pred = get_compressed_audio(y).cpu() fad.update(y_pred, y, sizes, sample_rates, audio_stems) if kldiv is not None: if self.cfg.metrics.kld.use_gt: y_pred = get_compressed_audio(y).cpu() kldiv.update(y_pred, y, sizes, sample_rates) if text_consistency is not None: texts = [m.description for m in meta] if self.cfg.metrics.text_consistency.use_gt: y_pred = y text_consistency.update(y_pred, texts, sizes, sample_rates) if chroma_cosine is not None: if self.cfg.metrics.chroma_cosine.use_gt: y_pred = get_compressed_audio(y).cpu() chroma_cosine.update(y_pred, y, sizes, sample_rates) # restore chroma conditioner's eval chroma wavs if eval_chroma_wavs is not None: self.model.condition_provider.conditioners['self_wav'].reset_eval_wavs(eval_chroma_wavs) flashy.distrib.barrier() if fad is not None: metrics['fad'] = fad.compute() if kldiv is not None: kld_metrics = kldiv.compute() metrics.update(kld_metrics) if text_consistency is not None: metrics['text_consistency'] = text_consistency.compute() if chroma_cosine is not None: metrics['chroma_cosine'] = chroma_cosine.compute() metrics = average(metrics) metrics = flashy.distrib.average_metrics(metrics, len(loader)) return metrics def evaluate(self) -> dict: """Evaluate stage.""" self.model.eval() with torch.no_grad(): metrics: dict = {} if self.cfg.evaluate.metrics.base: metrics.update(self.common_train_valid('evaluate')) gen_metrics = self.evaluate_audio_generation() return {**metrics, **gen_metrics}