Spaces:
Paused
Paused
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
import typing as tp | |
import flashy | |
import julius | |
import omegaconf | |
import torch | |
import torch.nn.functional as F | |
from . import builders | |
from . import base | |
from .. import models | |
from ..modules.diffusion_schedule import NoiseSchedule | |
from ..metrics import RelativeVolumeMel | |
from ..models.builders import get_processor | |
from ..utils.samples.manager import SampleManager | |
from ..solvers.compression import CompressionSolver | |
class PerStageMetrics: | |
"""Handle prompting the metrics per stage. | |
It outputs the metrics per range of diffusion states. | |
e.g. avg loss when t in [250, 500] | |
""" | |
def __init__(self, num_steps: int, num_stages: int = 4): | |
self.num_steps = num_steps | |
self.num_stages = num_stages | |
def __call__(self, losses: dict, step: tp.Union[int, torch.Tensor]): | |
if type(step) is int: | |
stage = int((step / self.num_steps) * self.num_stages) | |
return {f"{name}_{stage}": loss for name, loss in losses.items()} | |
elif type(step) is torch.Tensor: | |
stage_tensor = ((step / self.num_steps) * self.num_stages).long() | |
out: tp.Dict[str, float] = {} | |
for stage_idx in range(self.num_stages): | |
mask = (stage_tensor == stage_idx) | |
N = mask.sum() | |
stage_out = {} | |
if N > 0: # pass if no elements in the stage | |
for name, loss in losses.items(): | |
stage_loss = (mask * loss).sum() / N | |
stage_out[f"{name}_{stage_idx}"] = stage_loss | |
out = {**out, **stage_out} | |
return out | |
class DataProcess: | |
"""Apply filtering or resampling. | |
Args: | |
initial_sr (int): Initial sample rate. | |
target_sr (int): Target sample rate. | |
use_resampling: Whether to use resampling or not. | |
use_filter (bool): | |
n_bands (int): Number of bands to consider. | |
idx_band (int): | |
device (torch.device or str): | |
cutoffs (): | |
boost (bool): | |
""" | |
def __init__(self, initial_sr: int = 24000, target_sr: int = 16000, use_resampling: bool = False, | |
use_filter: bool = False, n_bands: int = 4, | |
idx_band: int = 0, device: torch.device = torch.device('cpu'), cutoffs=None, boost=False): | |
"""Apply filtering or resampling | |
Args: | |
initial_sr (int): sample rate of the dataset | |
target_sr (int): sample rate after resampling | |
use_resampling (bool): whether or not performs resampling | |
use_filter (bool): when True filter the data to keep only one frequency band | |
n_bands (int): Number of bands used | |
cuts (none or list): The cutoff frequencies of the band filtering | |
if None then we use mel scale bands. | |
idx_band (int): index of the frequency band. 0 are lows ... (n_bands - 1) highs | |
boost (bool): make the data scale match our music dataset. | |
""" | |
assert idx_band < n_bands | |
self.idx_band = idx_band | |
if use_filter: | |
if cutoffs is not None: | |
self.filter = julius.SplitBands(sample_rate=initial_sr, cutoffs=cutoffs).to(device) | |
else: | |
self.filter = julius.SplitBands(sample_rate=initial_sr, n_bands=n_bands).to(device) | |
self.use_filter = use_filter | |
self.use_resampling = use_resampling | |
self.target_sr = target_sr | |
self.initial_sr = initial_sr | |
self.boost = boost | |
def process_data(self, x, metric=False): | |
if x is None: | |
return None | |
if self.boost: | |
x /= torch.clamp(x.std(dim=(1, 2), keepdim=True), min=1e-4) | |
x * 0.22 | |
if self.use_filter and not metric: | |
x = self.filter(x)[self.idx_band] | |
if self.use_resampling: | |
x = julius.resample_frac(x, old_sr=self.initial_sr, new_sr=self.target_sr) | |
return x | |
def inverse_process(self, x): | |
"""Upsampling only.""" | |
if self.use_resampling: | |
x = julius.resample_frac(x, old_sr=self.target_sr, new_sr=self.target_sr) | |
return x | |
class DiffusionSolver(base.StandardSolver): | |
"""Solver for compression task. | |
The diffusion task allows for MultiBand diffusion model training. | |
Args: | |
cfg (DictConfig): Configuration. | |
""" | |
def __init__(self, cfg: omegaconf.DictConfig): | |
super().__init__(cfg) | |
self.cfg = cfg | |
self.device = cfg.device | |
self.sample_rate: int = self.cfg.sample_rate | |
self.codec_model = CompressionSolver.model_from_checkpoint( | |
cfg.compression_model_checkpoint, device=self.device) | |
self.codec_model.set_num_codebooks(cfg.n_q) | |
assert self.codec_model.sample_rate == self.cfg.sample_rate, ( | |
f"Codec model sample rate is {self.codec_model.sample_rate} but " | |
f"Solver sample rate is {self.cfg.sample_rate}." | |
) | |
assert self.codec_model.sample_rate == self.sample_rate, \ | |
f"Sample rate of solver {self.sample_rate} and codec {self.codec_model.sample_rate} " \ | |
"don't match." | |
self.sample_processor = get_processor(cfg.processor, sample_rate=self.sample_rate) | |
self.register_stateful('sample_processor') | |
self.sample_processor.to(self.device) | |
self.schedule = NoiseSchedule( | |
**cfg.schedule, device=self.device, sample_processor=self.sample_processor) | |
self.eval_metric: tp.Optional[torch.nn.Module] = None | |
self.rvm = RelativeVolumeMel() | |
self.data_processor = DataProcess(initial_sr=self.sample_rate, target_sr=cfg.resampling.target_sr, | |
use_resampling=cfg.resampling.use, cutoffs=cfg.filter.cutoffs, | |
use_filter=cfg.filter.use, n_bands=cfg.filter.n_bands, | |
idx_band=cfg.filter.idx_band, device=self.device) | |
def best_metric_name(self) -> tp.Optional[str]: | |
if self._current_stage == "evaluate": | |
return 'rvm' | |
else: | |
return 'loss' | |
def get_condition(self, wav: torch.Tensor) -> torch.Tensor: | |
codes, scale = self.codec_model.encode(wav) | |
assert scale is None, "Scaled compression models not supported." | |
emb = self.codec_model.decode_latent(codes) | |
return emb | |
def build_model(self): | |
"""Build model and optimizer as well as optional Exponential Moving Average of the model. | |
""" | |
# Model and optimizer | |
self.model = models.builders.get_diffusion_model(self.cfg).to(self.device) | |
self.optimizer = builders.get_optimizer(self.model.parameters(), self.cfg.optim) | |
self.register_stateful('model', 'optimizer') | |
self.register_best_state('model') | |
self.register_ema('model') | |
def build_dataloaders(self): | |
"""Build audio dataloaders for each stage.""" | |
self.dataloaders = builders.get_audio_datasets(self.cfg) | |
def show(self): | |
# TODO | |
raise NotImplementedError() | |
def run_step(self, idx: int, batch: torch.Tensor, metrics: dict): | |
"""Perform one training or valid step on a given batch.""" | |
x = batch.to(self.device) | |
loss_fun = F.mse_loss if self.cfg.loss.kind == 'mse' else F.l1_loss | |
condition = self.get_condition(x) # [bs, 128, T/hop, n_emb] | |
sample = self.data_processor.process_data(x) | |
input_, target, step = self.schedule.get_training_item(sample, | |
tensor_step=self.cfg.schedule.variable_step_batch) | |
out = self.model(input_, step, condition=condition).sample | |
base_loss = loss_fun(out, target, reduction='none').mean(dim=(1, 2)) | |
reference_loss = loss_fun(input_, target, reduction='none').mean(dim=(1, 2)) | |
loss = base_loss / reference_loss ** self.cfg.loss.norm_power | |
if self.is_training: | |
loss.mean().backward() | |
flashy.distrib.sync_model(self.model) | |
self.optimizer.step() | |
self.optimizer.zero_grad() | |
metrics = { | |
'loss': loss.mean(), 'normed_loss': (base_loss / reference_loss).mean(), | |
} | |
metrics.update(self.per_stage({'loss': loss, 'normed_loss': base_loss / reference_loss}, step)) | |
metrics.update({ | |
'std_in': input_.std(), 'std_out': out.std()}) | |
return metrics | |
def run_epoch(self): | |
# reset random seed at the beginning of the epoch | |
self.rng = torch.Generator() | |
self.rng.manual_seed(1234 + self.epoch) | |
self.per_stage = PerStageMetrics(self.schedule.num_steps, self.cfg.metrics.num_stage) | |
# run epoch | |
super().run_epoch() | |
def evaluate(self): | |
"""Evaluate stage. | |
Runs audio reconstruction evaluation. | |
""" | |
self.model.eval() | |
evaluate_stage_name = f'{self.current_stage}' | |
loader = self.dataloaders['evaluate'] | |
updates = len(loader) | |
lp = self.log_progress(f'{evaluate_stage_name} estimate', loader, total=updates, updates=self.log_updates) | |
metrics = {} | |
n = 1 | |
for idx, batch in enumerate(lp): | |
x = batch.to(self.device) | |
with torch.no_grad(): | |
y_pred = self.regenerate(x) | |
y_pred = y_pred.cpu() | |
y = batch.cpu() # should already be on CPU but just in case | |
rvm = self.rvm(y_pred, y) | |
lp.update(**rvm) | |
if len(metrics) == 0: | |
metrics = rvm | |
else: | |
for key in rvm.keys(): | |
metrics[key] = (metrics[key] * n + rvm[key]) / (n + 1) | |
metrics = flashy.distrib.average_metrics(metrics) | |
return metrics | |
def regenerate(self, wav: torch.Tensor, step_list: tp.Optional[list] = None): | |
"""Regenerate the given waveform.""" | |
condition = self.get_condition(wav) | |
initial = self.schedule.get_initial_noise(self.data_processor.process_data(wav)) # sampling rate changes. | |
result = self.schedule.generate_subsampled(self.model, initial=initial, condition=condition, | |
step_list=step_list) | |
result = self.data_processor.inverse_process(result) | |
return result | |
def generate(self): | |
"""Generate stage.""" | |
sample_manager = SampleManager(self.xp) | |
self.model.eval() | |
generate_stage_name = f'{self.current_stage}' | |
loader = self.dataloaders['generate'] | |
updates = len(loader) | |
lp = self.log_progress(generate_stage_name, loader, total=updates, updates=self.log_updates) | |
for batch in lp: | |
reference, _ = batch | |
reference = reference.to(self.device) | |
estimate = self.regenerate(reference) | |
reference = reference.cpu() | |
estimate = estimate.cpu() | |
sample_manager.add_samples(estimate, self.epoch, ground_truth_wavs=reference) | |
flashy.distrib.barrier() | |