File size: 13,083 Bytes
9d0d223
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

"""
Main model for using MelodyFlow. This will combine all the required components
and provide easy access to the generation API.
"""

import typing as tp
from audiocraft.utils.autocast import TorchAutocast
import torch

from .genmodel import BaseGenModel
from ..modules.conditioners import ConditioningAttributes
from ..utils.utils import vae_sample
from .loaders import load_compression_model, load_dit_model_melodyflow


class MelodyFlow(BaseGenModel):
    """MelodyFlow main model with convenient generation API.
    Args:
       See MelodyFlow class.
    """

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.set_generation_params()
        self.set_editing_params()
        if self.device.type == 'cpu' or self.device.type == 'mps':
            self.autocast = TorchAutocast(enabled=False)
        else:
            self.autocast = TorchAutocast(
                enabled=True, device_type=self.device.type, dtype=torch.bfloat16)

    @staticmethod
    def get_pretrained(name: str = 'facebook/melodyflow-t24-30secs', device=None):
        # TODO complete the list of pretrained models
        """
        """
        if device is None:
            if torch.cuda.device_count():
                device = 'cuda'
            elif torch.backends.mps.is_available():
                device = 'mps'
            else:
                device = 'cpu'

        compression_model = load_compression_model(name, device=device)

        def _remove_weight_norm(module):
            if hasattr(module, "conv"):
                if hasattr(module.conv, "conv"):
                    torch.nn.utils.parametrize.remove_parametrizations(
                        module.conv.conv, "weight"
                    )
            if hasattr(module, "convtr"):
                if hasattr(module.convtr, "convtr"):
                    torch.nn.utils.parametrize.remove_parametrizations(
                        module.convtr.convtr, "weight"
                    )

        def _clear_weight_norm(module):
            _remove_weight_norm(module)
            for child in module.children():
                _clear_weight_norm(child)

        compression_model.to('cpu')
        _clear_weight_norm(compression_model)
        compression_model.to(device)

        lm = load_dit_model_melodyflow(name, device=device)

        kwargs = {'name': name, 'compression_model': compression_model, 'lm': lm}
        return MelodyFlow(**kwargs)

    def set_generation_params(
        self,
        solver: str = "midpoint",
        steps: int = 64,
        duration: float = 10.0,
    ) -> tp.Dict[str, torch.Tensor]:
        """Set regularized inversion parameters for MelodyFlow.

        Args:
            solver (str, optional): ODE solver, either euler or midpoint.
            steps (int, optional): number of inference steps.
        """
        self.generation_params = {
            'solver': solver,
            'steps': steps,
            'duration': duration,
        }

    def set_editing_params(
        self,
        solver: str = "euler",
        steps: int = 25,
        target_flowstep: float = 0.0,
        regularize: bool = True,
        regularize_iters: int = 4,
        keep_last_k_iters: int = 2,
        lambda_kl: float = 0.2,
    ) -> tp.Dict[str, torch.Tensor]:
        """Set regularized inversion parameters for MelodyFlow.

        Args:
            solver (str, optional): ODE solver, either euler or midpoint.
            steps (int, optional): number of inference steps.
            target_flowstep (float): Target flow step.
            regularize (bool): Regularize each solver step.
            regularize_iters (int, optional): Number of regularization iterations.
            keep_last_k_iters (int, optional): Number of meaningful regularization iterations for moving average computation.
            lambda_kl (float, optional): KL regularization loss weight.
        """
        self.editing_params = {
            'solver': solver,
            'steps': steps,
            'target_flowstep': target_flowstep,
            'regularize': regularize,
            'regularize_iters': regularize_iters,
            'keep_last_k_iters': keep_last_k_iters,
            'lambda_kl': lambda_kl,
        }

    def encode_audio(self, waveform: torch.Tensor) -> torch.Tensor:
        """Generate Audio from tokens."""
        assert waveform.dim() == 3
        with torch.no_grad():
            latent_sequence = self.compression_model.encode(waveform)[0].squeeze(1)
        return latent_sequence

    def generate_audio(self, gen_tokens: torch.Tensor) -> torch.Tensor:
        """Generate Audio from tokens."""
        assert gen_tokens.dim() == 3
        with torch.no_grad():
            if self.lm.latent_mean.shape[1] != gen_tokens.shape[1]:
                # tokens directly emanate from the VAE encoder
                mean, scale = gen_tokens.chunk(2, dim=1)
                gen_tokens = vae_sample(mean, scale)
            else:
                # tokens emanate from the generator
                gen_tokens = gen_tokens * (self.lm.latent_std + 1e-5) + self.lm.latent_mean
            gen_audio = self.compression_model.decode(gen_tokens, None)
        return gen_audio

    def generate_unconditional(self, num_samples: int, progress: bool = False,
                               return_tokens: bool = False) -> tp.Union[torch.Tensor,
                                                                        tp.Tuple[torch.Tensor, torch.Tensor]]:
        """Generate samples in an unconditional manner.

        Args:
            num_samples (int): Number of samples to be generated.
            progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
        """
        descriptions: tp.List[tp.Optional[str]] = [None] * num_samples
        attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, None)
        assert prompt_tokens is None
        tokens = self._generate_tokens(attributes=attributes,
                                       prompt_tokens=prompt_tokens,
                                       progress=progress,
                                       **self.generation_params,
                                       )
        if return_tokens:
            return self.generate_audio(tokens), tokens
        return self.generate_audio(tokens)

    def generate(self, descriptions: tp.List[str], progress: bool = False, return_tokens: bool = False) \
            -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, torch.Tensor]]:
        """Generate samples conditioned on text.

        Args:
            descriptions (list of str): A list of strings used as text conditioning.
            progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
        """
        attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, None)
        assert prompt_tokens is None
        tokens = self._generate_tokens(attributes=attributes,
                                       prompt_tokens=prompt_tokens,
                                       progress=progress,
                                       **self.generation_params,
                                       )
        if return_tokens:
            return self.generate_audio(tokens), tokens
        return self.generate_audio(tokens)

    def edit(self,
             prompt_tokens: torch.Tensor,
             descriptions: tp.List[str],
             src_descriptions: tp.Optional[tp.List[str]] = None,
             progress: bool = False,
             return_tokens: bool = False,
             ) -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, torch.Tensor]]:
        """Generate samples conditioned on text.

        Args:
            prompt_tokens (torch.Tensor, optional): Audio prompt used as initial latent sequence.
            descriptions (list of str): A list of strings used as editing conditioning.
            inversion (str): Inversion method (either ddim or fm_renoise)
            target_flowstep (float): Target flow step pivot in [0, 1[.
            steps (int): number of solver steps.
            src_descriptions (list of str): A list of strings used as conditioning during latent inversion.
            progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
            return_tokens (bool): Whether to return the generated tokens.
        """
        empty_attributes, no_tokens = self._prepare_tokens_and_attributes(
            [""] if src_descriptions is None else src_descriptions, None)
        assert no_tokens is None
        edit_attributes, no_tokens = self._prepare_tokens_and_attributes(descriptions, None)
        assert no_tokens is None

        inversion_params = self.editing_params.copy()
        override_total_steps = inversion_params["steps"] * (
            inversion_params["regularize_iters"] + 1) if inversion_params["regularize"] else inversion_params["steps"] * 2
        current_step_offset: int = 0

        def _progress_callback(elapsed_steps: int, total_steps: int):
            elapsed_steps += current_step_offset
            if self._progress_callback is not None:
                self._progress_callback(elapsed_steps, override_total_steps)
            else:
                print(f'{elapsed_steps: 6d} / {override_total_steps: 6d}', end='\r')

        intermediate_tokens = self._generate_tokens(attributes=empty_attributes,
                                                    prompt_tokens=prompt_tokens,
                                                    source_flowstep=1.0,
                                                    progress=progress,
                                                    callback=_progress_callback,
                                                    **inversion_params,
                                                    )
        if intermediate_tokens.shape[0] < len(descriptions):
            intermediate_tokens = intermediate_tokens.repeat(len(descriptions)//intermediate_tokens.shape[0], 1, 1)
        current_step_offset += inversion_params["steps"] * (
            inversion_params["regularize_iters"]) if inversion_params["regularize"] else inversion_params["steps"]
        inversion_params.pop("regularize")
        final_tokens = self._generate_tokens(attributes=edit_attributes,
                                             prompt_tokens=intermediate_tokens,
                                             source_flowstep=inversion_params.pop("target_flowstep"),
                                             target_flowstep=1.0,
                                             progress=progress,
                                             callback=_progress_callback,
                                             **inversion_params,)
        if return_tokens:
            return self.generate_audio(final_tokens), final_tokens
        return self.generate_audio(final_tokens)

    def _generate_tokens(self,
                         attributes: tp.List[ConditioningAttributes],
                         prompt_tokens: tp.Optional[torch.Tensor],
                         progress: bool = False,
                         callback: tp.Optional[tp.Callable[[int, int], None]] = None,
                         **kwargs) -> torch.Tensor:
        """Generate continuous audio tokens given audio prompt and/or conditions.

        Args:
            attributes (list of ConditioningAttributes): Conditions used for generation (here text).
            prompt_tokens (torch.Tensor, optional): Audio prompt used as initial latent sequence.
            progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
        Returns:
            torch.Tensor: Generated audio, of shape [B, C, T], T is defined by the generation params.
        """
        generate_params = kwargs.copy()
        total_gen_len = prompt_tokens.shape[-1] if prompt_tokens is not None else int(
            generate_params.pop('duration') * self.frame_rate)
        current_step_offset: int = 0

        def _progress_callback(elapsed_steps: int, total_steps: int):
            elapsed_steps += current_step_offset
            if self._progress_callback is not None:
                self._progress_callback(elapsed_steps, total_steps)
            else:
                print(f'{elapsed_steps: 6d} / {total_steps: 6d}', end='\r')

        if progress and callback is None:
            callback = _progress_callback

        assert total_gen_len <= int(self.max_duration * self.frame_rate)

        with self.autocast:
            gen_tokens = self.lm.generate(
                prompt=prompt_tokens,
                conditions=attributes,
                callback=callback,
                max_gen_len=total_gen_len,
                **generate_params,
            )

        return gen_tokens