Spaces:
Sleeping
Sleeping
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
import logging | |
import multiprocessing | |
from pathlib import Path | |
import typing as tp | |
import flashy | |
import omegaconf | |
import torch | |
from torch import nn | |
from . import base, builders | |
from .. import models, quantization | |
from ..utils import checkpoint | |
from ..utils.samples.manager import SampleManager | |
from ..utils.utils import get_pool_executor | |
logger = logging.getLogger(__name__) | |
class CompressionSolver(base.StandardSolver): | |
"""Solver for compression task. | |
The compression task combines a set of perceptual and objective losses | |
to train an EncodecModel (composed of an encoder-decoder and a quantizer) | |
to perform high fidelity audio reconstruction. | |
""" | |
def __init__(self, cfg: omegaconf.DictConfig): | |
super().__init__(cfg) | |
self.rng: torch.Generator # set at each epoch | |
self.adv_losses = builders.get_adversarial_losses(self.cfg) | |
self.aux_losses = nn.ModuleDict() | |
self.info_losses = nn.ModuleDict() | |
assert not cfg.fsdp.use, "FSDP not supported by CompressionSolver." | |
loss_weights = dict() | |
for loss_name, weight in self.cfg.losses.items(): | |
if loss_name in ['adv', 'feat']: | |
for adv_name, _ in self.adv_losses.items(): | |
loss_weights[f'{loss_name}_{adv_name}'] = weight | |
elif weight > 0: | |
self.aux_losses[loss_name] = builders.get_loss(loss_name, self.cfg) | |
loss_weights[loss_name] = weight | |
else: | |
self.info_losses[loss_name] = builders.get_loss(loss_name, self.cfg) | |
self.balancer = builders.get_balancer(loss_weights, self.cfg.balancer) | |
self.register_stateful('adv_losses') | |
def best_metric_name(self) -> tp.Optional[str]: | |
# best model is the last for the compression model | |
return None | |
def build_model(self): | |
"""Instantiate model and optimizer.""" | |
# Model and optimizer | |
self.model = models.builders.get_compression_model(self.cfg).to(self.device) | |
self.optimizer = builders.get_optimizer(self.model.parameters(), self.cfg.optim) | |
self.register_stateful('model', 'optimizer') | |
self.register_best_state('model') | |
self.register_ema('model') | |
def build_dataloaders(self): | |
"""Instantiate audio dataloaders for each stage.""" | |
self.dataloaders = builders.get_audio_datasets(self.cfg) | |
def show(self): | |
"""Show the compression model and employed adversarial loss.""" | |
self.logger.info(f"Compression model with {self.model.quantizer.total_codebooks} codebooks:") | |
self.log_model_summary(self.model) | |
self.logger.info("Adversarial loss:") | |
self.log_model_summary(self.adv_losses) | |
self.logger.info("Auxiliary losses:") | |
self.logger.info(self.aux_losses) | |
self.logger.info("Info losses:") | |
self.logger.info(self.info_losses) | |
def run_step(self, idx: int, batch: torch.Tensor, metrics: dict): | |
"""Perform one training or valid step on a given batch.""" | |
x = batch.to(self.device) | |
y = x.clone() | |
qres = self.model(x) | |
assert isinstance(qres, quantization.QuantizedResult) | |
y_pred = qres.x | |
# Log bandwidth in kb/s | |
metrics['bandwidth'] = qres.bandwidth.mean() | |
if self.is_training: | |
d_losses: dict = {} | |
if len(self.adv_losses) > 0 and torch.rand(1, generator=self.rng).item() <= 1 / self.cfg.adversarial.every: | |
for adv_name, adversary in self.adv_losses.items(): | |
disc_loss = adversary.train_adv(y_pred, y) | |
d_losses[f'd_{adv_name}'] = disc_loss | |
metrics['d_loss'] = torch.sum(torch.stack(list(d_losses.values()))) | |
metrics.update(d_losses) | |
balanced_losses: dict = {} | |
other_losses: dict = {} | |
# penalty from quantization | |
if qres.penalty is not None and qres.penalty.requires_grad: | |
other_losses['penalty'] = qres.penalty # penalty term from the quantizer | |
# adversarial losses | |
for adv_name, adversary in self.adv_losses.items(): | |
adv_loss, feat_loss = adversary(y_pred, y) | |
balanced_losses[f'adv_{adv_name}'] = adv_loss | |
balanced_losses[f'feat_{adv_name}'] = feat_loss | |
# auxiliary losses | |
for loss_name, criterion in self.aux_losses.items(): | |
loss = criterion(y_pred, y) | |
balanced_losses[loss_name] = loss | |
# weighted losses | |
metrics.update(balanced_losses) | |
metrics.update(other_losses) | |
metrics.update(qres.metrics) | |
if self.is_training: | |
# backprop losses that are not handled by balancer | |
other_loss = torch.tensor(0., device=self.device) | |
if 'penalty' in other_losses: | |
other_loss += other_losses['penalty'] | |
if other_loss.requires_grad: | |
other_loss.backward(retain_graph=True) | |
ratio1 = sum(p.grad.data.norm(p=2).pow(2) | |
for p in self.model.parameters() if p.grad is not None) | |
assert isinstance(ratio1, torch.Tensor) | |
metrics['ratio1'] = ratio1.sqrt() | |
# balancer losses backward, returns effective training loss | |
# with effective weights at the current batch. | |
metrics['g_loss'] = self.balancer.backward(balanced_losses, y_pred) | |
# add metrics corresponding to weight ratios | |
metrics.update(self.balancer.metrics) | |
ratio2 = sum(p.grad.data.norm(p=2).pow(2) | |
for p in self.model.parameters() if p.grad is not None) | |
assert isinstance(ratio2, torch.Tensor) | |
metrics['ratio2'] = ratio2.sqrt() | |
# optim | |
flashy.distrib.sync_model(self.model) | |
if self.cfg.optim.max_norm: | |
torch.nn.utils.clip_grad_norm_( | |
self.model.parameters(), self.cfg.optim.max_norm | |
) | |
self.optimizer.step() | |
self.optimizer.zero_grad() | |
# informative losses only | |
info_losses: dict = {} | |
with torch.no_grad(): | |
for loss_name, criterion in self.info_losses.items(): | |
loss = criterion(y_pred, y) | |
info_losses[loss_name] = loss | |
metrics.update(info_losses) | |
# aggregated GAN losses: this is useful to report adv and feat across different adversarial loss setups | |
adv_losses = [loss for loss_name, loss in metrics.items() if loss_name.startswith('adv')] | |
if len(adv_losses) > 0: | |
metrics['adv'] = torch.sum(torch.stack(adv_losses)) | |
feat_losses = [loss for loss_name, loss in metrics.items() if loss_name.startswith('feat')] | |
if len(feat_losses) > 0: | |
metrics['feat'] = torch.sum(torch.stack(feat_losses)) | |
return metrics | |
def run_epoch(self): | |
# reset random seed at the beginning of the epoch | |
self.rng = torch.Generator() | |
self.rng.manual_seed(1234 + self.epoch) | |
# run epoch | |
super().run_epoch() | |
def evaluate(self): | |
"""Evaluate stage. Runs audio reconstruction evaluation.""" | |
self.model.eval() | |
evaluate_stage_name = str(self.current_stage) | |
loader = self.dataloaders['evaluate'] | |
updates = len(loader) | |
lp = self.log_progress(f'{evaluate_stage_name} inference', loader, total=updates, updates=self.log_updates) | |
average = flashy.averager() | |
pendings = [] | |
ctx = multiprocessing.get_context('spawn') | |
with get_pool_executor(self.cfg.evaluate.num_workers, mp_context=ctx) as pool: | |
for idx, batch in enumerate(lp): | |
x = batch.to(self.device) | |
with torch.no_grad(): | |
qres = self.model(x) | |
y_pred = qres.x.cpu() | |
y = batch.cpu() # should already be on CPU but just in case | |
pendings.append(pool.submit(evaluate_audio_reconstruction, y_pred, y, self.cfg)) | |
metrics_lp = self.log_progress(f'{evaluate_stage_name} metrics', pendings, updates=self.log_updates) | |
for pending in metrics_lp: | |
metrics = pending.result() | |
metrics = average(metrics) | |
metrics = flashy.distrib.average_metrics(metrics, len(loader)) | |
return metrics | |
def generate(self): | |
"""Generate stage.""" | |
self.model.eval() | |
sample_manager = SampleManager(self.xp, map_reference_to_sample_id=True) | |
generate_stage_name = str(self.current_stage) | |
loader = self.dataloaders['generate'] | |
updates = len(loader) | |
lp = self.log_progress(generate_stage_name, loader, total=updates, updates=self.log_updates) | |
for batch in lp: | |
reference, _ = batch | |
reference = reference.to(self.device) | |
with torch.no_grad(): | |
qres = self.model(reference) | |
assert isinstance(qres, quantization.QuantizedResult) | |
reference = reference.cpu() | |
estimate = qres.x.cpu() | |
sample_manager.add_samples(estimate, self.epoch, ground_truth_wavs=reference) | |
flashy.distrib.barrier() | |
def load_from_pretrained(self, name: str) -> dict: | |
model = models.CompressionModel.get_pretrained(name) | |
if isinstance(model, models.DAC): | |
raise RuntimeError("Cannot fine tune a DAC model.") | |
elif isinstance(model, models.HFEncodecCompressionModel): | |
self.logger.warning('Trying to automatically convert a HuggingFace model ' | |
'to AudioCraft, this might fail!') | |
state = model.model.state_dict() | |
new_state = {} | |
for k, v in state.items(): | |
if k.startswith('decoder.layers') and '.conv.' in k and '.block.' not in k: | |
# We need to determine if this a convtr or a regular conv. | |
layer = int(k.split('.')[2]) | |
if isinstance(model.model.decoder.layers[layer].conv, torch.nn.ConvTranspose1d): | |
k = k.replace('.conv.', '.convtr.') | |
k = k.replace('encoder.layers.', 'encoder.model.') | |
k = k.replace('decoder.layers.', 'decoder.model.') | |
k = k.replace('conv.', 'conv.conv.') | |
k = k.replace('convtr.', 'convtr.convtr.') | |
k = k.replace('quantizer.layers.', 'quantizer.vq.layers.') | |
k = k.replace('.codebook.', '._codebook.') | |
new_state[k] = v | |
state = new_state | |
elif isinstance(model, models.EncodecModel): | |
state = model.state_dict() | |
else: | |
raise RuntimeError(f"Cannot fine tune model type {type(model)}.") | |
return { | |
'best_state': {'model': state} | |
} | |
def model_from_checkpoint(checkpoint_path: tp.Union[Path, str], | |
device: tp.Union[torch.device, str] = 'cpu') -> models.CompressionModel: | |
"""Instantiate a CompressionModel from a given checkpoint path or dora sig. | |
This method is a convenient endpoint to load a CompressionModel to use in other solvers. | |
Args: | |
checkpoint_path (Path or str): Path to checkpoint or dora sig from where the checkpoint is resolved. | |
This also supports pre-trained models by using a path of the form //pretrained/NAME. | |
See `model_from_pretrained` for a list of supported pretrained models. | |
use_ema (bool): Use EMA variant of the model instead of the actual model. | |
device (torch.device or str): Device on which the model is loaded. | |
""" | |
checkpoint_path = str(checkpoint_path) | |
if checkpoint_path.startswith('//pretrained/'): | |
name = checkpoint_path.split('/', 3)[-1] | |
return models.CompressionModel.get_pretrained(name, device) | |
logger = logging.getLogger(__name__) | |
logger.info(f"Loading compression model from checkpoint: {checkpoint_path}") | |
_checkpoint_path = checkpoint.resolve_checkpoint_path(checkpoint_path, use_fsdp=False) | |
assert _checkpoint_path is not None, f"Could not resolve compression model checkpoint path: {checkpoint_path}" | |
state = checkpoint.load_checkpoint(_checkpoint_path) | |
assert state is not None and 'xp.cfg' in state, f"Could not load compression model from ckpt: {checkpoint_path}" | |
cfg = state['xp.cfg'] | |
cfg.device = device | |
compression_model = models.builders.get_compression_model(cfg).to(device) | |
assert compression_model.sample_rate == cfg.sample_rate, "Compression model sample rate should match" | |
assert 'best_state' in state and state['best_state'] != {} | |
assert 'exported' not in state, "When loading an exported checkpoint, use the //pretrained/ prefix." | |
compression_model.load_state_dict(state['best_state']['model']) | |
compression_model.eval() | |
logger.info("Compression model loaded!") | |
return compression_model | |
def wrapped_model_from_checkpoint(cfg: omegaconf.DictConfig, | |
checkpoint_path: tp.Union[Path, str], | |
device: tp.Union[torch.device, str] = 'cpu') -> models.CompressionModel: | |
"""Instantiate a wrapped CompressionModel from a given checkpoint path or dora sig. | |
Args: | |
cfg (omegaconf.DictConfig): Configuration to read from for wrapped mode. | |
checkpoint_path (Path or str): Path to checkpoint or dora sig from where the checkpoint is resolved. | |
use_ema (bool): Use EMA variant of the model instead of the actual model. | |
device (torch.device or str): Device on which the model is loaded. | |
""" | |
compression_model = CompressionSolver.model_from_checkpoint(checkpoint_path, device) | |
compression_model = models.builders.get_wrapped_compression_model(compression_model, cfg) | |
return compression_model | |
def evaluate_audio_reconstruction(y_pred: torch.Tensor, y: torch.Tensor, cfg: omegaconf.DictConfig) -> dict: | |
"""Audio reconstruction evaluation method that can be conveniently pickled.""" | |
metrics = {} | |
if cfg.evaluate.metrics.visqol: | |
visqol = builders.get_visqol(cfg.metrics.visqol) | |
metrics['visqol'] = visqol(y_pred, y, cfg.sample_rate) | |
sisnr = builders.get_loss('sisnr', cfg) | |
metrics['sisnr'] = sisnr(y_pred, y) | |
return metrics | |