# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. # pyre-strict """ Main model for using MelodyFlow. This will combine all the required components and provide easy access to the generation API. """ import typing as tp from audiocraft.utils.autocast import TorchAutocast import torch from .genmodel import BaseGenModel from ..modules.conditioners import ConditioningAttributes from ..utils.utils import vae_sample from .loaders import load_compression_model, load_dit_model_melodyflow class MelodyFlow(BaseGenModel): """MelodyFlow main model with convenient generation API. Args: See MelodyFlow class. """ def __init__(self, **kwargs): super().__init__(**kwargs) self.set_generation_params() self.set_editing_params() if self.device.type == 'cpu' or self.device.type == 'mps': self.autocast = TorchAutocast(enabled=False) else: self.autocast = TorchAutocast( enabled=True, device_type=self.device.type, dtype=torch.bfloat16) @staticmethod def get_pretrained(name: str = 'facebook/melodyflow-t24-30secs', device=None): # TODO complete the list of pretrained models """ """ if device is None: if torch.cuda.device_count(): device = 'cuda' elif torch.backends.mps.is_available(): device = 'mps' else: device = 'cpu' compression_model = load_compression_model(name, device=device) def _remove_weight_norm(module): if hasattr(module, "conv"): if hasattr(module.conv, "conv"): torch.nn.utils.parametrize.remove_parametrizations( module.conv.conv, "weight" ) if hasattr(module, "convtr"): if hasattr(module.convtr, "convtr"): torch.nn.utils.parametrize.remove_parametrizations( module.convtr.convtr, "weight" ) def _clear_weight_norm(module): _remove_weight_norm(module) for child in module.children(): _clear_weight_norm(child) compression_model.to('cpu') _clear_weight_norm(compression_model) compression_model.to(device) lm = load_dit_model_melodyflow(name, device=device) kwargs = {'name': name, 'compression_model': compression_model, 'lm': lm} return MelodyFlow(**kwargs) def set_generation_params( self, solver: str = "midpoint", steps: int = 64, duration: float = 10.0, ) -> tp.Dict[str, torch.Tensor]: """Set regularized inversion parameters for MelodyFlow. Args: solver (str, optional): ODE solver, either euler or midpoint. steps (int, optional): number of inference steps. """ self.generation_params = { 'solver': solver, 'steps': steps, 'duration': duration, } def set_editing_params( self, solver: str = "euler", steps: int = 25, target_flowstep: float = 0.0, regularize: bool = True, regularize_iters: int = 4, keep_last_k_iters: int = 2, lambda_kl: float = 0.2, ) -> tp.Dict[str, torch.Tensor]: """Set regularized inversion parameters for MelodyFlow. Args: solver (str, optional): ODE solver, either euler or midpoint. steps (int, optional): number of inference steps. target_flowstep (float): Target flow step. regularize (bool): Regularize each solver step. regularize_iters (int, optional): Number of regularization iterations. keep_last_k_iters (int, optional): Number of meaningful regularization iterations for moving average computation. lambda_kl (float, optional): KL regularization loss weight. """ self.editing_params = { 'solver': solver, 'steps': steps, 'target_flowstep': target_flowstep, 'regularize': regularize, 'regularize_iters': regularize_iters, 'keep_last_k_iters': keep_last_k_iters, 'lambda_kl': lambda_kl, } def encode_audio(self, waveform: torch.Tensor) -> torch.Tensor: """Generate Audio from tokens.""" assert waveform.dim() == 3 with torch.no_grad(): latent_sequence = self.compression_model.encode(waveform)[0].squeeze(1) return latent_sequence def generate_audio(self, gen_tokens: torch.Tensor) -> torch.Tensor: """Generate Audio from tokens.""" assert gen_tokens.dim() == 3 with torch.no_grad(): if self.lm.latent_mean.shape[1] != gen_tokens.shape[1]: # tokens directly emanate from the VAE encoder mean, scale = gen_tokens.chunk(2, dim=1) gen_tokens = vae_sample(mean, scale) else: # tokens emanate from the generator gen_tokens = gen_tokens * (self.lm.latent_std + 1e-5) + self.lm.latent_mean gen_audio = self.compression_model.decode(gen_tokens, None) return gen_audio def generate_unconditional(self, num_samples: int, progress: bool = False, return_tokens: bool = False) -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, torch.Tensor]]: """Generate samples in an unconditional manner. Args: num_samples (int): Number of samples to be generated. progress (bool, optional): Flag to display progress of the generation process. Defaults to False. """ descriptions: tp.List[tp.Optional[str]] = [None] * num_samples attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, None) assert prompt_tokens is None tokens = self._generate_tokens(attributes=attributes, prompt_tokens=prompt_tokens, progress=progress, **self.generation_params, ) if return_tokens: return self.generate_audio(tokens), tokens return self.generate_audio(tokens) def generate(self, descriptions: tp.List[str], progress: bool = False, return_tokens: bool = False) \ -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, torch.Tensor]]: """Generate samples conditioned on text. Args: descriptions (list of str): A list of strings used as text conditioning. progress (bool, optional): Flag to display progress of the generation process. Defaults to False. """ attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, None) assert prompt_tokens is None tokens = self._generate_tokens(attributes=attributes, prompt_tokens=prompt_tokens, progress=progress, **self.generation_params, ) if return_tokens: return self.generate_audio(tokens), tokens return self.generate_audio(tokens) def edit(self, prompt_tokens: torch.Tensor, descriptions: tp.List[str], src_descriptions: tp.Optional[tp.List[str]] = None, progress: bool = False, return_tokens: bool = False, ) -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, torch.Tensor]]: """Generate samples conditioned on text. Args: prompt_tokens (torch.Tensor, optional): Audio prompt used as initial latent sequence. descriptions (list of str): A list of strings used as editing conditioning. inversion (str): Inversion method (either ddim or fm_renoise) target_flowstep (float): Target flow step pivot in [0, 1[. steps (int): number of solver steps. src_descriptions (list of str): A list of strings used as conditioning during latent inversion. progress (bool, optional): Flag to display progress of the generation process. Defaults to False. return_tokens (bool): Whether to return the generated tokens. """ empty_attributes, no_tokens = self._prepare_tokens_and_attributes( [""] if src_descriptions is None else src_descriptions, None) assert no_tokens is None edit_attributes, no_tokens = self._prepare_tokens_and_attributes(descriptions, None) assert no_tokens is None inversion_params = self.editing_params.copy() override_total_steps = inversion_params["steps"] * ( inversion_params["regularize_iters"] + 1) if inversion_params["regularize"] else inversion_params["steps"] * 2 current_step_offset: int = 0 def _progress_callback(elapsed_steps: int, total_steps: int): elapsed_steps += current_step_offset if self._progress_callback is not None: self._progress_callback(elapsed_steps, override_total_steps) else: print(f'{elapsed_steps: 6d} / {override_total_steps: 6d}', end='\r') intermediate_tokens = self._generate_tokens(attributes=empty_attributes, prompt_tokens=prompt_tokens, source_flowstep=1.0, progress=progress, callback=_progress_callback, **inversion_params, ) if intermediate_tokens.shape[0] < len(descriptions): intermediate_tokens = intermediate_tokens.repeat(len(descriptions)//intermediate_tokens.shape[0], 1, 1) current_step_offset += inversion_params["steps"] * ( inversion_params["regularize_iters"]) if inversion_params["regularize"] else inversion_params["steps"] inversion_params.pop("regularize") final_tokens = self._generate_tokens(attributes=edit_attributes, prompt_tokens=intermediate_tokens, source_flowstep=inversion_params.pop("target_flowstep"), target_flowstep=1.0, progress=progress, callback=_progress_callback, **inversion_params,) if return_tokens: return self.generate_audio(final_tokens), final_tokens return self.generate_audio(final_tokens) def _generate_tokens(self, attributes: tp.List[ConditioningAttributes], prompt_tokens: tp.Optional[torch.Tensor], progress: bool = False, callback: tp.Optional[tp.Callable[[int, int], None]] = None, **kwargs) -> torch.Tensor: """Generate continuous audio tokens given audio prompt and/or conditions. Args: attributes (list of ConditioningAttributes): Conditions used for generation (here text). prompt_tokens (torch.Tensor, optional): Audio prompt used as initial latent sequence. progress (bool, optional): Flag to display progress of the generation process. Defaults to False. Returns: torch.Tensor: Generated audio, of shape [B, C, T], T is defined by the generation params. """ generate_params = kwargs.copy() total_gen_len = prompt_tokens.shape[-1] if prompt_tokens is not None else int( generate_params.pop('duration') * self.frame_rate) current_step_offset: int = 0 def _progress_callback(elapsed_steps: int, total_steps: int): elapsed_steps += current_step_offset if self._progress_callback is not None: self._progress_callback(elapsed_steps, total_steps) else: print(f'{elapsed_steps: 6d} / {total_steps: 6d}', end='\r') if progress and callback is None: callback = _progress_callback assert total_gen_len <= int(self.max_duration * self.frame_rate) with self.autocast: gen_tokens = self.lm.generate( prompt=prompt_tokens, conditions=attributes, callback=callback, max_gen_len=total_gen_len, **generate_params, ) return gen_tokens