Zeyue7's picture
VidMuse_CVPR
ffa5ac7
# Modified from Audiocraft (https://github.com/facebookresearch/audiocraft)
from pathlib import Path
import time
import typing as tp
import warnings
import flashy
import math
import omegaconf
import torch
from torch.nn import functional as F
from . import base, builders
from .compression import CompressionSolver
from .. import metrics as eval_metrics
from .. import models
from ..data.audio_dataset import AudioDataset
from ..data.music_dataset import MusicDataset, MusicInfo, AudioInfo
from ..data.audio_utils import normalize_audio
from ..modules.conditioners import JointEmbedCondition, SegmentWithAttributes, WavCondition
from ..utils.cache import CachedBatchWriter, CachedBatchLoader
from ..utils.samples.manager import SampleManager
from ..utils.utils import get_dataset_from_loader, is_jsonable, warn_once
class MusicGenSolver(base.StandardSolver):
"""Solver for MusicGen training task.
Used in: https://arxiv.org/abs/2306.05284
"""
DATASET_TYPE: builders.DatasetType = builders.DatasetType.MUSIC
def __init__(self, cfg: omegaconf.DictConfig):
super().__init__(cfg)
# easier access to sampling parameters
self.generation_params = {
'use_sampling': self.cfg.generate.lm.use_sampling,
'temp': self.cfg.generate.lm.temp,
'top_k': self.cfg.generate.lm.top_k,
'top_p': self.cfg.generate.lm.top_p,
}
self._best_metric_name: tp.Optional[str] = 'ce'
self._cached_batch_writer = None
self._cached_batch_loader = None
if cfg.cache.path:
if cfg.cache.write:
self._cached_batch_writer = CachedBatchWriter(Path(cfg.cache.path))
if self.cfg.cache.write_num_shards:
self.logger.warning("Multiple shard cache, best_metric_name will be set to None.")
self._best_metric_name = None
else:
self._cached_batch_loader = CachedBatchLoader(
Path(cfg.cache.path), cfg.dataset.batch_size, cfg.dataset.num_workers,
min_length=self.cfg.optim.updates_per_epoch or 1)
self.dataloaders['original_train'] = self.dataloaders['train']
self.dataloaders['train'] = self._cached_batch_loader # type: ignore
@staticmethod
def get_eval_solver_from_sig(sig: str, dtype: tp.Optional[str] = None,
device: tp.Optional[str] = None, autocast: bool = True,
batch_size: tp.Optional[int] = None,
override_cfg: tp.Optional[tp.Union[dict, omegaconf.DictConfig]] = None,
**kwargs):
"""Mostly a convenience function around magma.train.get_solver_from_sig,
populating all the proper param, deactivating EMA, FSDP, loading the best state,
basically all you need to get a solver ready to "play" with in single GPU mode
and with minimal memory overhead.
Args:
sig (str): signature to load.
dtype (str or None): potential dtype, as a string, i.e. 'float16'.
device (str or None): potential device, as a string, i.e. 'cuda'.
override_cfg (dict or omegaconf.DictConfig or None): potential device, as a string, i.e. 'cuda'.
"""
from audiocraft import train
our_override_cfg: tp.Dict[str, tp.Any] = {'optim': {'ema': {'use': False}}}
our_override_cfg['autocast'] = autocast
if dtype is not None:
our_override_cfg['dtype'] = dtype
if device is not None:
our_override_cfg['device'] = device
if batch_size is not None:
our_override_cfg['dataset'] = {'batch_size': batch_size}
if override_cfg is None:
override_cfg = {}
override_cfg = omegaconf.OmegaConf.merge(
omegaconf.DictConfig(override_cfg), omegaconf.DictConfig(our_override_cfg)) # type: ignore
solver = train.get_solver_from_sig(
sig, override_cfg=override_cfg,
load_best=True, disable_fsdp=True,
ignore_state_keys=['optimizer', 'ema'], **kwargs)
solver.model.eval()
return solver
def get_formatter(self, stage_name: str) -> flashy.Formatter:
return flashy.Formatter({
'lr': '.2E',
'ce': '.3f',
'ppl': '.3f',
'grad_norm': '.3E',
}, exclude_keys=['ce_q*', 'ppl_q*'])
@property
def best_metric_name(self) -> tp.Optional[str]:
return self._best_metric_name
def build_model(self) -> None:
"""Instantiate models and optimizer."""
# we can potentially not use all quantizers with which the EnCodec model was trained
# (e.g. we trained the model with quantizers dropout)
self.compression_model = CompressionSolver.wrapped_model_from_checkpoint(
self.cfg, self.cfg.compression_model_checkpoint, device=self.device)
assert self.compression_model.sample_rate == self.cfg.sample_rate, (
f"Compression model sample rate is {self.compression_model.sample_rate} but "
f"Solver sample rate is {self.cfg.sample_rate}."
)
# ensure we have matching configuration between LM and compression model
assert self.cfg.transformer_lm.card == self.compression_model.cardinality, (
"Cardinalities of the LM and compression model don't match: ",
f"LM cardinality is {self.cfg.transformer_lm.card} vs ",
f"compression model cardinality is {self.compression_model.cardinality}"
)
assert self.cfg.transformer_lm.n_q == self.compression_model.num_codebooks, (
"Numbers of codebooks of the LM and compression models don't match: ",
f"LM number of codebooks is {self.cfg.transformer_lm.n_q} vs ",
f"compression model numer of codebooks is {self.compression_model.num_codebooks}"
)
self.logger.info("Compression model has %d codebooks with %d cardinality, and a framerate of %d",
self.compression_model.num_codebooks, self.compression_model.cardinality,
self.compression_model.frame_rate)
# instantiate LM model
self.model: models.LMModel = models.builders.get_lm_model(self.cfg).to(self.device)
if self.cfg.fsdp.use:
assert not self.cfg.autocast, "Cannot use autocast with fsdp"
self.model = self.wrap_with_fsdp(self.model)
self.register_ema('model')
# initialize optimization
self.optimizer = builders.get_optimizer(builders.get_optim_parameter_groups(self.model), self.cfg.optim)
self.lr_scheduler = builders.get_lr_scheduler(self.optimizer, self.cfg.schedule, self.total_updates)
self.register_stateful('compression_model', 'model', 'optimizer', 'lr_scheduler')
self.register_best_state('model')
self.autocast_dtype = {
'float16': torch.float16, 'bfloat16': torch.bfloat16
}[self.cfg.autocast_dtype]
self.scaler: tp.Optional[torch.cuda.amp.GradScaler] = None
if self.cfg.fsdp.use:
need_scaler = self.cfg.fsdp.param_dtype == 'float16'
else:
need_scaler = self.cfg.autocast and self.autocast_dtype is torch.float16
if need_scaler:
if self.cfg.fsdp.use:
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
self.scaler = ShardedGradScaler() # type: ignore
else:
self.scaler = torch.cuda.amp.GradScaler()
self.register_stateful('scaler')
def build_dataloaders(self) -> None:
"""Instantiate audio dataloaders for each stage."""
self.dataloaders = builders.get_audio_datasets(self.cfg, dataset_type=self.DATASET_TYPE)
def show(self) -> None:
"""Show the compression model and LM model."""
self.logger.info("Compression model:")
self.log_model_summary(self.compression_model)
self.logger.info("LM model:")
self.log_model_summary(self.model)
def load_state_dict(self, state: dict) -> None:
if 'condition_provider' in state:
model_state = state['model']
condition_provider_state = state.pop('condition_provider')
prefix = 'condition_provider.'
for key, value in condition_provider_state.items():
key = prefix + key
assert key not in model_state
model_state[key] = value
super().load_state_dict(state)
def load_from_pretrained(self, name: str):
# TODO: support native HF versions of MusicGen.
lm_pkg = models.loaders.load_lm_model_ckpt(name)
state: dict = {
'best_state': {
'model': lm_pkg['best_state'],
},
}
return state
def _compute_cross_entropy(
self, logits: torch.Tensor, targets: torch.Tensor, mask: torch.Tensor
) -> tp.Tuple[torch.Tensor, tp.List[torch.Tensor]]:
"""Compute cross entropy between multi-codebook targets and model's logits.
The cross entropy is computed per codebook to provide codebook-level cross entropy.
Valid timesteps for each of the codebook are pulled from the mask, where invalid
timesteps are set to 0.
Args:
logits (torch.Tensor): Model's logits of shape [B, K, T, card].
targets (torch.Tensor): Target codes, of shape [B, K, T].
mask (torch.Tensor): Mask for valid target codes, of shape [B, K, T].
Returns:
ce (torch.Tensor): Cross entropy averaged over the codebooks
ce_per_codebook (list of torch.Tensor): Cross entropy per codebook (detached).
"""
B, K, T = targets.shape
assert logits.shape[:-1] == targets.shape
assert mask.shape == targets.shape
ce = torch.zeros([], device=targets.device)
ce_per_codebook: tp.List[torch.Tensor] = []
for k in range(K):
logits_k = logits[:, k, ...].contiguous().view(-1, logits.size(-1)) # [B x T, card]
targets_k = targets[:, k, ...].contiguous().view(-1) # [B x T]
mask_k = mask[:, k, ...].contiguous().view(-1) # [B x T]
ce_targets = targets_k[mask_k]
ce_logits = logits_k[mask_k]
q_ce = F.cross_entropy(ce_logits, ce_targets)
ce += q_ce
ce_per_codebook.append(q_ce.detach())
# average cross entropy across codebooks
ce = ce / K
return ce, ce_per_codebook
def _prepare_tokens_and_attributes(
self, batch: tp.Tuple[torch.Tensor, tp.List[SegmentWithAttributes]],
check_synchronization_points: bool = False
) -> tp.Tuple[dict, torch.Tensor, torch.Tensor]:
"""Prepare input batchs for language model training.
Args:
batch (tuple[torch.Tensor, list[SegmentWithAttributes]]): Input batch with audio tensor of shape [B, C, T]
and corresponding metadata as SegmentWithAttributes (with B items).
check_synchronization_points (bool): Whether to check for synchronization points slowing down training.
Returns:
Condition tensors (dict[str, any]): Preprocessed condition attributes.
Tokens (torch.Tensor): Audio tokens from compression model, of shape [B, K, T_s],
with B the batch size, K the number of codebooks, T_s the token timesteps.
Padding mask (torch.Tensor): Mask with valid positions in the tokens tensor, of shape [B, K, T_s].
"""
if self._cached_batch_loader is None or self.current_stage != "train":
audio, video_list, infos = batch
assert isinstance(video_list,list)
assert len(video_list)==2
[local_video, global_video]=video_list
local_video = local_video.to(self.device)
global_video = global_video.to(self.device)
local_video_tokens = None
global_video_tokens = None
audio = audio.to(self.device)
audio_tokens = None
assert audio.size(0) == len(infos), (
f"Mismatch between number of items in audio batch ({audio.size(0)})",
f" and in metadata ({len(infos)})"
)
else:
print(f'----------musicgen.py---line264-------')
audio = None
# In that case the batch will be a tuple coming from the _cached_batch_writer bit below.
infos, = batch # type: ignore
assert all([isinstance(info, AudioInfo) for info in infos])
assert all([info.audio_tokens is not None for info in infos]) # type: ignore
audio_tokens = torch.stack([info.audio_tokens for info in infos]).to(self.device) # type: ignore
audio_tokens = audio_tokens.long()
for info in infos:
if isinstance(info, MusicInfo):
# Careful here, if you want to use this condition_wav (e.b. chroma conditioning),
# then you must be using the chroma cache! otherwise the code will try
# to use this segment and fail (by that I mean you will see NaN everywhere).
info.self_wav = WavCondition(
torch.full([1, info.channels, info.total_frames], float('NaN')),
length=torch.tensor([info.n_frames]),
sample_rate=[info.sample_rate],
path=[info.meta.path],
seek_time=[info.seek_time])
dataset = get_dataset_from_loader(self.dataloaders['original_train'])
assert isinstance(dataset, MusicDataset), type(dataset)
if dataset.paraphraser is not None and info.description is not None:
# Hackingly reapplying paraphraser when using cache.
info.description = dataset.paraphraser.sample_paraphrase(
info.meta.path, info.description)
# Now we should be synchronization free.
if self.device == "cuda" and check_synchronization_points:
torch.cuda.set_sync_debug_mode("warn")
if audio_tokens is None:
with torch.no_grad():
audio_tokens, scale = self.compression_model.encode(audio)
assert scale is None, "Scaled compression model not supported with LM."
# create a padding mask to hold valid vs invalid positions
padding_mask = torch.ones_like(audio_tokens, dtype=torch.bool, device=audio_tokens.device)
# replace encodec tokens from padded audio with special_token_id
if self.cfg.tokens.padding_with_special_token:
audio_tokens = audio_tokens.clone()
padding_mask = padding_mask.clone()
token_sample_rate = self.compression_model.frame_rate
B, K, T_s = audio_tokens.shape
for i in range(B):
n_samples = infos[i].n_frames
audio_sample_rate = infos[i].sample_rate
# take the last token generated from actual audio frames (non-padded audio)
valid_tokens = math.floor(float(n_samples) / audio_sample_rate * token_sample_rate)
audio_tokens[i, :, valid_tokens:] = self.model.special_token_id
padding_mask[i, :, valid_tokens:] = 0
if self.device == "cuda" and check_synchronization_points:
torch.cuda.set_sync_debug_mode("default")
if self._cached_batch_writer is not None and self.current_stage == 'train':
assert self._cached_batch_loader is None
assert audio_tokens is not None
for info, one_audio_tokens in zip(infos, audio_tokens):
assert isinstance(info, AudioInfo)
if isinstance(info, MusicInfo):
assert not info.joint_embed, "joint_embed and cache not supported yet."
info.self_wav = None
assert one_audio_tokens.max() < 2**15, one_audio_tokens.max().item()
info.audio_tokens = one_audio_tokens.short().cpu()
self._cached_batch_writer.save(infos)
assert isinstance(video_list,list)
assert len(video_list)==2
return [local_video, global_video], audio_tokens, padding_mask
def run_step(self, idx: int, batch: tp.Tuple[torch.Tensor, tp.List[SegmentWithAttributes]], metrics: dict) -> dict:
"""Perform one training or valid step on a given batch."""
check_synchronization_points = idx == 1 and self.device == 'cuda'
video_list, audio_tokens, padding_mask = self._prepare_tokens_and_attributes(
batch, check_synchronization_points)
assert isinstance(video_list,list)
assert len(video_list)==2
[local_video, global_video]=video_list
self.deadlock_detect.update('tokens_and_conditions')
if check_synchronization_points:
torch.cuda.set_sync_debug_mode('warn')
with self.autocast:
assert len(video_list)==2
model_output = self.model.compute_predictions(audio_tokens, [], [local_video, global_video]) # type: ignore
logits = model_output.logits
mask = padding_mask & model_output.mask
ce, ce_per_codebook = self._compute_cross_entropy(logits, audio_tokens, mask)
loss = ce
self.deadlock_detect.update('loss')
if check_synchronization_points:
torch.cuda.set_sync_debug_mode('default')
if self.is_training:
metrics['lr'] = self.optimizer.param_groups[0]['lr']
if self.scaler is not None:
loss = self.scaler.scale(loss)
self.deadlock_detect.update('scale')
if self.cfg.fsdp.use:
loss.backward()
flashy.distrib.average_tensors(self.model.buffers())
elif self.cfg.optim.eager_sync:
with flashy.distrib.eager_sync_model(self.model):
loss.backward()
else:
# this should always be slower but can be useful
# for weird use cases like multiple backwards.
loss.backward()
flashy.distrib.sync_model(self.model)
self.deadlock_detect.update('backward')
if self.scaler is not None:
self.scaler.unscale_(self.optimizer)
if self.cfg.optim.max_norm:
if self.cfg.fsdp.use:
metrics['grad_norm'] = self.model.clip_grad_norm_(self.cfg.optim.max_norm) # type: ignore
else:
metrics['grad_norm'] = torch.nn.utils.clip_grad_norm_(
self.model.parameters(), self.cfg.optim.max_norm
)
if self.scaler is None:
self.optimizer.step()
else:
self.scaler.step(self.optimizer)
self.scaler.update()
if self.lr_scheduler:
self.lr_scheduler.step()
self.optimizer.zero_grad()
self.deadlock_detect.update('optim')
if self.scaler is not None:
scale = self.scaler.get_scale()
metrics['grad_scale'] = scale
if not loss.isfinite().all():
raise RuntimeError("Model probably diverged.")
metrics['ce'] = ce
metrics['ppl'] = torch.exp(ce)
for k, ce_q in enumerate(ce_per_codebook):
metrics[f'ce_q{k + 1}'] = ce_q
metrics[f'ppl_q{k + 1}'] = torch.exp(ce_q)
return metrics
@torch.no_grad()
def run_generate_step(self, batch: tp.Tuple[torch.Tensor, tp.List[SegmentWithAttributes]],
gen_duration: float, prompt_duration: tp.Optional[float] = None,
remove_prompt: bool = False,
**generation_params) -> dict:
"""Run generate step on a batch of optional audio tensor and corresponding attributes/video feature.
Args:
batch (tuple[torch.Tensor, list[SegmentWithAttributes]]):
use_prompt (bool): Whether to do audio continuation generation with prompt from audio batch.
gen_duration (float): Target audio duration for the generation.
prompt_duration (float, optional): Duration for the audio prompt to use for continuation.
remove_prompt (bool, optional): Whether to remove the prompt from the generated audio.
generation_params: Additional generation parameters.
Returns:
gen_outputs (dict): Generation outputs, consisting in audio, audio tokens from both the generation
and the prompt along with additional information.
"""
bench_start = time.time()
audio, video_list, meta = batch
assert isinstance(video_list,list)
assert len(video_list)==2
[local_video, global_video]=video_list
assert audio.size(0) == len(local_video), (
f"Mismatch between number of items in audio batch ({audio.size(0)})",
f" and in video ({len(meta)})"
)
assert audio.size(0) == len(global_video), (
f"Mismatch between number of items in audio batch ({audio.size(0)})",
f" and in video ({len(meta)})"
)
local_video_attributes = local_video
global_video_attributes = global_video
# prepare attributes
# attributes = [x.to_condition_attributes() for x in meta]
# TODO: Add dropout for chroma?
# prepare audio prompt
if prompt_duration is None:
prompt_audio = None
else:
assert prompt_duration < gen_duration, "Prompt duration must be lower than target generation duration"
prompt_audio_frames = int(prompt_duration * self.compression_model.sample_rate)
prompt_audio = audio[..., :prompt_audio_frames]
# get audio tokens from compression model
assert isinstance(video_list,list)
if prompt_audio is None or prompt_audio.nelement() == 0:
assert len(video_list)==2
num_samples = len(local_video_attributes)
prompt_tokens = None
else:
num_samples = None
prompt_audio = prompt_audio.to(self.device)
prompt_tokens, scale = self.compression_model.encode(prompt_audio)
assert scale is None, "Compression model in MusicGen should not require rescaling."
# generate by sampling from the LM
with self.autocast:
total_gen_len = math.ceil(gen_duration * self.compression_model.frame_rate)
assert len(video_list)==2
gen_tokens = self.model.generate(
prompt_tokens, [local_video_attributes, global_video_attributes], max_gen_len=total_gen_len,
num_samples=num_samples, **self.generation_params)
# generate audio from tokens
assert gen_tokens.dim() == 3
gen_audio = self.compression_model.decode(gen_tokens, None)
bench_end = time.time()
gen_outputs = {
'rtf': (bench_end - bench_start) / gen_duration,
'ref_audio': audio,
'gen_audio': gen_audio,
'gen_tokens': gen_tokens,
'prompt_audio': prompt_audio,
'prompt_tokens': prompt_tokens,
}
return gen_outputs
def generate_audio(self) -> dict:
"""Audio generation stage."""
generate_stage_name = f'{self.current_stage}'
sample_manager = SampleManager(self.xp)
self.logger.info(f"Generating samples in {sample_manager.base_folder}")
loader = self.dataloaders['generate']
updates = len(loader)
lp = self.log_progress(generate_stage_name, loader, total=updates, updates=self.log_updates)
dataset = get_dataset_from_loader(loader)
dataset_duration = dataset.segment_duration
assert dataset_duration is not None
assert isinstance(dataset, AudioDataset)
target_duration = self.cfg.generate.lm.gen_duration
prompt_duration = self.cfg.generate.lm.prompt_duration
if target_duration is None:
target_duration = dataset_duration
if prompt_duration is None:
prompt_duration = dataset_duration / 4
assert prompt_duration < dataset_duration, (
f"Specified prompt duration ({prompt_duration}s) is longer",
f" than reference audio duration ({dataset_duration}s)"
)
def get_hydrated_conditions(meta: tp.List[SegmentWithAttributes]):
hydrated_conditions = []
for sample in [x.to_condition_attributes() for x in meta]:
cond_dict = {}
for cond_type in sample.__annotations__.keys():
for cond_key, cond_val in getattr(sample, cond_type).items():
if cond_key not in self.model.condition_provider.conditioners.keys():
continue
if is_jsonable(cond_val):
cond_dict[cond_key] = cond_val
elif isinstance(cond_val, WavCondition):
cond_dict[cond_key] = cond_val.path
elif isinstance(cond_val, JointEmbedCondition):
cond_dict[cond_key] = cond_val.text # only support text at inference for now
else:
# if we reached this point, it is not clear how to log the condition
# so we just log the type.
cond_dict[cond_key] = str(type(cond_val))
continue
hydrated_conditions.append(cond_dict)
return hydrated_conditions
metrics: dict = {}
average = flashy.averager()
for batch in lp:
audio, video_list, meta = batch
assert isinstance(video_list,list)
assert len(video_list)==2
local_video = video_list[0]
global_video = video_list[1]
hydrated_conditions = meta[0].meta.video_path
sample_generation_params = {
**{f'classifier_free_guidance_{k}': v for k, v in self.cfg.classifier_free_guidance.items()},
**self.generation_params
}
if self.cfg.generate.lm.unprompted_samples:
if self.cfg.generate.lm.gen_gt_samples:
# get the ground truth instead of generation
self.logger.warn(
"Use ground truth instead of audio generation as generate.lm.gen_gt_samples=true")
gen_unprompted_audio = audio
rtf = 1.
else:
gen_unprompted_outputs = self.run_generate_step(
batch, gen_duration=target_duration, prompt_duration=None,
**self.generation_params)
gen_unprompted_audio = gen_unprompted_outputs['gen_audio'].cpu()
rtf = gen_unprompted_outputs['rtf']
sample_manager.add_samples(
gen_unprompted_audio, self.epoch, hydrated_conditions,
ground_truth_wavs=audio, generation_args=sample_generation_params)
if self.cfg.generate.lm.prompted_samples:
gen_outputs = self.run_generate_step(
batch, gen_duration=target_duration, prompt_duration=prompt_duration,
**self.generation_params)
gen_audio = gen_outputs['gen_audio'].cpu()
prompt_audio = gen_outputs['prompt_audio'].cpu()
sample_manager.add_samples(
gen_audio, self.epoch, hydrated_conditions,
prompt_wavs=prompt_audio, ground_truth_wavs=audio,
generation_args=sample_generation_params)
metrics['rtf'] = rtf
metrics = average(metrics)
flashy.distrib.barrier()
return metrics
def generate(self) -> dict:
"""Generate stage."""
self.model.eval()
with torch.no_grad():
return self.generate_audio()
def run_epoch(self):
if self.cfg.cache.write:
if ((self.epoch - 1) % self.cfg.cache.write_num_shards) != self.cfg.cache.write_shard:
return
super().run_epoch()
def train(self):
"""Train stage.
"""
if self._cached_batch_writer is not None:
self._cached_batch_writer.start_epoch(self.epoch)
if self._cached_batch_loader is None:
dataset = get_dataset_from_loader(self.dataloaders['train'])
assert isinstance(dataset, AudioDataset)
dataset.current_epoch = self.epoch
else:
self._cached_batch_loader.start_epoch(self.epoch)
return super().train()
def evaluate_audio_generation(self) -> dict:
"""Evaluate audio generation with off-the-shelf metrics."""
evaluate_stage_name = f'{self.current_stage}_generation'
# instantiate evaluation metrics, if at least one metric is defined, run audio generation evaluation
fad: tp.Optional[eval_metrics.FrechetAudioDistanceMetric] = None
kldiv: tp.Optional[eval_metrics.KLDivergenceMetric] = None
text_consistency: tp.Optional[eval_metrics.TextConsistencyMetric] = None
chroma_cosine: tp.Optional[eval_metrics.ChromaCosineSimilarityMetric] = None
should_run_eval = False
eval_chroma_wavs: tp.Optional[torch.Tensor] = None
if self.cfg.evaluate.metrics.fad:
fad = builders.get_fad(self.cfg.metrics.fad).to(self.device)
should_run_eval = True
if self.cfg.evaluate.metrics.kld:
kldiv = builders.get_kldiv(self.cfg.metrics.kld).to(self.device)
should_run_eval = True
if self.cfg.evaluate.metrics.text_consistency:
text_consistency = builders.get_text_consistency(self.cfg.metrics.text_consistency).to(self.device)
should_run_eval = True
if self.cfg.evaluate.metrics.chroma_cosine:
chroma_cosine = builders.get_chroma_cosine_similarity(self.cfg.metrics.chroma_cosine).to(self.device)
# if we have predefind wavs for chroma we should purge them for computing the cosine metric
has_predefined_eval_chromas = 'self_wav' in self.model.condition_provider.conditioners and \
self.model.condition_provider.conditioners['self_wav'].has_eval_wavs()
if has_predefined_eval_chromas:
warn_once(self.logger, "Attempting to run cosine eval for config with pre-defined eval chromas! "
'Resetting eval chromas to None for evaluation.')
eval_chroma_wavs = self.model.condition_provider.conditioners.self_wav.eval_wavs # type: ignore
self.model.condition_provider.conditioners.self_wav.reset_eval_wavs(None) # type: ignore
should_run_eval = True
def get_compressed_audio(audio: torch.Tensor) -> torch.Tensor:
audio_tokens, scale = self.compression_model.encode(audio.to(self.device))
compressed_audio = self.compression_model.decode(audio_tokens, scale)
return compressed_audio[..., :audio.shape[-1]]
metrics: dict = {}
if should_run_eval:
loader = self.dataloaders['evaluate']
updates = len(loader)
lp = self.log_progress(f'{evaluate_stage_name} inference', loader, total=updates, updates=self.log_updates)
average = flashy.averager()
dataset = get_dataset_from_loader(loader)
assert isinstance(dataset, AudioDataset)
self.logger.info(f"Computing evaluation metrics on {len(dataset)} samples")
for idx, batch in enumerate(lp):
audio, meta = batch
assert all([self.cfg.sample_rate == m.sample_rate for m in meta])
target_duration = audio.shape[-1] / self.cfg.sample_rate
if self.cfg.evaluate.fixed_generation_duration:
target_duration = self.cfg.evaluate.fixed_generation_duration
gen_outputs = self.run_generate_step(
batch, gen_duration=target_duration,
**self.generation_params
)
y_pred = gen_outputs['gen_audio'].detach()
y_pred = y_pred[..., :audio.shape[-1]]
normalize_kwargs = dict(self.cfg.generate.audio)
normalize_kwargs.pop('format', None)
y_pred = torch.stack([normalize_audio(w, **normalize_kwargs) for w in y_pred], dim=0).cpu()
y = audio.cpu() # should already be on CPU but just in case
sizes = torch.tensor([m.n_frames for m in meta]) # actual sizes without padding
sample_rates = torch.tensor([m.sample_rate for m in meta]) # sample rates for audio samples
audio_stems = [Path(m.meta.path).stem + f"_{m.seek_time}" for m in meta]
if fad is not None:
if self.cfg.metrics.fad.use_gt:
y_pred = get_compressed_audio(y).cpu()
fad.update(y_pred, y, sizes, sample_rates, audio_stems)
if kldiv is not None:
if self.cfg.metrics.kld.use_gt:
y_pred = get_compressed_audio(y).cpu()
kldiv.update(y_pred, y, sizes, sample_rates)
if text_consistency is not None:
texts = [m.description for m in meta]
if self.cfg.metrics.text_consistency.use_gt:
y_pred = y
text_consistency.update(y_pred, texts, sizes, sample_rates)
if chroma_cosine is not None:
if self.cfg.metrics.chroma_cosine.use_gt:
y_pred = get_compressed_audio(y).cpu()
chroma_cosine.update(y_pred, y, sizes, sample_rates)
# restore chroma conditioner's eval chroma wavs
if eval_chroma_wavs is not None:
self.model.condition_provider.conditioners['self_wav'].reset_eval_wavs(eval_chroma_wavs)
flashy.distrib.barrier()
if fad is not None:
metrics['fad'] = fad.compute()
if kldiv is not None:
kld_metrics = kldiv.compute()
metrics.update(kld_metrics)
if text_consistency is not None:
metrics['text_consistency'] = text_consistency.compute()
if chroma_cosine is not None:
metrics['chroma_cosine'] = chroma_cosine.compute()
metrics = average(metrics)
metrics = flashy.distrib.average_metrics(metrics, len(loader))
return metrics
def evaluate(self) -> dict:
"""Evaluate stage."""
self.model.eval()
with torch.no_grad():
metrics: dict = {}
if self.cfg.evaluate.metrics.base:
metrics.update(self.common_train_valid('evaluate'))
gen_metrics = self.evaluate_audio_generation()
return {**metrics, **gen_metrics}