unpairedelectron07 commited on
Commit
9dd6fe2
1 Parent(s): 16f7f52

Upload 11 files

Browse files
audiocraft/modules/activations.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from torch import Tensor
10
+ from typing import Union, Callable
11
+
12
+
13
+ class CustomGLU(nn.Module):
14
+ """Custom Gated Linear Unit activation.
15
+ Applies a modified gated linear unit :math:`a * f(b)` where :math:`a` is the first half
16
+ of the input matrices, :math:`b` is the second half, and :math:`f` is a provided activation
17
+ function (i.e. sigmoid, swish, etc.).
18
+
19
+ Args:
20
+ activation (nn.Module): The custom activation to apply in the Gated Linear Unit
21
+ dim (int): the dimension on which to split the input. Default: -1
22
+
23
+ Shape:
24
+ - Input: :math:`(\ast_1, N, \ast_2)` where `*` means, any number of additional
25
+ dimensions
26
+ - Output: :math:`(\ast_1, M, \ast_2)` where :math:`M=N/2`
27
+
28
+ Examples::
29
+ >>> m = CustomGLU(nn.Sigmoid())
30
+ >>> input = torch.randn(4, 2)
31
+ >>> output = m(input)
32
+ """
33
+ def __init__(self, activation: nn.Module, dim: int = -1):
34
+ super(CustomGLU, self).__init__()
35
+ self.dim = dim
36
+ self.activation = activation
37
+
38
+ def forward(self, x: Tensor):
39
+ assert x.shape[self.dim] % 2 == 0 # M = N / 2
40
+ a, b = torch.chunk(x, 2, dim=self.dim)
41
+ return a * self.activation(b)
42
+
43
+
44
+ class SwiGLU(CustomGLU):
45
+ """SiLU Gated Linear Unit activation.
46
+ Applies SiLU Gated Linear Unit :math:`a * SiLU(b)` where :math:`a` is
47
+ the first half of the input matrices, :math:`b` is the second half.
48
+
49
+ Args:
50
+ dim (int): the dimension on which to split the input. Default: -1
51
+ """
52
+ def __init__(self, dim: int = -1):
53
+ super(SwiGLU, self).__init__(nn.SiLU(), dim)
54
+
55
+
56
+ class GeGLU(CustomGLU):
57
+ """GeLU Gated Linear Unit activation.
58
+ Applies GeLU Gated Linear Unit :math:`a * GELU(b)` where :math:`a` is
59
+ the first half of the input matrices, :math:`b` is the second half.
60
+
61
+ Args:
62
+ dim (int): the dimension on which to split the input. Default: -1
63
+ """
64
+ def __init__(self, dim: int = -1):
65
+ super(GeGLU, self).__init__(nn.GELU(), dim)
66
+
67
+
68
+ class ReGLU(CustomGLU):
69
+ """ReLU Gated Linear Unit activation.
70
+ Applies ReLU Gated Linear Unit :math:`a * ReLU(b)` where :math:`a` is
71
+ the first half of the input matrices, :math:`b` is the second half.
72
+
73
+ Args:
74
+ dim (int): the dimension on which to split the input. Default: -1
75
+ """
76
+ def __init__(self, dim: int = -1):
77
+ super(ReGLU, self).__init__(nn.ReLU(), dim)
78
+
79
+
80
+ def get_activation_fn(
81
+ activation: Union[str, Callable[[Tensor], Tensor]]
82
+ ) -> Union[str, Callable[[Tensor], Tensor]]:
83
+ """Helper function to map an activation string to the activation class.
84
+ If the supplied activation is not a string that is recognized, the activation is passed back.
85
+
86
+ Args:
87
+ activation (str, or Callable[[Tensor], Tensor]): Activation to check
88
+ """
89
+ if isinstance(activation, str):
90
+ if activation == "reglu":
91
+ return ReGLU()
92
+ elif activation == "geglu":
93
+ return GeGLU()
94
+ elif activation == "swiglu":
95
+ return SwiGLU()
96
+ return activation
audiocraft/modules/chroma.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ import typing as tp
7
+
8
+ from einops import rearrange
9
+ from librosa import filters
10
+ import torch
11
+ from torch import nn
12
+ import torch.nn.functional as F
13
+ import torchaudio
14
+
15
+
16
+ class ChromaExtractor(nn.Module):
17
+ """Chroma extraction and quantization.
18
+
19
+ Args:
20
+ sample_rate (int): Sample rate for the chroma extraction.
21
+ n_chroma (int): Number of chroma bins for the chroma extraction.
22
+ radix2_exp (int): Size of stft window for the chroma extraction (power of 2, e.g. 12 -> 2^12).
23
+ nfft (int, optional): Number of FFT.
24
+ winlen (int, optional): Window length.
25
+ winhop (int, optional): Window hop size.
26
+ argmax (bool, optional): Whether to use argmax. Defaults to False.
27
+ norm (float, optional): Norm for chroma normalization. Defaults to inf.
28
+ """
29
+ def __init__(self, sample_rate: int, n_chroma: int = 12, radix2_exp: int = 12, nfft: tp.Optional[int] = None,
30
+ winlen: tp.Optional[int] = None, winhop: tp.Optional[int] = None, argmax: bool = False,
31
+ norm: float = torch.inf):
32
+ super().__init__()
33
+ self.winlen = winlen or 2 ** radix2_exp
34
+ self.nfft = nfft or self.winlen
35
+ self.winhop = winhop or (self.winlen // 4)
36
+ self.sample_rate = sample_rate
37
+ self.n_chroma = n_chroma
38
+ self.norm = norm
39
+ self.argmax = argmax
40
+ self.register_buffer('fbanks', torch.from_numpy(filters.chroma(sr=sample_rate, n_fft=self.nfft, tuning=0,
41
+ n_chroma=self.n_chroma)), persistent=False)
42
+ self.spec = torchaudio.transforms.Spectrogram(n_fft=self.nfft, win_length=self.winlen,
43
+ hop_length=self.winhop, power=2, center=True,
44
+ pad=0, normalized=True)
45
+
46
+ def forward(self, wav: torch.Tensor) -> torch.Tensor:
47
+ T = wav.shape[-1]
48
+ # in case we are getting a wav that was dropped out (nullified)
49
+ # from the conditioner, make sure wav length is no less that nfft
50
+ if T < self.nfft:
51
+ pad = self.nfft - T
52
+ r = 0 if pad % 2 == 0 else 1
53
+ wav = F.pad(wav, (pad // 2, pad // 2 + r), 'constant', 0)
54
+ assert wav.shape[-1] == self.nfft, f"expected len {self.nfft} but got {wav.shape[-1]}"
55
+
56
+ spec = self.spec(wav).squeeze(1)
57
+ raw_chroma = torch.einsum('cf,...ft->...ct', self.fbanks, spec)
58
+ norm_chroma = torch.nn.functional.normalize(raw_chroma, p=self.norm, dim=-2, eps=1e-6)
59
+ norm_chroma = rearrange(norm_chroma, 'b d t -> b t d')
60
+
61
+ if self.argmax:
62
+ idx = norm_chroma.argmax(-1, keepdim=True)
63
+ norm_chroma[:] = 0
64
+ norm_chroma.scatter_(dim=-1, index=idx, value=1)
65
+
66
+ return norm_chroma
audiocraft/modules/codebooks_patterns.py ADDED
@@ -0,0 +1,548 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from collections import namedtuple
8
+ from dataclasses import dataclass
9
+ from functools import lru_cache
10
+ import logging
11
+ import typing as tp
12
+
13
+ from abc import ABC, abstractmethod
14
+ import torch
15
+
16
+ LayoutCoord = namedtuple('LayoutCoord', ['t', 'q']) # (timestep, codebook index)
17
+ PatternLayout = tp.List[tp.List[LayoutCoord]] # Sequence of coordinates
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ @dataclass
22
+ class Pattern:
23
+ """Base implementation of a pattern over a sequence with multiple codebooks.
24
+
25
+ The codebook pattern consists in a layout, defining for each sequence step
26
+ the list of coordinates of each codebook timestep in the resulting interleaved sequence.
27
+ The first item of the pattern is always an empty list in order to properly insert a special token
28
+ to start with. For convenience, we also keep track of ``n_q`` the number of codebooks used for the pattern
29
+ and ``timesteps`` the number of timesteps corresponding to the original sequence.
30
+
31
+ The pattern provides convenient methods to build and revert interleaved sequences from it:
32
+ ``build_pattern_sequence`` maps a given a dense input tensor of multi-codebook sequence from [B, K, T]
33
+ to the interleaved sequence of shape [B, K, S] applying the pattern, with B being the batch size,
34
+ K being the number of codebooks, T the number of original timesteps and S the number of sequence steps
35
+ for the output sequence. The unfilled positions are replaced with a special token and the built sequence
36
+ is returned along with a mask indicating valid tokens.
37
+ ``revert_pattern_sequence`` maps back an interleaved sequence of shape [B, K, S] to the original alignment
38
+ of codebooks across timesteps to an output tensor of shape [B, K, T], using again a special token and a mask
39
+ to fill and specify invalid positions if needed.
40
+ See the dedicated methods for more details.
41
+ """
42
+ # Pattern layout, for each sequence step, we have a list of coordinates
43
+ # corresponding to the original codebook timestep and position.
44
+ # The first list is always an empty list in order to properly insert
45
+ # a special token to start with.
46
+ layout: PatternLayout
47
+ timesteps: int
48
+ n_q: int
49
+
50
+ def __post_init__(self):
51
+ assert len(self.layout) > 0
52
+ self._validate_layout()
53
+ self._build_reverted_sequence_scatter_indexes = lru_cache(100)(self._build_reverted_sequence_scatter_indexes)
54
+ self._build_pattern_sequence_scatter_indexes = lru_cache(100)(self._build_pattern_sequence_scatter_indexes)
55
+ logger.info("New pattern, time steps: %d, sequence steps: %d", self.timesteps, len(self.layout))
56
+
57
+ def _validate_layout(self):
58
+ """Runs checks on the layout to ensure a valid pattern is defined.
59
+ A pattern is considered invalid if:
60
+ - Multiple timesteps for a same codebook are defined in the same sequence step
61
+ - The timesteps for a given codebook are not in ascending order as we advance in the sequence
62
+ (this would mean that we have future timesteps before past timesteps).
63
+ """
64
+ q_timesteps = {q: 0 for q in range(self.n_q)}
65
+ for s, seq_coords in enumerate(self.layout):
66
+ if len(seq_coords) > 0:
67
+ qs = set()
68
+ for coord in seq_coords:
69
+ qs.add(coord.q)
70
+ last_q_timestep = q_timesteps[coord.q]
71
+ assert coord.t >= last_q_timestep, \
72
+ f"Past timesteps are found in the sequence for codebook = {coord.q} at step {s}"
73
+ q_timesteps[coord.q] = coord.t
74
+ # each sequence step contains at max 1 coordinate per codebook
75
+ assert len(qs) == len(seq_coords), \
76
+ f"Multiple entries for a same codebook are found at step {s}"
77
+
78
+ @property
79
+ def num_sequence_steps(self):
80
+ return len(self.layout) - 1
81
+
82
+ @property
83
+ def max_delay(self):
84
+ max_t_in_seq_coords = 0
85
+ for seq_coords in self.layout[1:]:
86
+ for coords in seq_coords:
87
+ max_t_in_seq_coords = max(max_t_in_seq_coords, coords.t + 1)
88
+ return max_t_in_seq_coords - self.timesteps
89
+
90
+ @property
91
+ def valid_layout(self):
92
+ valid_step = len(self.layout) - self.max_delay
93
+ return self.layout[:valid_step]
94
+
95
+ def starts_with_special_token(self):
96
+ return self.layout[0] == []
97
+
98
+ def get_sequence_coords_with_timestep(self, t: int, q: tp.Optional[int] = None):
99
+ """Get codebook coordinates in the layout that corresponds to the specified timestep t
100
+ and optionally to the codebook q. Coordinates are returned as a tuple with the sequence step
101
+ and the actual codebook coordinates.
102
+ """
103
+ assert t <= self.timesteps, "provided timesteps is greater than the pattern's number of timesteps"
104
+ if q is not None:
105
+ assert q <= self.n_q, "provided number of codebooks is greater than the pattern's number of codebooks"
106
+ coords = []
107
+ for s, seq_codes in enumerate(self.layout):
108
+ for code in seq_codes:
109
+ if code.t == t and (q is None or code.q == q):
110
+ coords.append((s, code))
111
+ return coords
112
+
113
+ def get_steps_with_timestep(self, t: int, q: tp.Optional[int] = None) -> tp.List[int]:
114
+ return [step for step, coords in self.get_sequence_coords_with_timestep(t, q)]
115
+
116
+ def get_first_step_with_timesteps(self, t: int, q: tp.Optional[int] = None) -> tp.Optional[int]:
117
+ steps_with_timesteps = self.get_steps_with_timestep(t, q)
118
+ return steps_with_timesteps[0] if len(steps_with_timesteps) > 0 else None
119
+
120
+ def _build_pattern_sequence_scatter_indexes(self, timesteps: int, n_q: int, keep_only_valid_steps: bool,
121
+ device: tp.Union[torch.device, str] = 'cpu'):
122
+ """Build scatter indexes corresponding to the pattern, up to the provided sequence_steps.
123
+
124
+ Args:
125
+ timesteps (int): Maximum number of timesteps steps to consider.
126
+ keep_only_valid_steps (bool): Restrict the pattern layout to match only valid steps.
127
+ device (torch.device or str): Device for created tensors.
128
+ Returns:
129
+ indexes (torch.Tensor): Indexes corresponding to the sequence, of shape [K, S].
130
+ mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes, of shape [K, S].
131
+ """
132
+ assert n_q == self.n_q, f"invalid number of codebooks for the sequence and the pattern: {n_q} != {self.n_q}"
133
+ assert timesteps <= self.timesteps, "invalid number of timesteps used to build the sequence from the pattern"
134
+ # use the proper layout based on whether we limit ourselves to valid steps only or not,
135
+ # note that using the valid_layout will result in a truncated sequence up to the valid steps
136
+ ref_layout = self.valid_layout if keep_only_valid_steps else self.layout
137
+ # single item indexing being super slow with pytorch vs. numpy, so we use numpy here
138
+ indexes = torch.zeros(n_q, len(ref_layout), dtype=torch.long).numpy()
139
+ mask = torch.zeros(n_q, len(ref_layout), dtype=torch.bool).numpy()
140
+ # fill indexes with last sequence step value that will correspond to our special token
141
+ # the last value is n_q * timesteps as we have flattened z and append special token as the last token
142
+ # which will correspond to the index: n_q * timesteps
143
+ indexes[:] = n_q * timesteps
144
+ # iterate over the pattern and fill scattered indexes and mask
145
+ for s, sequence_coords in enumerate(ref_layout):
146
+ for coords in sequence_coords:
147
+ if coords.t < timesteps:
148
+ indexes[coords.q, s] = coords.t + coords.q * timesteps
149
+ mask[coords.q, s] = 1
150
+ indexes = torch.from_numpy(indexes).to(device)
151
+ mask = torch.from_numpy(mask).to(device)
152
+ return indexes, mask
153
+
154
+ def build_pattern_sequence(self, z: torch.Tensor, special_token: int, keep_only_valid_steps: bool = False):
155
+ """Build sequence corresponding to the pattern from the input tensor z.
156
+ The sequence is built using up to sequence_steps if specified, and non-pattern
157
+ coordinates are filled with the special token.
158
+
159
+ Args:
160
+ z (torch.Tensor): Input tensor of multi-codebooks sequence, of shape [B, K, T].
161
+ special_token (int): Special token used to fill non-pattern coordinates in the new sequence.
162
+ keep_only_valid_steps (bool): Build a sequence from the pattern up to valid (= fully defined) steps.
163
+ Steps that are beyond valid steps will be replaced by the special_token in that case.
164
+ Returns:
165
+ values (torch.Tensor): Interleaved sequence matching the pattern, of shape [B, K, S] with S
166
+ corresponding either to the sequence_steps if provided, otherwise to the length of the pattern.
167
+ indexes (torch.Tensor): Indexes corresponding to the interleaved sequence, of shape [K, S].
168
+ mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, S].
169
+ """
170
+ B, K, T = z.shape
171
+ indexes, mask = self._build_pattern_sequence_scatter_indexes(
172
+ T, K, keep_only_valid_steps=keep_only_valid_steps, device=str(z.device)
173
+ )
174
+ z = z.view(B, -1)
175
+ # we append the special token as the last index of our flattened z tensor
176
+ z = torch.cat([z, torch.zeros_like(z[:, :1]) + special_token], dim=1)
177
+ values = z[:, indexes.view(-1)]
178
+ values = values.view(B, K, indexes.shape[-1])
179
+ return values, indexes, mask
180
+
181
+ def _build_reverted_sequence_scatter_indexes(self, sequence_steps: int, n_q: int,
182
+ keep_only_valid_steps: bool = False,
183
+ is_model_output: bool = False,
184
+ device: tp.Union[torch.device, str] = 'cpu'):
185
+ """Builds scatter indexes required to retrieve the original multi-codebook sequence
186
+ from interleaving pattern.
187
+
188
+ Args:
189
+ sequence_steps (int): Sequence steps.
190
+ n_q (int): Number of codebooks.
191
+ keep_only_valid_steps (bool): Build a sequence from the pattern up to valid (= fully defined) steps.
192
+ Steps that are beyond valid steps will be replaced by the special_token in that case.
193
+ is_model_output (bool): Whether to keep the sequence item corresponding to initial special token or not.
194
+ device (torch.device or str): Device for created tensors.
195
+ Returns:
196
+ indexes (torch.Tensor): Indexes for reconstructing the output, of shape [K, T].
197
+ mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, T].
198
+ """
199
+ ref_layout = self.valid_layout if keep_only_valid_steps else self.layout
200
+ # TODO(jade): Do we want to further truncate to only valid timesteps here as well?
201
+ timesteps = self.timesteps
202
+ assert n_q == self.n_q, f"invalid number of codebooks for the sequence and the pattern: {n_q} != {self.n_q}"
203
+ assert sequence_steps <= len(ref_layout), \
204
+ f"sequence to revert is longer than the defined pattern: {sequence_steps} > {len(ref_layout)}"
205
+
206
+ # ensure we take the appropriate indexes to keep the model output from the first special token as well
207
+ if is_model_output and self.starts_with_special_token():
208
+ ref_layout = ref_layout[1:]
209
+
210
+ # single item indexing being super slow with pytorch vs. numpy, so we use numpy here
211
+ indexes = torch.zeros(n_q, timesteps, dtype=torch.long).numpy()
212
+ mask = torch.zeros(n_q, timesteps, dtype=torch.bool).numpy()
213
+ # fill indexes with last sequence step value that will correspond to our special token
214
+ indexes[:] = n_q * sequence_steps
215
+ for s, sequence_codes in enumerate(ref_layout):
216
+ if s < sequence_steps:
217
+ for code in sequence_codes:
218
+ if code.t < timesteps:
219
+ indexes[code.q, code.t] = s + code.q * sequence_steps
220
+ mask[code.q, code.t] = 1
221
+ indexes = torch.from_numpy(indexes).to(device)
222
+ mask = torch.from_numpy(mask).to(device)
223
+ return indexes, mask
224
+
225
+ def revert_pattern_sequence(self, s: torch.Tensor, special_token: int, keep_only_valid_steps: bool = False):
226
+ """Revert a sequence built from the pattern back to the original multi-codebook sequence without interleaving.
227
+ The sequence is reverted using up to timesteps if specified, and non-pattern coordinates
228
+ are filled with the special token.
229
+
230
+ Args:
231
+ s (torch.Tensor): Interleaved sequence tensor obtained from the pattern, of shape [B, K, S].
232
+ special_token (int or float): Special token used to fill non-pattern coordinates in the new sequence.
233
+ Returns:
234
+ values (torch.Tensor): Interleaved sequence matching the pattern, of shape [B, K, T] with T
235
+ corresponding either to the timesteps if provided, or the total timesteps in pattern otherwise.
236
+ indexes (torch.Tensor): Indexes corresponding to the interleaved sequence, of shape [K, T].
237
+ mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, T].
238
+ """
239
+ B, K, S = s.shape
240
+ indexes, mask = self._build_reverted_sequence_scatter_indexes(
241
+ S, K, keep_only_valid_steps, is_model_output=False, device=str(s.device)
242
+ )
243
+ s = s.view(B, -1)
244
+ # we append the special token as the last index of our flattened z tensor
245
+ s = torch.cat([s, torch.zeros_like(s[:, :1]) + special_token], dim=1)
246
+ values = s[:, indexes.view(-1)]
247
+ values = values.view(B, K, indexes.shape[-1])
248
+ return values, indexes, mask
249
+
250
+ def revert_pattern_logits(self, logits: torch.Tensor, special_token: float, keep_only_valid_steps: bool = False):
251
+ """Revert model logits obtained on a sequence built from the pattern
252
+ back to a tensor matching the original sequence.
253
+
254
+ This method is similar to ``revert_pattern_sequence`` with the following specificities:
255
+ 1. It is designed to work with the extra cardinality dimension
256
+ 2. We return the logits for the first sequence item that matches the special_token and
257
+ which matching target in the original sequence is the first item of the sequence,
258
+ while we skip the last logits as there is no matching target
259
+ """
260
+ B, card, K, S = logits.shape
261
+ indexes, mask = self._build_reverted_sequence_scatter_indexes(
262
+ S, K, keep_only_valid_steps, is_model_output=True, device=logits.device
263
+ )
264
+ logits = logits.reshape(B, card, -1)
265
+ # we append the special token as the last index of our flattened z tensor
266
+ logits = torch.cat([logits, torch.zeros_like(logits[:, :, :1]) + special_token], dim=-1) # [B, card, K x S]
267
+ values = logits[:, :, indexes.view(-1)]
268
+ values = values.view(B, card, K, indexes.shape[-1])
269
+ return values, indexes, mask
270
+
271
+
272
+ class CodebooksPatternProvider(ABC):
273
+ """Abstraction around providing pattern for interleaving codebooks.
274
+
275
+ The CodebooksPatternProvider abstraction allows to implement various strategies to
276
+ define interleaving pattern of sequences composed of multiple codebooks. For a given
277
+ number of codebooks `n_q`, the pattern provider can generate a specified pattern
278
+ corresponding to a sequence of `T` timesteps with `n_q` parallel codebooks. This pattern
279
+ can be used to construct a new sequence from the original codes respecting the specified
280
+ pattern. The pattern is defined as a list of list of code coordinates, code coordinate
281
+ being a tuple with the original timestep and codebook to build the new sequence.
282
+ Note that all patterns must start with an empty list that is then used to insert a first
283
+ sequence step of special tokens in the newly generated sequence.
284
+
285
+ Args:
286
+ n_q (int): number of codebooks.
287
+ cached (bool): if True, patterns for a given length are cached. In general
288
+ that should be true for efficiency reason to avoid synchronization points.
289
+ """
290
+ def __init__(self, n_q: int, cached: bool = True):
291
+ assert n_q > 0
292
+ self.n_q = n_q
293
+ self.get_pattern = lru_cache(100)(self.get_pattern) # type: ignore
294
+
295
+ @abstractmethod
296
+ def get_pattern(self, timesteps: int) -> Pattern:
297
+ """Builds pattern with specific interleaving between codebooks.
298
+
299
+ Args:
300
+ timesteps (int): Total number of timesteps.
301
+ """
302
+ raise NotImplementedError()
303
+
304
+
305
+ class DelayedPatternProvider(CodebooksPatternProvider):
306
+ """Provider for delayed pattern across delayed codebooks.
307
+ Codebooks are delayed in the sequence and sequence steps will contain codebooks
308
+ from different timesteps.
309
+
310
+ Example:
311
+ Taking timesteps=4 and n_q=3, delays=None, the multi-codebook sequence:
312
+ [[1, 2, 3, 4],
313
+ [1, 2, 3, 4],
314
+ [1, 2, 3, 4]]
315
+ The resulting sequence obtained from the returned pattern is:
316
+ [[S, 1, 2, 3, 4],
317
+ [S, S, 1, 2, 3],
318
+ [S, S, S, 1, 2]]
319
+ (with S being a special token)
320
+
321
+ Args:
322
+ n_q (int): Number of codebooks.
323
+ delays (list of int, optional): Delay for each of the codebooks.
324
+ If delays not defined, each codebook is delayed by 1 compared to the previous one.
325
+ flatten_first (int): Flatten the first N timesteps.
326
+ empty_initial (int): Prepend with N empty list of coordinates.
327
+ """
328
+ def __init__(self, n_q: int, delays: tp.Optional[tp.List[int]] = None,
329
+ flatten_first: int = 0, empty_initial: int = 0):
330
+ super().__init__(n_q)
331
+ if delays is None:
332
+ delays = list(range(n_q))
333
+ self.delays = delays
334
+ self.flatten_first = flatten_first
335
+ self.empty_initial = empty_initial
336
+ assert len(self.delays) == self.n_q
337
+ assert sorted(self.delays) == self.delays
338
+
339
+ def get_pattern(self, timesteps: int) -> Pattern:
340
+ omit_special_token = self.empty_initial < 0
341
+ out: PatternLayout = [] if omit_special_token else [[]]
342
+ max_delay = max(self.delays)
343
+ if self.empty_initial:
344
+ out += [[] for _ in range(self.empty_initial)]
345
+ if self.flatten_first:
346
+ for t in range(min(timesteps, self.flatten_first)):
347
+ for q in range(self.n_q):
348
+ out.append([LayoutCoord(t, q)])
349
+ for t in range(self.flatten_first, timesteps + max_delay):
350
+ v = []
351
+ for q, delay in enumerate(self.delays):
352
+ t_for_q = t - delay
353
+ if t_for_q >= self.flatten_first:
354
+ v.append(LayoutCoord(t_for_q, q))
355
+ out.append(v)
356
+ return Pattern(out, n_q=self.n_q, timesteps=timesteps)
357
+
358
+
359
+ class ParallelPatternProvider(DelayedPatternProvider):
360
+ """Provider for parallel pattern across codebooks.
361
+ This pattern provider is a special case of the delayed pattern with actually no delay,
362
+ hence delays=repeat(0, n_q).
363
+
364
+ Args:
365
+ n_q (int): Number of codebooks.
366
+ empty_initial (int): Prepend with N empty list of coordinates.
367
+ """
368
+ def __init__(self, n_q: int, empty_initial: int = 0):
369
+ super().__init__(n_q, [0] * n_q, empty_initial=empty_initial)
370
+
371
+
372
+ class UnrolledPatternProvider(CodebooksPatternProvider):
373
+ """Provider for unrolling codebooks pattern.
374
+ This pattern provider enables to represent the codebook flattened completely or only to some extend
375
+ while also specifying a given delay between the flattened codebooks representation, allowing to
376
+ unroll the codebooks in the sequence.
377
+
378
+ Example:
379
+ 1. Flattening of the codebooks.
380
+ By default, the pattern provider will fully flatten the codebooks such as flattening=range(n_q),
381
+ taking n_q = 3 and timesteps = 4:
382
+ [[1, 2, 3, 4],
383
+ [1, 2, 3, 4],
384
+ [1, 2, 3, 4]]
385
+ will result into:
386
+ [[S, S, 1, S, S, 2, S, S, 3, S, S, 4],
387
+ [S, 1, S, S, 2, S, S, 3, S, S, 4, S],
388
+ [1, S, S, 2, S, S, 3, S, S, 4, S, S]]
389
+ 2. Partial flattening of the codebooks. The ``flattening`` parameter allows to specify the inner step
390
+ for each of the codebook, allowing to define which codebook to flatten (or keep in parallel), for example
391
+ taking n_q = 3, timesteps = 4 and flattening = [0, 1, 1]:
392
+ [[1, 2, 3, 4],
393
+ [1, 2, 3, 4],
394
+ [1, 2, 3, 4]]
395
+ will result into:
396
+ [[S, 1, S, S, 2, S, S, 3, S, S, 4, S],
397
+ [S, 1, S, S, 2, S, S, 3, S, S, 4, S],
398
+ [1, S, S, 2, S, S, 3, S, S, 4, S, S]]
399
+ 3. Flattening with delay. The ``delay`` parameter allows to further unroll the sequence of codebooks
400
+ allowing to specify the delay per codebook. Note that the delay between codebooks flattened to the
401
+ same inner timestep should be coherent. For example, taking n_q = 3, timesteps = 4, flattening = [0, 1, 1]
402
+ and delays = [0, 3, 3]:
403
+ [[1, 2, 3, 4],
404
+ [1, 2, 3, 4],
405
+ [1, 2, 3, 4]]
406
+ will result into:
407
+ [[S, S, S, 1, S, 2, S, 3, S, 4],
408
+ [S, S, S, 1, S, 2, S, 3, S, 4],
409
+ [1, 2, 3, S, 4, S, 5, S, 6, S]]
410
+
411
+ Args:
412
+ n_q (int): Number of codebooks.
413
+ flattening (list of int, optional): Flattening schema over the codebooks. If not defined,
414
+ the codebooks will be flattened to 1 codebook per step, meaning that the sequence will
415
+ have n_q extra steps for each timestep.
416
+ delays (list of int, optional): Delay for each of the codebooks. If not defined,
417
+ no delay is added and therefore will default to [0] * ``n_q``.
418
+ Note that two codebooks that will be flattened to the same inner step
419
+ should have the same delay, otherwise the pattern is considered as invalid.
420
+ """
421
+ FlattenedCodebook = namedtuple('FlattenedCodebook', ['codebooks', 'delay'])
422
+
423
+ def __init__(self, n_q: int, flattening: tp.Optional[tp.List[int]] = None,
424
+ delays: tp.Optional[tp.List[int]] = None):
425
+ super().__init__(n_q)
426
+ if flattening is None:
427
+ flattening = list(range(n_q))
428
+ if delays is None:
429
+ delays = [0] * n_q
430
+ assert len(flattening) == n_q
431
+ assert len(delays) == n_q
432
+ assert sorted(flattening) == flattening
433
+ assert sorted(delays) == delays
434
+ self._flattened_codebooks = self._build_flattened_codebooks(delays, flattening)
435
+ self.max_delay = max(delays)
436
+
437
+ def _build_flattened_codebooks(self, delays: tp.List[int], flattening: tp.List[int]):
438
+ """Build a flattened codebooks representation as a dictionary of inner step
439
+ and the actual codebook indices corresponding to the flattened codebook. For convenience, we
440
+ also store the delay associated to the flattened codebook to avoid maintaining an extra mapping.
441
+ """
442
+ flattened_codebooks: dict = {}
443
+ for q, (inner_step, delay) in enumerate(zip(flattening, delays)):
444
+ if inner_step not in flattened_codebooks:
445
+ flat_codebook = UnrolledPatternProvider.FlattenedCodebook(codebooks=[q], delay=delay)
446
+ else:
447
+ flat_codebook = flattened_codebooks[inner_step]
448
+ assert flat_codebook.delay == delay, (
449
+ "Delay and flattening between codebooks is inconsistent: ",
450
+ "two codebooks flattened to the same position should have the same delay."
451
+ )
452
+ flat_codebook.codebooks.append(q)
453
+ flattened_codebooks[inner_step] = flat_codebook
454
+ return flattened_codebooks
455
+
456
+ @property
457
+ def _num_inner_steps(self):
458
+ """Number of inner steps to unroll between timesteps in order to flatten the codebooks.
459
+ """
460
+ return max([inner_step for inner_step in self._flattened_codebooks.keys()]) + 1
461
+
462
+ def num_virtual_steps(self, timesteps: int) -> int:
463
+ return timesteps * self._num_inner_steps + 1
464
+
465
+ def get_pattern(self, timesteps: int) -> Pattern:
466
+ """Builds pattern for delay across codebooks.
467
+
468
+ Args:
469
+ timesteps (int): Total number of timesteps.
470
+ """
471
+ # the PatternLayout is built as a tuple of sequence position and list of coordinates
472
+ # so that it can be reordered properly given the required delay between codebooks of given timesteps
473
+ indexed_out: list = [(-1, [])]
474
+ max_timesteps = timesteps + self.max_delay
475
+ for t in range(max_timesteps):
476
+ # for each timestep, we unroll the flattened codebooks,
477
+ # emitting the sequence step with the corresponding delay
478
+ for step in range(self._num_inner_steps):
479
+ if step in self._flattened_codebooks:
480
+ # we have codebooks at this virtual step to emit
481
+ step_codebooks = self._flattened_codebooks[step]
482
+ t_for_q = t + step_codebooks.delay
483
+ coords = [LayoutCoord(t, q) for q in step_codebooks.codebooks]
484
+ if t_for_q < max_timesteps and t < max_timesteps:
485
+ indexed_out.append((t_for_q, coords))
486
+ else:
487
+ # there is no codebook in this virtual step so we emit an empty list
488
+ indexed_out.append((t, []))
489
+ out = [coords for _, coords in sorted(indexed_out)]
490
+ return Pattern(out, n_q=self.n_q, timesteps=timesteps)
491
+
492
+
493
+ class CoarseFirstPattern(CodebooksPatternProvider):
494
+ """First generates all the codebooks #1 (e.g. coarser), then the remaining ones,
495
+ potentially with delays.
496
+
497
+ ..Warning:: You must always generate the full training duration at test time, for instance,
498
+ 30 seconds, as otherwise, the fine codebooks will start being generated in an unexpected
499
+ location. This is due to the non causality of the remaining codebooks with respect to
500
+ the first ones.
501
+
502
+ Args:
503
+ n_q (int): Number of codebooks.
504
+ delays (list of int, optional): Delay for each of the codebooks.
505
+ If delays not defined, each codebook is delayed by 1 compared to the previous one.
506
+ """
507
+ def __init__(self, n_q: int, delays: tp.Optional[tp.List[int]] = None):
508
+ super().__init__(n_q)
509
+ if delays is None:
510
+ delays = [0] * (n_q - 1)
511
+ self.delays = delays
512
+ assert len(self.delays) == self.n_q - 1
513
+ assert sorted(self.delays) == self.delays
514
+
515
+ def get_pattern(self, timesteps: int) -> Pattern:
516
+ out: PatternLayout = [[]]
517
+ for t in range(timesteps):
518
+ out.append([LayoutCoord(t, 0)])
519
+ max_delay = max(self.delays)
520
+ for t in range(timesteps + max_delay):
521
+ v = []
522
+ for q, delay in enumerate(self.delays):
523
+ t_for_q = t - delay
524
+ if t_for_q >= 0:
525
+ v.append(LayoutCoord(t_for_q, q + 1))
526
+ out.append(v)
527
+ return Pattern(out, n_q=self.n_q, timesteps=timesteps)
528
+
529
+
530
+ class MusicLMPattern(CodebooksPatternProvider):
531
+ """Almost MusicLM style pattern. This is equivalent to full flattening
532
+ but in a different order.
533
+
534
+ Args:
535
+ n_q (int): Number of codebooks.
536
+ group_by (int): Number of codebooks to group together.
537
+ """
538
+ def __init__(self, n_q: int, group_by: int = 2):
539
+ super().__init__(n_q)
540
+ self.group_by = group_by
541
+
542
+ def get_pattern(self, timesteps: int) -> Pattern:
543
+ out: PatternLayout = [[]]
544
+ for offset in range(0, self.n_q, self.group_by):
545
+ for t in range(timesteps):
546
+ for q in range(offset, offset + self.group_by):
547
+ out.append([LayoutCoord(t, q)])
548
+ return Pattern(out, n_q=self.n_q, timesteps=timesteps)
audiocraft/modules/conditioners.py ADDED
@@ -0,0 +1,1416 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from collections import defaultdict
8
+ from copy import deepcopy
9
+ from dataclasses import dataclass, field
10
+ from itertools import chain
11
+ import logging
12
+ import math
13
+ from pathlib import Path
14
+ import random
15
+ import re
16
+ import typing as tp
17
+ import warnings
18
+
19
+ import einops
20
+ from num2words import num2words
21
+ import spacy
22
+ from transformers import RobertaTokenizer, T5EncoderModel, T5Tokenizer # type: ignore
23
+ import torch
24
+ from torch import nn
25
+ import torch.nn.functional as F
26
+ from torch.nn.utils.rnn import pad_sequence
27
+
28
+ from .chroma import ChromaExtractor
29
+ from .streaming import StreamingModule
30
+ from .transformer import create_sin_embedding
31
+ from ..data.audio import audio_read
32
+ from ..data.audio_dataset import SegmentInfo
33
+ from ..data.audio_utils import convert_audio
34
+ from ..environment import AudioCraftEnvironment
35
+ from ..quantization import ResidualVectorQuantizer
36
+ from ..utils.autocast import TorchAutocast
37
+ from ..utils.cache import EmbeddingCache
38
+ from ..utils.utils import collate, hash_trick, length_to_mask, load_clap_state_dict, warn_once
39
+
40
+
41
+ logger = logging.getLogger(__name__)
42
+ TextCondition = tp.Optional[str] # a text condition can be a string or None (if doesn't exist)
43
+ ConditionType = tp.Tuple[torch.Tensor, torch.Tensor] # condition, mask
44
+
45
+
46
+ class WavCondition(tp.NamedTuple):
47
+ wav: torch.Tensor
48
+ length: torch.Tensor
49
+ sample_rate: tp.List[int]
50
+ path: tp.List[tp.Optional[str]] = []
51
+ seek_time: tp.List[tp.Optional[float]] = []
52
+
53
+
54
+ class JointEmbedCondition(tp.NamedTuple):
55
+ wav: torch.Tensor
56
+ text: tp.List[tp.Optional[str]]
57
+ length: torch.Tensor
58
+ sample_rate: tp.List[int]
59
+ path: tp.List[tp.Optional[str]] = []
60
+ seek_time: tp.List[tp.Optional[float]] = []
61
+
62
+
63
+ @dataclass
64
+ class ConditioningAttributes:
65
+ text: tp.Dict[str, tp.Optional[str]] = field(default_factory=dict)
66
+ wav: tp.Dict[str, WavCondition] = field(default_factory=dict)
67
+ joint_embed: tp.Dict[str, JointEmbedCondition] = field(default_factory=dict)
68
+
69
+ def __getitem__(self, item):
70
+ return getattr(self, item)
71
+
72
+ @property
73
+ def text_attributes(self):
74
+ return self.text.keys()
75
+
76
+ @property
77
+ def wav_attributes(self):
78
+ return self.wav.keys()
79
+
80
+ @property
81
+ def joint_embed_attributes(self):
82
+ return self.joint_embed.keys()
83
+
84
+ @property
85
+ def attributes(self):
86
+ return {
87
+ "text": self.text_attributes,
88
+ "wav": self.wav_attributes,
89
+ "joint_embed": self.joint_embed_attributes,
90
+ }
91
+
92
+ def to_flat_dict(self):
93
+ return {
94
+ **{f"text.{k}": v for k, v in self.text.items()},
95
+ **{f"wav.{k}": v for k, v in self.wav.items()},
96
+ **{f"joint_embed.{k}": v for k, v in self.joint_embed.items()}
97
+ }
98
+
99
+ @classmethod
100
+ def from_flat_dict(cls, x):
101
+ out = cls()
102
+ for k, v in x.items():
103
+ kind, att = k.split(".")
104
+ out[kind][att] = v
105
+ return out
106
+
107
+
108
+ class SegmentWithAttributes(SegmentInfo):
109
+ """Base class for all dataclasses that are used for conditioning.
110
+ All child classes should implement `to_condition_attributes` that converts
111
+ the existing attributes to a dataclass of type ConditioningAttributes.
112
+ """
113
+ def to_condition_attributes(self) -> ConditioningAttributes:
114
+ raise NotImplementedError()
115
+
116
+
117
+ def nullify_condition(condition: ConditionType, dim: int = 1):
118
+ """Transform an input condition to a null condition.
119
+ The way it is done by converting it to a single zero vector similarly
120
+ to how it is done inside WhiteSpaceTokenizer and NoopTokenizer.
121
+
122
+ Args:
123
+ condition (ConditionType): A tuple of condition and mask (tuple[torch.Tensor, torch.Tensor])
124
+ dim (int): The dimension that will be truncated (should be the time dimension)
125
+ WARNING!: dim should not be the batch dimension!
126
+ Returns:
127
+ ConditionType: A tuple of null condition and mask
128
+ """
129
+ assert dim != 0, "dim cannot be the batch dimension!"
130
+ assert isinstance(condition, tuple) and \
131
+ isinstance(condition[0], torch.Tensor) and \
132
+ isinstance(condition[1], torch.Tensor), "'nullify_condition' got an unexpected input type!"
133
+ cond, mask = condition
134
+ B = cond.shape[0]
135
+ last_dim = cond.dim() - 1
136
+ out = cond.transpose(dim, last_dim)
137
+ out = 0. * out[..., :1]
138
+ out = out.transpose(dim, last_dim)
139
+ mask = torch.zeros((B, 1), device=out.device).int()
140
+ assert cond.dim() == out.dim()
141
+ return out, mask
142
+
143
+
144
+ def nullify_wav(cond: WavCondition) -> WavCondition:
145
+ """Transform a WavCondition to a nullified WavCondition.
146
+ It replaces the wav by a null tensor, forces its length to 0, and replaces metadata by dummy attributes.
147
+
148
+ Args:
149
+ cond (WavCondition): Wav condition with wav, tensor of shape [B, T].
150
+ Returns:
151
+ WavCondition: Nullified wav condition.
152
+ """
153
+ null_wav, _ = nullify_condition((cond.wav, torch.zeros_like(cond.wav)), dim=cond.wav.dim() - 1)
154
+ return WavCondition(
155
+ wav=null_wav,
156
+ length=torch.tensor([0] * cond.wav.shape[0], device=cond.wav.device),
157
+ sample_rate=cond.sample_rate,
158
+ path=[None] * cond.wav.shape[0],
159
+ seek_time=[None] * cond.wav.shape[0],
160
+ )
161
+
162
+
163
+ def nullify_joint_embed(embed: JointEmbedCondition) -> JointEmbedCondition:
164
+ """Nullify the joint embedding condition by replacing it by a null tensor, forcing its length to 0,
165
+ and replacing metadata by dummy attributes.
166
+
167
+ Args:
168
+ cond (JointEmbedCondition): Joint embedding condition with wav and text, wav tensor of shape [B, C, T].
169
+ """
170
+ null_wav, _ = nullify_condition((embed.wav, torch.zeros_like(embed.wav)), dim=embed.wav.dim() - 1)
171
+ return JointEmbedCondition(
172
+ wav=null_wav, text=[None] * len(embed.text),
173
+ length=torch.LongTensor([0]).to(embed.wav.device),
174
+ sample_rate=embed.sample_rate,
175
+ path=[None] * embed.wav.shape[0],
176
+ seek_time=[0] * embed.wav.shape[0],
177
+ )
178
+
179
+
180
+ class Tokenizer:
181
+ """Base tokenizer implementation
182
+ (in case we want to introduce more advances tokenizers in the future).
183
+ """
184
+ def __call__(self, texts: tp.List[tp.Optional[str]]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
185
+ raise NotImplementedError()
186
+
187
+
188
+ class WhiteSpaceTokenizer(Tokenizer):
189
+ """This tokenizer should be used for natural language descriptions.
190
+ For example:
191
+ ["he didn't, know he's going home.", 'shorter sentence'] =>
192
+ [[78, 62, 31, 4, 78, 25, 19, 34],
193
+ [59, 77, 0, 0, 0, 0, 0, 0]]
194
+ """
195
+ PUNCTUATION = "?:!.,;"
196
+
197
+ def __init__(self, n_bins: int, pad_idx: int = 0, language: str = "en_core_web_sm",
198
+ lemma: bool = True, stopwords: bool = True) -> None:
199
+ self.n_bins = n_bins
200
+ self.pad_idx = pad_idx
201
+ self.lemma = lemma
202
+ self.stopwords = stopwords
203
+ try:
204
+ self.nlp = spacy.load(language)
205
+ except IOError:
206
+ spacy.cli.download(language) # type: ignore
207
+ self.nlp = spacy.load(language)
208
+
209
+ @tp.no_type_check
210
+ def __call__(self, texts: tp.List[tp.Optional[str]],
211
+ return_text: bool = False) -> tp.Tuple[torch.Tensor, torch.Tensor]:
212
+ """Take a list of strings and convert them to a tensor of indices.
213
+
214
+ Args:
215
+ texts (list[str]): List of strings.
216
+ return_text (bool, optional): Whether to return text as additional tuple item. Defaults to False.
217
+ Returns:
218
+ tuple[torch.Tensor, torch.Tensor]:
219
+ - Indices of words in the LUT.
220
+ - And a mask indicating where the padding tokens are
221
+ """
222
+ output, lengths = [], []
223
+ texts = deepcopy(texts)
224
+ for i, text in enumerate(texts):
225
+ # if current sample doesn't have a certain attribute, replace with pad token
226
+ if text is None:
227
+ output.append(torch.Tensor([self.pad_idx]))
228
+ lengths.append(0)
229
+ continue
230
+
231
+ # convert numbers to words
232
+ text = re.sub(r"(\d+)", lambda x: num2words(int(x.group(0))), text) # type: ignore
233
+ # normalize text
234
+ text = self.nlp(text) # type: ignore
235
+ # remove stopwords
236
+ if self.stopwords:
237
+ text = [w for w in text if not w.is_stop] # type: ignore
238
+ # remove punctuation
239
+ text = [w for w in text if w.text not in self.PUNCTUATION] # type: ignore
240
+ # lemmatize if needed
241
+ text = [getattr(t, "lemma_" if self.lemma else "text") for t in text] # type: ignore
242
+
243
+ texts[i] = " ".join(text)
244
+ lengths.append(len(text))
245
+ # convert to tensor
246
+ tokens = torch.Tensor([hash_trick(w, self.n_bins) for w in text])
247
+ output.append(tokens)
248
+
249
+ mask = length_to_mask(torch.IntTensor(lengths)).int()
250
+ padded_output = pad_sequence(output, padding_value=self.pad_idx).int().t()
251
+ if return_text:
252
+ return padded_output, mask, texts # type: ignore
253
+ return padded_output, mask
254
+
255
+
256
+ class NoopTokenizer(Tokenizer):
257
+ """This tokenizer should be used for global conditioners such as: artist, genre, key, etc.
258
+ The difference between this and WhiteSpaceTokenizer is that NoopTokenizer does not split
259
+ strings, so "Jeff Buckley" will get it's own index. Whereas WhiteSpaceTokenizer will
260
+ split it to ["Jeff", "Buckley"] and return an index per word.
261
+
262
+ For example:
263
+ ["Queen", "ABBA", "Jeff Buckley"] => [43, 55, 101]
264
+ ["Metal", "Rock", "Classical"] => [0, 223, 51]
265
+ """
266
+ def __init__(self, n_bins: int, pad_idx: int = 0):
267
+ self.n_bins = n_bins
268
+ self.pad_idx = pad_idx
269
+
270
+ def __call__(self, texts: tp.List[tp.Optional[str]]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
271
+ output, lengths = [], []
272
+ for text in texts:
273
+ # if current sample doesn't have a certain attribute, replace with pad token
274
+ if text is None:
275
+ output.append(self.pad_idx)
276
+ lengths.append(0)
277
+ else:
278
+ output.append(hash_trick(text, self.n_bins))
279
+ lengths.append(1)
280
+
281
+ tokens = torch.LongTensor(output).unsqueeze(1)
282
+ mask = length_to_mask(torch.IntTensor(lengths)).int()
283
+ return tokens, mask
284
+
285
+
286
+ class BaseConditioner(nn.Module):
287
+ """Base model for all conditioner modules.
288
+ We allow the output dim to be different than the hidden dim for two reasons:
289
+ 1) keep our LUTs small when the vocab is large;
290
+ 2) make all condition dims consistent.
291
+
292
+ Args:
293
+ dim (int): Hidden dim of the model.
294
+ output_dim (int): Output dim of the conditioner.
295
+ """
296
+ def __init__(self, dim: int, output_dim: int):
297
+ super().__init__()
298
+ self.dim = dim
299
+ self.output_dim = output_dim
300
+ self.output_proj = nn.Linear(dim, output_dim)
301
+
302
+ def tokenize(self, *args, **kwargs) -> tp.Any:
303
+ """Should be any part of the processing that will lead to a synchronization
304
+ point, e.g. BPE tokenization with transfer to the GPU.
305
+
306
+ The returned value will be saved and return later when calling forward().
307
+ """
308
+ raise NotImplementedError()
309
+
310
+ def forward(self, inputs: tp.Any) -> ConditionType:
311
+ """Gets input that should be used as conditioning (e.g, genre, description or a waveform).
312
+ Outputs a ConditionType, after the input data was embedded as a dense vector.
313
+
314
+ Returns:
315
+ ConditionType:
316
+ - A tensor of size [B, T, D] where B is the batch size, T is the length of the
317
+ output embedding and D is the dimension of the embedding.
318
+ - And a mask indicating where the padding tokens.
319
+ """
320
+ raise NotImplementedError()
321
+
322
+
323
+ class TextConditioner(BaseConditioner):
324
+ ...
325
+
326
+
327
+ class LUTConditioner(TextConditioner):
328
+ """Lookup table TextConditioner.
329
+
330
+ Args:
331
+ n_bins (int): Number of bins.
332
+ dim (int): Hidden dim of the model (text-encoder/LUT).
333
+ output_dim (int): Output dim of the conditioner.
334
+ tokenizer (str): Name of the tokenizer.
335
+ pad_idx (int, optional): Index for padding token. Defaults to 0.
336
+ """
337
+ def __init__(self, n_bins: int, dim: int, output_dim: int, tokenizer: str, pad_idx: int = 0):
338
+ super().__init__(dim, output_dim)
339
+ self.embed = nn.Embedding(n_bins, dim)
340
+ self.tokenizer: Tokenizer
341
+ if tokenizer == 'whitespace':
342
+ self.tokenizer = WhiteSpaceTokenizer(n_bins, pad_idx=pad_idx)
343
+ elif tokenizer == 'noop':
344
+ self.tokenizer = NoopTokenizer(n_bins, pad_idx=pad_idx)
345
+ else:
346
+ raise ValueError(f"unrecognized tokenizer `{tokenizer}`.")
347
+
348
+ def tokenize(self, x: tp.List[tp.Optional[str]]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
349
+ device = self.embed.weight.device
350
+ tokens, mask = self.tokenizer(x)
351
+ tokens, mask = tokens.to(device), mask.to(device)
352
+ return tokens, mask
353
+
354
+ def forward(self, inputs: tp.Tuple[torch.Tensor, torch.Tensor]) -> ConditionType:
355
+ tokens, mask = inputs
356
+ embeds = self.embed(tokens)
357
+ embeds = self.output_proj(embeds)
358
+ embeds = (embeds * mask.unsqueeze(-1))
359
+ return embeds, mask
360
+
361
+
362
+ class T5Conditioner(TextConditioner):
363
+ """T5-based TextConditioner.
364
+
365
+ Args:
366
+ name (str): Name of the T5 model.
367
+ output_dim (int): Output dim of the conditioner.
368
+ finetune (bool): Whether to fine-tune T5 at train time.
369
+ device (str): Device for T5 Conditioner.
370
+ autocast_dtype (tp.Optional[str], optional): Autocast dtype.
371
+ word_dropout (float, optional): Word dropout probability.
372
+ normalize_text (bool, optional): Whether to apply text normalization.
373
+ """
374
+ MODELS = ["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b",
375
+ "google/flan-t5-small", "google/flan-t5-base", "google/flan-t5-large",
376
+ "google/flan-t5-xl", "google/flan-t5-xxl"]
377
+ MODELS_DIMS = {
378
+ "t5-small": 512,
379
+ "t5-base": 768,
380
+ "t5-large": 1024,
381
+ "t5-3b": 1024,
382
+ "t5-11b": 1024,
383
+ "google/flan-t5-small": 512,
384
+ "google/flan-t5-base": 768,
385
+ "google/flan-t5-large": 1024,
386
+ "google/flan-t5-3b": 1024,
387
+ "google/flan-t5-11b": 1024,
388
+ }
389
+
390
+ def __init__(self, name: str, output_dim: int, finetune: bool, device: str,
391
+ autocast_dtype: tp.Optional[str] = 'float32', word_dropout: float = 0.,
392
+ normalize_text: bool = False):
393
+ assert name in self.MODELS, f"Unrecognized t5 model name (should in {self.MODELS})"
394
+ super().__init__(self.MODELS_DIMS[name], output_dim)
395
+ self.device = device
396
+ self.name = name
397
+ self.finetune = finetune
398
+ self.word_dropout = word_dropout
399
+ if autocast_dtype is None or self.device == 'cpu':
400
+ self.autocast = TorchAutocast(enabled=False)
401
+ if self.device != 'cpu':
402
+ logger.warning("T5 has no autocast, this might lead to NaN")
403
+ else:
404
+ dtype = getattr(torch, autocast_dtype)
405
+ assert isinstance(dtype, torch.dtype)
406
+ logger.info(f"T5 will be evaluated with autocast as {autocast_dtype}")
407
+ self.autocast = TorchAutocast(enabled=True, device_type=self.device, dtype=dtype)
408
+ # Let's disable logging temporarily because T5 will vomit some errors otherwise.
409
+ # thanks https://gist.github.com/simon-weber/7853144
410
+ previous_level = logging.root.manager.disable
411
+ logging.disable(logging.ERROR)
412
+ with warnings.catch_warnings():
413
+ warnings.simplefilter("ignore")
414
+ try:
415
+ self.t5_tokenizer = T5Tokenizer.from_pretrained(name)
416
+ t5 = T5EncoderModel.from_pretrained(name).train(mode=finetune)
417
+ finally:
418
+ logging.disable(previous_level)
419
+ if finetune:
420
+ self.t5 = t5
421
+ else:
422
+ # this makes sure that the t5 models is not part
423
+ # of the saved checkpoint
424
+ self.__dict__['t5'] = t5.to(device)
425
+
426
+ self.normalize_text = normalize_text
427
+ if normalize_text:
428
+ self.text_normalizer = WhiteSpaceTokenizer(1, lemma=True, stopwords=True)
429
+
430
+ def tokenize(self, x: tp.List[tp.Optional[str]]) -> tp.Dict[str, torch.Tensor]:
431
+ # if current sample doesn't have a certain attribute, replace with empty string
432
+ entries: tp.List[str] = [xi if xi is not None else "" for xi in x]
433
+ if self.normalize_text:
434
+ _, _, entries = self.text_normalizer(entries, return_text=True)
435
+ if self.word_dropout > 0. and self.training:
436
+ new_entries = []
437
+ for entry in entries:
438
+ words = [word for word in entry.split(" ") if random.random() >= self.word_dropout]
439
+ new_entries.append(" ".join(words))
440
+ entries = new_entries
441
+
442
+ empty_idx = torch.LongTensor([i for i, xi in enumerate(entries) if xi == ""])
443
+
444
+ inputs = self.t5_tokenizer(entries, return_tensors='pt', padding=True).to(self.device)
445
+ mask = inputs['attention_mask']
446
+ mask[empty_idx, :] = 0 # zero-out index where the input is non-existant
447
+ return inputs
448
+
449
+ def forward(self, inputs: tp.Dict[str, torch.Tensor]) -> ConditionType:
450
+ mask = inputs['attention_mask']
451
+ with torch.set_grad_enabled(self.finetune), self.autocast:
452
+ embeds = self.t5(**inputs).last_hidden_state
453
+ embeds = self.output_proj(embeds.to(self.output_proj.weight))
454
+ embeds = (embeds * mask.unsqueeze(-1))
455
+ return embeds, mask
456
+
457
+
458
+ class WaveformConditioner(BaseConditioner):
459
+ """Base class for all conditioners that take a waveform as input.
460
+ Classes that inherit must implement `_get_wav_embedding` that outputs
461
+ a continuous tensor, and `_downsampling_factor` that returns the down-sampling
462
+ factor of the embedding model.
463
+
464
+ Args:
465
+ dim (int): The internal representation dimension.
466
+ output_dim (int): Output dimension.
467
+ device (tp.Union[torch.device, str]): Device.
468
+ """
469
+ def __init__(self, dim: int, output_dim: int, device: tp.Union[torch.device, str]):
470
+ super().__init__(dim, output_dim)
471
+ self.device = device
472
+ # if False no masking is done, used in ChromaStemConditioner when completing by periodicity a sample.
473
+ self._use_masking = True
474
+
475
+ def tokenize(self, x: WavCondition) -> WavCondition:
476
+ wav, length, sample_rate, path, seek_time = x
477
+ assert length is not None
478
+ return WavCondition(wav.to(self.device), length.to(self.device), sample_rate, path, seek_time)
479
+
480
+ def _get_wav_embedding(self, x: WavCondition) -> torch.Tensor:
481
+ """Gets as input a WavCondition and returns a dense embedding."""
482
+ raise NotImplementedError()
483
+
484
+ def _downsampling_factor(self):
485
+ """Returns the downsampling factor of the embedding model."""
486
+ raise NotImplementedError()
487
+
488
+ def forward(self, x: WavCondition) -> ConditionType:
489
+ """Extract condition embedding and mask from a waveform and its metadata.
490
+ Args:
491
+ x (WavCondition): Waveform condition containing raw waveform and metadata.
492
+ Returns:
493
+ ConditionType: a dense vector representing the conditioning along with its mask
494
+ """
495
+ wav, lengths, *_ = x
496
+ with torch.no_grad():
497
+ embeds = self._get_wav_embedding(x)
498
+ embeds = embeds.to(self.output_proj.weight)
499
+ embeds = self.output_proj(embeds)
500
+
501
+ if lengths is not None and self._use_masking:
502
+ lengths = lengths / self._downsampling_factor()
503
+ mask = length_to_mask(lengths, max_len=embeds.shape[1]).int() # type: ignore
504
+ else:
505
+ mask = torch.ones_like(embeds[..., 0])
506
+ embeds = (embeds * mask.unsqueeze(-1))
507
+ return embeds, mask
508
+
509
+
510
+ class ChromaStemConditioner(WaveformConditioner):
511
+ """Chroma conditioner based on stems.
512
+ The ChromaStemConditioner uses DEMUCS to first filter out drums and bass, as
513
+ the drums and bass often dominate the chroma leading to the chroma features
514
+ not containing information about the melody.
515
+
516
+ Args:
517
+ output_dim (int): Output dimension for the conditioner.
518
+ sample_rate (int): Sample rate for the chroma extractor.
519
+ n_chroma (int): Number of chroma bins for the chroma extractor.
520
+ radix2_exp (int): Size of stft window for the chroma extractor (power of 2, e.g. 12 -> 2^12).
521
+ duration (int): duration used during training. This is later used for correct padding
522
+ in case we are using chroma as prefix.
523
+ match_len_on_eval (bool, optional): if True then all chromas are padded to the training
524
+ duration. Defaults to False.
525
+ eval_wavs (str, optional): path to a dataset manifest with waveform, this waveforms are used as
526
+ conditions during eval (for cases where we don't want to leak test conditions like MusicCaps).
527
+ Defaults to None.
528
+ n_eval_wavs (int, optional): limits the number of waveforms used for conditioning. Defaults to 0.
529
+ device (tp.Union[torch.device, str], optional): Device for the conditioner.
530
+ **kwargs: Additional parameters for the chroma extractor.
531
+ """
532
+ def __init__(self, output_dim: int, sample_rate: int, n_chroma: int, radix2_exp: int,
533
+ duration: float, match_len_on_eval: bool = True, eval_wavs: tp.Optional[str] = None,
534
+ n_eval_wavs: int = 0, cache_path: tp.Optional[tp.Union[str, Path]] = None,
535
+ device: tp.Union[torch.device, str] = 'cpu', **kwargs):
536
+ from demucs import pretrained
537
+ super().__init__(dim=n_chroma, output_dim=output_dim, device=device)
538
+ self.autocast = TorchAutocast(enabled=device != 'cpu', device_type=self.device, dtype=torch.float32)
539
+ self.sample_rate = sample_rate
540
+ self.match_len_on_eval = match_len_on_eval
541
+ if match_len_on_eval:
542
+ self._use_masking = False
543
+ self.duration = duration
544
+ self.__dict__['demucs'] = pretrained.get_model('htdemucs').to(device)
545
+ stem_sources: list = self.demucs.sources # type: ignore
546
+ self.stem_indices = torch.LongTensor([stem_sources.index('vocals'), stem_sources.index('other')]).to(device)
547
+ self.chroma = ChromaExtractor(sample_rate=sample_rate, n_chroma=n_chroma,
548
+ radix2_exp=radix2_exp, **kwargs).to(device)
549
+ self.chroma_len = self._get_chroma_len()
550
+ self.eval_wavs: tp.Optional[torch.Tensor] = self._load_eval_wavs(eval_wavs, n_eval_wavs)
551
+ self.cache = None
552
+ if cache_path is not None:
553
+ self.cache = EmbeddingCache(Path(cache_path) / 'wav', self.device,
554
+ compute_embed_fn=self._get_full_chroma_for_cache,
555
+ extract_embed_fn=self._extract_chroma_chunk)
556
+
557
+ def _downsampling_factor(self) -> int:
558
+ return self.chroma.winhop
559
+
560
+ def _load_eval_wavs(self, path: tp.Optional[str], num_samples: int) -> tp.Optional[torch.Tensor]:
561
+ """Load pre-defined waveforms from a json.
562
+ These waveforms will be used for chroma extraction during evaluation.
563
+ This is done to make the evaluation on MusicCaps fair (we shouldn't see the chromas of MusicCaps).
564
+ """
565
+ if path is None:
566
+ return None
567
+
568
+ logger.info(f"Loading evaluation wavs from {path}")
569
+ from audiocraft.data.audio_dataset import AudioDataset
570
+ dataset: AudioDataset = AudioDataset.from_meta(
571
+ path, segment_duration=self.duration, min_audio_duration=self.duration,
572
+ sample_rate=self.sample_rate, channels=1)
573
+
574
+ if len(dataset) > 0:
575
+ eval_wavs = dataset.collater([dataset[i] for i in range(num_samples)]).to(self.device)
576
+ logger.info(f"Using {len(eval_wavs)} evaluation wavs for chroma-stem conditioner")
577
+ return eval_wavs
578
+ else:
579
+ raise ValueError("Could not find evaluation wavs, check lengths of wavs")
580
+
581
+ def reset_eval_wavs(self, eval_wavs: tp.Optional[torch.Tensor]) -> None:
582
+ self.eval_wavs = eval_wavs
583
+
584
+ def has_eval_wavs(self) -> bool:
585
+ return self.eval_wavs is not None
586
+
587
+ def _sample_eval_wavs(self, num_samples: int) -> torch.Tensor:
588
+ """Sample wavs from a predefined list."""
589
+ assert self.eval_wavs is not None, "Cannot sample eval wavs as no eval wavs provided."
590
+ total_eval_wavs = len(self.eval_wavs)
591
+ out = self.eval_wavs
592
+ if num_samples > total_eval_wavs:
593
+ out = self.eval_wavs.repeat(num_samples // total_eval_wavs + 1, 1, 1)
594
+ return out[torch.randperm(len(out))][:num_samples]
595
+
596
+ def _get_chroma_len(self) -> int:
597
+ """Get length of chroma during training."""
598
+ dummy_wav = torch.zeros((1, int(self.sample_rate * self.duration)), device=self.device)
599
+ dummy_chr = self.chroma(dummy_wav)
600
+ return dummy_chr.shape[1]
601
+
602
+ @torch.no_grad()
603
+ def _get_stemmed_wav(self, wav: torch.Tensor, sample_rate: int) -> torch.Tensor:
604
+ """Get parts of the wav that holds the melody, extracting the main stems from the wav."""
605
+ from demucs.apply import apply_model
606
+ from demucs.audio import convert_audio
607
+ with self.autocast:
608
+ wav = convert_audio(
609
+ wav, sample_rate, self.demucs.samplerate, self.demucs.audio_channels) # type: ignore
610
+ stems = apply_model(self.demucs, wav, device=self.device)
611
+ stems = stems[:, self.stem_indices] # extract relevant stems for melody conditioning
612
+ mix_wav = stems.sum(1) # merge extracted stems to single waveform
613
+ mix_wav = convert_audio(mix_wav, self.demucs.samplerate, self.sample_rate, 1) # type: ignore
614
+ return mix_wav
615
+
616
+ @torch.no_grad()
617
+ def _extract_chroma(self, wav: torch.Tensor) -> torch.Tensor:
618
+ """Extract chroma features from the waveform."""
619
+ with self.autocast:
620
+ return self.chroma(wav)
621
+
622
+ @torch.no_grad()
623
+ def _compute_wav_embedding(self, wav: torch.Tensor, sample_rate: int) -> torch.Tensor:
624
+ """Compute wav embedding, applying stem and chroma extraction."""
625
+ # avoid 0-size tensors when we are working with null conds
626
+ if wav.shape[-1] == 1:
627
+ return self._extract_chroma(wav)
628
+ stems = self._get_stemmed_wav(wav, sample_rate)
629
+ chroma = self._extract_chroma(stems)
630
+ return chroma
631
+
632
+ @torch.no_grad()
633
+ def _get_full_chroma_for_cache(self, path: tp.Union[str, Path], x: WavCondition, idx: int) -> torch.Tensor:
634
+ """Extract chroma from the whole audio waveform at the given path."""
635
+ wav, sr = audio_read(path)
636
+ wav = wav[None].to(self.device)
637
+ wav = convert_audio(wav, sr, self.sample_rate, to_channels=1)
638
+ chroma = self._compute_wav_embedding(wav, self.sample_rate)[0]
639
+ return chroma
640
+
641
+ def _extract_chroma_chunk(self, full_chroma: torch.Tensor, x: WavCondition, idx: int) -> torch.Tensor:
642
+ """Extract a chunk of chroma from the full chroma derived from the full waveform."""
643
+ wav_length = x.wav.shape[-1]
644
+ seek_time = x.seek_time[idx]
645
+ assert seek_time is not None, (
646
+ "WavCondition seek_time is required "
647
+ "when extracting chroma chunks from pre-computed chroma.")
648
+ full_chroma = full_chroma.float()
649
+ frame_rate = self.sample_rate / self._downsampling_factor()
650
+ target_length = int(frame_rate * wav_length / self.sample_rate)
651
+ index = int(frame_rate * seek_time)
652
+ out = full_chroma[index: index + target_length]
653
+ out = F.pad(out[None], (0, 0, 0, target_length - out.shape[0]))[0]
654
+ return out.to(self.device)
655
+
656
+ @torch.no_grad()
657
+ def _get_wav_embedding(self, x: WavCondition) -> torch.Tensor:
658
+ """Get the wav embedding from the WavCondition.
659
+ The conditioner will either extract the embedding on-the-fly computing it from the condition wav directly
660
+ or will rely on the embedding cache to load the pre-computed embedding if relevant.
661
+ """
662
+ sampled_wav: tp.Optional[torch.Tensor] = None
663
+ if not self.training and self.eval_wavs is not None:
664
+ warn_once(logger, "Using precomputed evaluation wavs!")
665
+ sampled_wav = self._sample_eval_wavs(len(x.wav))
666
+
667
+ no_undefined_paths = all(p is not None for p in x.path)
668
+ no_nullified_cond = x.wav.shape[-1] > 1
669
+ if sampled_wav is not None:
670
+ chroma = self._compute_wav_embedding(sampled_wav, self.sample_rate)
671
+ elif self.cache is not None and no_undefined_paths and no_nullified_cond:
672
+ paths = [Path(p) for p in x.path if p is not None]
673
+ chroma = self.cache.get_embed_from_cache(paths, x)
674
+ else:
675
+ assert all(sr == x.sample_rate[0] for sr in x.sample_rate), "All sample rates in batch should be equal."
676
+ chroma = self._compute_wav_embedding(x.wav, x.sample_rate[0])
677
+
678
+ if self.match_len_on_eval:
679
+ B, T, C = chroma.shape
680
+ if T > self.chroma_len:
681
+ chroma = chroma[:, :self.chroma_len]
682
+ logger.debug(f"Chroma was truncated to match length! ({T} -> {chroma.shape[1]})")
683
+ elif T < self.chroma_len:
684
+ n_repeat = int(math.ceil(self.chroma_len / T))
685
+ chroma = chroma.repeat(1, n_repeat, 1)
686
+ chroma = chroma[:, :self.chroma_len]
687
+ logger.debug(f"Chroma was repeated to match length! ({T} -> {chroma.shape[1]})")
688
+
689
+ return chroma
690
+
691
+ def tokenize(self, x: WavCondition) -> WavCondition:
692
+ """Apply WavConditioner tokenization and populate cache if needed."""
693
+ x = super().tokenize(x)
694
+ no_undefined_paths = all(p is not None for p in x.path)
695
+ if self.cache is not None and no_undefined_paths:
696
+ paths = [Path(p) for p in x.path if p is not None]
697
+ self.cache.populate_embed_cache(paths, x)
698
+ return x
699
+
700
+
701
+ class JointEmbeddingConditioner(BaseConditioner):
702
+ """Joint embedding conditioning supporting both audio or text conditioning.
703
+
704
+ Args:
705
+ dim (int): Dimension.
706
+ output_dim (int): Output dimension.
707
+ device (str): Device.
708
+ attribute (str): Attribute used by the conditioner.
709
+ autocast_dtype (str): Autocast for the conditioner.
710
+ quantize (bool): Whether to quantize the CLAP embedding.
711
+ n_q (int): Number of residual quantizers (used if quantize is true).
712
+ bins (int): Quantizers' codebooks size (used if quantize is true).
713
+ kwargs: Additional parameters for residual vector quantizer.
714
+ """
715
+ def __init__(self, dim: int, output_dim: int, device: str, attribute: str,
716
+ autocast_dtype: tp.Optional[str] = 'float32', quantize: bool = True,
717
+ n_q: int = 12, bins: int = 1024, **kwargs):
718
+ super().__init__(dim=dim, output_dim=output_dim)
719
+ self.device = device
720
+ self.attribute = attribute
721
+ if autocast_dtype is None or device == 'cpu':
722
+ self.autocast = TorchAutocast(enabled=False)
723
+ logger.warning("JointEmbeddingConditioner has no autocast, this might lead to NaN.")
724
+ else:
725
+ dtype = getattr(torch, autocast_dtype)
726
+ assert isinstance(dtype, torch.dtype)
727
+ logger.info(f"JointEmbeddingConditioner will be evaluated with autocast as {autocast_dtype}.")
728
+ self.autocast = TorchAutocast(enabled=True, device_type=self.device, dtype=dtype)
729
+ # residual vector quantizer to discretize the conditioned embedding
730
+ self.quantizer: tp.Optional[ResidualVectorQuantizer] = None
731
+ if quantize:
732
+ self.quantizer = ResidualVectorQuantizer(dim, n_q=n_q, bins=bins, **kwargs)
733
+
734
+ def _get_embed(self, x: JointEmbedCondition) -> tp.Tuple[torch.Tensor, torch.Tensor]:
735
+ """Get joint embedding in latent space from the inputs.
736
+
737
+ Returns:
738
+ tuple[torch.Tensor, torch.Tensor]: Tensor for the latent embedding
739
+ and corresponding empty indexes.
740
+ """
741
+ raise NotImplementedError()
742
+
743
+ def forward(self, x: JointEmbedCondition) -> ConditionType:
744
+ with self.autocast:
745
+ embed, empty_idx = self._get_embed(x)
746
+ if self.quantizer is not None:
747
+ embed = embed.view(-1, self.dim, 1)
748
+ q_res = self.quantizer(embed, frame_rate=1)
749
+ out_embed = q_res.x.view(-1, self.dim)
750
+ else:
751
+ out_embed = embed
752
+ out_embed = self.output_proj(out_embed).view(-1, 1, self.output_dim)
753
+ mask = torch.ones(*out_embed.shape[:2], device=out_embed.device)
754
+ mask[empty_idx, :] = 0 # zero-out index where the input is non-existant
755
+ out_embed = (out_embed * mask.unsqueeze(-1))
756
+ return out_embed, mask
757
+
758
+ def tokenize(self, x: JointEmbedCondition) -> JointEmbedCondition:
759
+ return x
760
+
761
+
762
+ class CLAPEmbeddingConditioner(JointEmbeddingConditioner):
763
+ """Joint Embedding conditioner based on pre-trained CLAP model.
764
+
765
+ This CLAP-based conditioner supports a caching mechanism
766
+ over the computed embeddings for faster training.
767
+
768
+ Args:
769
+ dim (int): Dimension.
770
+ output_dim (int): Output dimension.
771
+ device (str): Device.
772
+ attribute (str): Attribute used by the conditioner.
773
+ quantize (bool): Whether to quantize the CLAP embedding.
774
+ n_q (int): Number of residual quantizers (used if quantize is true).
775
+ bins (int): Quantizers' codebooks size (used if quantize is true).
776
+ checkpoint (str): Path to CLAP checkpoint.
777
+ model_arch (str): CLAP model architecture.
778
+ enable_fusion (bool): Enable fusion for CLAP model.
779
+ sample_rate (int): Sample rate used by CLAP model.
780
+ max_audio_length (float): Maximum audio length for CLAP model.
781
+ audio_stride (float): Stride to use for getting a CLAP embedding on the full sequence.
782
+ normalize (bool): Whether to normalize the CLAP embedding.
783
+ text_p (float): Probability of using text representation instead of audio at train time.
784
+ batch_size (Optional[int]): Batch size for CLAP embedding computation.
785
+ autocast_dtype (str): Autocast for the conditioner.
786
+ cache_path (Optional[str]): Path for pre-computed embeddings caching.
787
+ kwargs: Additional parameters for residual vector quantizer.
788
+ """
789
+ def __init__(self, dim: int, output_dim: int, device: str, attribute: str,
790
+ quantize: bool, n_q: int, bins: int, checkpoint: tp.Union[str, Path], model_arch: str,
791
+ enable_fusion: bool, sample_rate: int, max_audio_length: int, audio_stride: int,
792
+ normalize: bool, text_p: bool, batch_size: tp.Optional[int] = None,
793
+ autocast_dtype: tp.Optional[str] = 'float32', cache_path: tp.Optional[str] = None, **kwargs):
794
+ try:
795
+ import laion_clap # type: ignore
796
+ except ImportError:
797
+ raise ImportError("Please install CLAP to use the CLAPEmbeddingConditioner: 'pip install laion_clap'")
798
+ warnings.warn("Sample rate for CLAP conditioner was fixed in version v1.1.0, (from 44.1 to 48 kHz). "
799
+ "Please retrain all models.")
800
+ checkpoint = AudioCraftEnvironment.resolve_reference_path(checkpoint)
801
+ clap_tokenize = RobertaTokenizer.from_pretrained('roberta-base')
802
+ clap_model = laion_clap.CLAP_Module(enable_fusion=enable_fusion, amodel=model_arch)
803
+ load_clap_state_dict(clap_model, checkpoint)
804
+ clap_model.eval()
805
+ clap_model.to(device)
806
+ super().__init__(dim=dim, output_dim=output_dim, device=device, attribute=attribute,
807
+ autocast_dtype=autocast_dtype, quantize=quantize, n_q=n_q, bins=bins,
808
+ **kwargs)
809
+ self.checkpoint = checkpoint
810
+ self.enable_fusion = enable_fusion
811
+ self.model_arch = model_arch
812
+ self.clap: laion_clap.CLAP_Module
813
+ self.clap_tokenize: RobertaTokenizer
814
+ self.clap_sample_rate = sample_rate
815
+ self.clap_max_frames = int(self.clap_sample_rate * max_audio_length)
816
+ self.clap_stride = int(self.clap_sample_rate * audio_stride)
817
+ self.batch_size = batch_size or 1
818
+ self.normalize = normalize
819
+ self.text_p = text_p
820
+ self.__dict__['clap_tokenize'] = clap_tokenize
821
+ self.__dict__['clap'] = clap_model
822
+ self.wav_cache, self.text_cache = None, None
823
+ if cache_path is not None:
824
+ self.wav_cache = EmbeddingCache(Path(cache_path) / 'wav', self.device,
825
+ compute_embed_fn=self._get_wav_embedding_for_cache,
826
+ extract_embed_fn=self._extract_wav_embedding_chunk)
827
+ self.text_cache = EmbeddingCache(Path(cache_path) / 'text', self.device,
828
+ compute_embed_fn=self._get_text_embedding_for_cache)
829
+
830
+ def _tokenizer(self, texts: tp.Union[str, tp.List[str]]) -> dict:
831
+ # we use the default params from CLAP module here as well
832
+ return self.clap_tokenize(texts, padding="max_length", truncation=True, max_length=77, return_tensors="pt")
833
+
834
+ def _compute_text_embedding(self, text: tp.List[str]) -> torch.Tensor:
835
+ """Compute text embedding from CLAP model on a given a batch of text.
836
+
837
+ Args:
838
+ text (list[str]): List of text for the batch, with B items.
839
+ Returns:
840
+ torch.Tensor: CLAP embedding derived from text, of shape [B, 1, D], with D the CLAP embedding dimension.
841
+ """
842
+ with torch.no_grad():
843
+ embed = self.clap.get_text_embedding(text, tokenizer=self._tokenizer, use_tensor=True)
844
+ return embed.view(embed.size(0), 1, embed.size(-1))
845
+
846
+ def _get_text_embedding_for_cache(self, path: tp.Union[Path, str],
847
+ x: JointEmbedCondition, idx: int) -> torch.Tensor:
848
+ """Get text embedding function for the cache."""
849
+ text = x.text[idx]
850
+ text = text if text is not None else ""
851
+ return self._compute_text_embedding([text])[0]
852
+
853
+ def _preprocess_wav(self, wav: torch.Tensor, length: torch.Tensor, sample_rates: tp.List[int]) -> torch.Tensor:
854
+ """Preprocess wav to expected format by CLAP model.
855
+
856
+ Args:
857
+ wav (torch.Tensor): Audio wav, of shape [B, C, T].
858
+ length (torch.Tensor): Actual length of the audio for each item in the batch, of shape [B].
859
+ sample_rates (list[int]): Sample rates for each sample in the batch
860
+ Returns:
861
+ torch.Tensor: Audio wav of shape [B, T].
862
+ """
863
+ assert wav.dim() == 3, "Expecting wav to be [B, C, T]"
864
+ if sample_rates is not None:
865
+ _wav = []
866
+ for i, audio in enumerate(wav):
867
+ sr = sample_rates[i]
868
+ audio = convert_audio(audio, from_rate=sr, to_rate=self.clap_sample_rate, to_channels=1)
869
+ _wav.append(audio)
870
+ wav = torch.stack(_wav, dim=0)
871
+ wav = wav.mean(dim=1)
872
+ return wav
873
+
874
+ def _compute_wav_embedding(self, wav: torch.Tensor, length: torch.Tensor,
875
+ sample_rates: tp.List[int], reduce_mean: bool = False) -> torch.Tensor:
876
+ """Compute audio wave embedding from CLAP model.
877
+
878
+ Since CLAP operates on a fixed sequence length audio inputs and we need to process longer audio sequences,
879
+ we calculate the wav embeddings on `clap_max_frames` windows with `clap_stride`-second stride and
880
+ average the resulting embeddings.
881
+
882
+ Args:
883
+ wav (torch.Tensor): Audio wav, of shape [B, C, T].
884
+ length (torch.Tensor): Actual length of the audio for each item in the batch, of shape [B].
885
+ sample_rates (list[int]): Sample rates for each sample in the batch.
886
+ reduce_mean (bool): Whether to get the average tensor.
887
+ Returns:
888
+ torch.Tensor: Audio embedding of shape [B, F, D], F being the number of chunks, D the dimension.
889
+ """
890
+ with torch.no_grad():
891
+ wav = self._preprocess_wav(wav, length, sample_rates)
892
+ B, T = wav.shape
893
+ if T >= self.clap_max_frames:
894
+ wav = wav.unfold(-1, self.clap_max_frames, self.clap_stride) # [B, F, T]
895
+ else:
896
+ wav = wav.view(-1, 1, T) # [B, F, T] with F=1
897
+ wav = einops.rearrange(wav, 'b f t -> (b f) t')
898
+ embed_list = []
899
+ for i in range(0, wav.size(0), self.batch_size):
900
+ _wav = wav[i:i+self.batch_size, ...]
901
+ _embed = self.clap.get_audio_embedding_from_data(_wav, use_tensor=True)
902
+ embed_list.append(_embed)
903
+ embed = torch.cat(embed_list, dim=0)
904
+ embed = einops.rearrange(embed, '(b f) d -> b f d', b=B)
905
+ if reduce_mean:
906
+ embed = embed.mean(dim=1, keepdim=True)
907
+ return embed # [B, F, D] with F=1 if reduce_mean is True
908
+
909
+ def _get_wav_embedding_for_cache(self, path: tp.Union[str, Path],
910
+ x: JointEmbedCondition, idx: int) -> torch.Tensor:
911
+ """Compute audio wave embedding for the cache.
912
+ The embedding is computed on a given audio read from file.
913
+
914
+ Args:
915
+ path (str or Path): Path to the full audio file.
916
+ Returns:
917
+ torch.Tensor: Single-item tensor of shape [F, D], F being the number of chunks, D the dimension.
918
+ """
919
+ wav, sr = audio_read(path) # [C, T]
920
+ wav = wav.unsqueeze(0).to(self.device) # [1, C, T]
921
+ wav_len = torch.LongTensor([wav.shape[-1]]).to(self.device)
922
+ embed = self._compute_wav_embedding(wav, wav_len, [sr], reduce_mean=False) # [B, F, D]
923
+ return embed.squeeze(0) # [F, D]
924
+
925
+ def _extract_wav_embedding_chunk(self, full_embed: torch.Tensor, x: JointEmbedCondition, idx: int) -> torch.Tensor:
926
+ """Extract the chunk of embedding matching the seek_time and length from the full CLAP audio embedding.
927
+
928
+ Args:
929
+ full_embed (torch.Tensor): CLAP embedding computed on the full wave, of shape [F, D].
930
+ x (JointEmbedCondition): Joint embedding condition for the full batch.
931
+ idx (int): Index considered for the given embedding to extract.
932
+ Returns:
933
+ torch.Tensor: Wav embedding averaged on sliding window, of shape [1, D].
934
+ """
935
+ sample_rate = x.sample_rate[idx]
936
+ seek_time = x.seek_time[idx]
937
+ seek_time = 0. if seek_time is None else seek_time
938
+ clap_stride = int(self.clap_stride / self.clap_sample_rate) * sample_rate
939
+ end_seek_time = seek_time + self.clap_max_frames / self.clap_sample_rate
940
+ start_offset = int(seek_time * sample_rate // clap_stride)
941
+ end_offset = int(end_seek_time * sample_rate // clap_stride)
942
+ wav_embed = full_embed[start_offset:end_offset, ...]
943
+ wav_embed = wav_embed.mean(dim=0, keepdim=True)
944
+ return wav_embed.to(self.device) # [F, D]
945
+
946
+ def _get_text_embedding(self, x: JointEmbedCondition) -> torch.Tensor:
947
+ """Get CLAP embedding from a batch of text descriptions."""
948
+ no_nullified_cond = x.wav.shape[-1] > 1 # we don't want to read from cache when condition dropout
949
+ if self.text_cache is not None and no_nullified_cond:
950
+ assert all(p is not None for p in x.path), "Cache requires all JointEmbedCondition paths to be provided"
951
+ paths = [Path(p) for p in x.path if p is not None]
952
+ embed = self.text_cache.get_embed_from_cache(paths, x)
953
+ else:
954
+ text = [xi if xi is not None else "" for xi in x.text]
955
+ embed = self._compute_text_embedding(text)
956
+ if self.normalize:
957
+ embed = torch.nn.functional.normalize(embed, p=2.0, dim=-1)
958
+ return embed
959
+
960
+ def _get_wav_embedding(self, x: JointEmbedCondition) -> torch.Tensor:
961
+ """Get CLAP embedding from a batch of audio tensors (and corresponding sample rates)."""
962
+ no_undefined_paths = all(p is not None for p in x.path)
963
+ no_nullified_cond = x.wav.shape[-1] > 1 # we don't want to read from cache when condition dropout
964
+ if self.wav_cache is not None and no_undefined_paths and no_nullified_cond:
965
+ paths = [Path(p) for p in x.path if p is not None]
966
+ embed = self.wav_cache.get_embed_from_cache(paths, x)
967
+ else:
968
+ embed = self._compute_wav_embedding(x.wav, x.length, x.sample_rate, reduce_mean=True)
969
+ if self.normalize:
970
+ embed = torch.nn.functional.normalize(embed, p=2.0, dim=-1)
971
+ return embed
972
+
973
+ def tokenize(self, x: JointEmbedCondition) -> JointEmbedCondition:
974
+ # Trying to limit as much as possible sync points when the cache is warm.
975
+ no_undefined_paths = all(p is not None for p in x.path)
976
+ if self.wav_cache is not None and no_undefined_paths:
977
+ assert all([p is not None for p in x.path]), "Cache requires all JointEmbedCondition paths to be provided"
978
+ paths = [Path(p) for p in x.path if p is not None]
979
+ self.wav_cache.populate_embed_cache(paths, x)
980
+ if self.text_cache is not None and no_undefined_paths:
981
+ assert all([p is not None for p in x.path]), "Cache requires all JointEmbedCondition paths to be provided"
982
+ paths = [Path(p) for p in x.path if p is not None]
983
+ self.text_cache.populate_embed_cache(paths, x)
984
+ return x
985
+
986
+ def _get_embed(self, x: JointEmbedCondition) -> tp.Tuple[torch.Tensor, torch.Tensor]:
987
+ """Extract shared latent representation from either the wav or the text using CLAP."""
988
+ # decide whether to use text embedding at train time or not
989
+ use_text_embed = random.random() < self.text_p
990
+ if self.training and not use_text_embed:
991
+ embed = self._get_wav_embedding(x)
992
+ empty_idx = torch.LongTensor([]) # we assume we always have the audio wav
993
+ else:
994
+ embed = self._get_text_embedding(x)
995
+ empty_idx = torch.LongTensor([i for i, xi in enumerate(x.text) if xi is None or xi == ""])
996
+ return embed, empty_idx
997
+
998
+
999
+ def dropout_condition(sample: ConditioningAttributes, condition_type: str, condition: str) -> ConditioningAttributes:
1000
+ """Utility function for nullifying an attribute inside an ConditioningAttributes object.
1001
+ If the condition is of type "wav", then nullify it using `nullify_condition` function.
1002
+ If the condition is of any other type, set its value to None.
1003
+ Works in-place.
1004
+ """
1005
+ if condition_type not in ['text', 'wav', 'joint_embed']:
1006
+ raise ValueError(
1007
+ "dropout_condition got an unexpected condition type!"
1008
+ f" expected 'text', 'wav' or 'joint_embed' but got '{condition_type}'"
1009
+ )
1010
+
1011
+ if condition not in getattr(sample, condition_type):
1012
+ raise ValueError(
1013
+ "dropout_condition received an unexpected condition!"
1014
+ f" expected wav={sample.wav.keys()} and text={sample.text.keys()}"
1015
+ f" but got '{condition}' of type '{condition_type}'!"
1016
+ )
1017
+
1018
+ if condition_type == 'wav':
1019
+ wav_cond = sample.wav[condition]
1020
+ sample.wav[condition] = nullify_wav(wav_cond)
1021
+ elif condition_type == 'joint_embed':
1022
+ embed = sample.joint_embed[condition]
1023
+ sample.joint_embed[condition] = nullify_joint_embed(embed)
1024
+ else:
1025
+ sample.text[condition] = None
1026
+
1027
+ return sample
1028
+
1029
+
1030
+ class DropoutModule(nn.Module):
1031
+ """Base module for all dropout modules."""
1032
+ def __init__(self, seed: int = 1234):
1033
+ super().__init__()
1034
+ self.rng = torch.Generator()
1035
+ self.rng.manual_seed(seed)
1036
+
1037
+
1038
+ class AttributeDropout(DropoutModule):
1039
+ """Dropout with a given probability per attribute.
1040
+ This is different from the behavior of ClassifierFreeGuidanceDropout as this allows for attributes
1041
+ to be dropped out separately. For example, "artist" can be dropped while "genre" remains.
1042
+ This is in contrast to ClassifierFreeGuidanceDropout where if "artist" is dropped "genre"
1043
+ must also be dropped.
1044
+
1045
+ Args:
1046
+ p (tp.Dict[str, float]): A dict mapping between attributes and dropout probability. For example:
1047
+ ...
1048
+ "genre": 0.1,
1049
+ "artist": 0.5,
1050
+ "wav": 0.25,
1051
+ ...
1052
+ active_on_eval (bool, optional): Whether the dropout is active at eval. Default to False.
1053
+ seed (int, optional): Random seed.
1054
+ """
1055
+ def __init__(self, p: tp.Dict[str, tp.Dict[str, float]], active_on_eval: bool = False, seed: int = 1234):
1056
+ super().__init__(seed=seed)
1057
+ self.active_on_eval = active_on_eval
1058
+ # construct dict that return the values from p otherwise 0
1059
+ self.p = {}
1060
+ for condition_type, probs in p.items():
1061
+ self.p[condition_type] = defaultdict(lambda: 0, probs)
1062
+
1063
+ def forward(self, samples: tp.List[ConditioningAttributes]) -> tp.List[ConditioningAttributes]:
1064
+ """
1065
+ Args:
1066
+ samples (list[ConditioningAttributes]): List of conditions.
1067
+ Returns:
1068
+ list[ConditioningAttributes]: List of conditions after certain attributes were set to None.
1069
+ """
1070
+ if not self.training and not self.active_on_eval:
1071
+ return samples
1072
+
1073
+ samples = deepcopy(samples)
1074
+ for condition_type, ps in self.p.items(): # for condition types [text, wav]
1075
+ for condition, p in ps.items(): # for attributes of each type (e.g., [artist, genre])
1076
+ if torch.rand(1, generator=self.rng).item() < p:
1077
+ for sample in samples:
1078
+ dropout_condition(sample, condition_type, condition)
1079
+ return samples
1080
+
1081
+ def __repr__(self):
1082
+ return f"AttributeDropout({dict(self.p)})"
1083
+
1084
+
1085
+ class ClassifierFreeGuidanceDropout(DropoutModule):
1086
+ """Classifier Free Guidance dropout.
1087
+ All attributes are dropped with the same probability.
1088
+
1089
+ Args:
1090
+ p (float): Probability to apply condition dropout during training.
1091
+ seed (int): Random seed.
1092
+ """
1093
+ def __init__(self, p: float, seed: int = 1234):
1094
+ super().__init__(seed=seed)
1095
+ self.p = p
1096
+
1097
+ def forward(self, samples: tp.List[ConditioningAttributes]) -> tp.List[ConditioningAttributes]:
1098
+ """
1099
+ Args:
1100
+ samples (list[ConditioningAttributes]): List of conditions.
1101
+ Returns:
1102
+ list[ConditioningAttributes]: List of conditions after all attributes were set to None.
1103
+ """
1104
+ if not self.training:
1105
+ return samples
1106
+
1107
+ # decide on which attributes to drop in a batched fashion
1108
+ drop = torch.rand(1, generator=self.rng).item() < self.p
1109
+ if not drop:
1110
+ return samples
1111
+
1112
+ # nullify conditions of all attributes
1113
+ samples = deepcopy(samples)
1114
+ for condition_type in ["wav", "text"]:
1115
+ for sample in samples:
1116
+ for condition in sample.attributes[condition_type]:
1117
+ dropout_condition(sample, condition_type, condition)
1118
+ return samples
1119
+
1120
+ def __repr__(self):
1121
+ return f"ClassifierFreeGuidanceDropout(p={self.p})"
1122
+
1123
+
1124
+ class ConditioningProvider(nn.Module):
1125
+ """Prepare and provide conditions given all the supported conditioners.
1126
+
1127
+ Args:
1128
+ conditioners (dict): Dictionary of conditioners.
1129
+ device (torch.device or str, optional): Device for conditioners and output condition types.
1130
+ """
1131
+ def __init__(self, conditioners: tp.Dict[str, BaseConditioner], device: tp.Union[torch.device, str] = "cpu"):
1132
+ super().__init__()
1133
+ self.device = device
1134
+ self.conditioners = nn.ModuleDict(conditioners)
1135
+
1136
+ @property
1137
+ def joint_embed_conditions(self):
1138
+ return [m.attribute for m in self.conditioners.values() if isinstance(m, JointEmbeddingConditioner)]
1139
+
1140
+ @property
1141
+ def has_joint_embed_conditions(self):
1142
+ return len(self.joint_embed_conditions) > 0
1143
+
1144
+ @property
1145
+ def text_conditions(self):
1146
+ return [k for k, v in self.conditioners.items() if isinstance(v, TextConditioner)]
1147
+
1148
+ @property
1149
+ def wav_conditions(self):
1150
+ return [k for k, v in self.conditioners.items() if isinstance(v, WaveformConditioner)]
1151
+
1152
+ @property
1153
+ def has_wav_condition(self):
1154
+ return len(self.wav_conditions) > 0
1155
+
1156
+ def tokenize(self, inputs: tp.List[ConditioningAttributes]) -> tp.Dict[str, tp.Any]:
1157
+ """Match attributes/wavs with existing conditioners in self, and compute tokenize them accordingly.
1158
+ This should be called before starting any real GPU work to avoid synchronization points.
1159
+ This will return a dict matching conditioner names to their arbitrary tokenized representations.
1160
+
1161
+ Args:
1162
+ inputs (list[ConditioningAttributes]): List of ConditioningAttributes objects containing
1163
+ text and wav conditions.
1164
+ """
1165
+ assert all([isinstance(x, ConditioningAttributes) for x in inputs]), (
1166
+ "Got unexpected types input for conditioner! should be tp.List[ConditioningAttributes]",
1167
+ f" but types were {set([type(x) for x in inputs])}"
1168
+ )
1169
+
1170
+ output = {}
1171
+ text = self._collate_text(inputs)
1172
+ wavs = self._collate_wavs(inputs)
1173
+ joint_embeds = self._collate_joint_embeds(inputs)
1174
+
1175
+ assert set(text.keys() | wavs.keys() | joint_embeds.keys()).issubset(set(self.conditioners.keys())), (
1176
+ f"Got an unexpected attribute! Expected {self.conditioners.keys()}, ",
1177
+ f"got {text.keys(), wavs.keys(), joint_embeds.keys()}"
1178
+ )
1179
+
1180
+ for attribute, batch in chain(text.items(), wavs.items(), joint_embeds.items()):
1181
+ output[attribute] = self.conditioners[attribute].tokenize(batch)
1182
+ return output
1183
+
1184
+ def forward(self, tokenized: tp.Dict[str, tp.Any]) -> tp.Dict[str, ConditionType]:
1185
+ """Compute pairs of `(embedding, mask)` using the configured conditioners and the tokenized representations.
1186
+ The output is for example:
1187
+ {
1188
+ "genre": (torch.Tensor([B, 1, D_genre]), torch.Tensor([B, 1])),
1189
+ "description": (torch.Tensor([B, T_desc, D_desc]), torch.Tensor([B, T_desc])),
1190
+ ...
1191
+ }
1192
+
1193
+ Args:
1194
+ tokenized (dict): Dict of tokenized representations as returned by `tokenize()`.
1195
+ """
1196
+ output = {}
1197
+ for attribute, inputs in tokenized.items():
1198
+ condition, mask = self.conditioners[attribute](inputs)
1199
+ output[attribute] = (condition, mask)
1200
+ return output
1201
+
1202
+ def _collate_text(self, samples: tp.List[ConditioningAttributes]) -> tp.Dict[str, tp.List[tp.Optional[str]]]:
1203
+ """Given a list of ConditioningAttributes objects, compile a dictionary where the keys
1204
+ are the attributes and the values are the aggregated input per attribute.
1205
+ For example:
1206
+ Input:
1207
+ [
1208
+ ConditioningAttributes(text={"genre": "Rock", "description": "A rock song with a guitar solo"}, wav=...),
1209
+ ConditioningAttributes(text={"genre": "Hip-hop", "description": "A hip-hop verse"}, wav=...),
1210
+ ]
1211
+ Output:
1212
+ {
1213
+ "genre": ["Rock", "Hip-hop"],
1214
+ "description": ["A rock song with a guitar solo", "A hip-hop verse"]
1215
+ }
1216
+
1217
+ Args:
1218
+ samples (list of ConditioningAttributes): List of ConditioningAttributes samples.
1219
+ Returns:
1220
+ dict[str, list[str, optional]]: A dictionary mapping an attribute name to text batch.
1221
+ """
1222
+ out: tp.Dict[str, tp.List[tp.Optional[str]]] = defaultdict(list)
1223
+ texts = [x.text for x in samples]
1224
+ for text in texts:
1225
+ for condition in self.text_conditions:
1226
+ out[condition].append(text[condition])
1227
+ return out
1228
+
1229
+ def _collate_wavs(self, samples: tp.List[ConditioningAttributes]) -> tp.Dict[str, WavCondition]:
1230
+ """Generate a dict where the keys are attributes by which we fetch similar wavs,
1231
+ and the values are Tensors of wavs according to said attributes.
1232
+
1233
+ *Note*: by the time the samples reach this function, each sample should have some waveform
1234
+ inside the "wav" attribute. It should be either:
1235
+ 1. A real waveform
1236
+ 2. A null waveform due to the sample having no similar waveforms (nullified by the dataset)
1237
+ 3. A null waveform due to it being dropped in a dropout module (nullified by dropout)
1238
+
1239
+ Args:
1240
+ samples (list of ConditioningAttributes): List of ConditioningAttributes samples.
1241
+ Returns:
1242
+ dict[str, WavCondition]: A dictionary mapping an attribute name to wavs.
1243
+ """
1244
+ wavs = defaultdict(list)
1245
+ lengths = defaultdict(list)
1246
+ sample_rates = defaultdict(list)
1247
+ paths = defaultdict(list)
1248
+ seek_times = defaultdict(list)
1249
+ out: tp.Dict[str, WavCondition] = {}
1250
+
1251
+ for sample in samples:
1252
+ for attribute in self.wav_conditions:
1253
+ wav, length, sample_rate, path, seek_time = sample.wav[attribute]
1254
+ assert wav.dim() == 3, f"Got wav with dim={wav.dim()}, but expected 3 [1, C, T]"
1255
+ assert wav.size(0) == 1, f"Got wav [B, C, T] with shape={wav.shape}, but expected B == 1"
1256
+ # mono-channel conditioning
1257
+ wav = wav.mean(1, keepdim=True) # [1, 1, T]
1258
+ wavs[attribute].append(wav.flatten()) # [T]
1259
+ lengths[attribute].append(length)
1260
+ sample_rates[attribute].extend(sample_rate)
1261
+ paths[attribute].extend(path)
1262
+ seek_times[attribute].extend(seek_time)
1263
+
1264
+ # stack all wavs to a single tensor
1265
+ for attribute in self.wav_conditions:
1266
+ stacked_wav, _ = collate(wavs[attribute], dim=0)
1267
+ out[attribute] = WavCondition(
1268
+ stacked_wav.unsqueeze(1), torch.cat(lengths[attribute]), sample_rates[attribute],
1269
+ paths[attribute], seek_times[attribute])
1270
+
1271
+ return out
1272
+
1273
+ def _collate_joint_embeds(self, samples: tp.List[ConditioningAttributes]) -> tp.Dict[str, JointEmbedCondition]:
1274
+ """Generate a dict where the keys are attributes by which we compute joint embeddings,
1275
+ and the values are Tensors of pre-computed embeddings and the corresponding text attributes.
1276
+
1277
+ Args:
1278
+ samples (list[ConditioningAttributes]): List of ConditioningAttributes samples.
1279
+ Returns:
1280
+ A dictionary mapping an attribute name to joint embeddings.
1281
+ """
1282
+ texts = defaultdict(list)
1283
+ wavs = defaultdict(list)
1284
+ lengths = defaultdict(list)
1285
+ sample_rates = defaultdict(list)
1286
+ paths = defaultdict(list)
1287
+ seek_times = defaultdict(list)
1288
+ channels: int = 0
1289
+
1290
+ out = {}
1291
+ for sample in samples:
1292
+ for attribute in self.joint_embed_conditions:
1293
+ wav, text, length, sample_rate, path, seek_time = sample.joint_embed[attribute]
1294
+ assert wav.dim() == 3
1295
+ if channels == 0:
1296
+ channels = wav.size(1)
1297
+ else:
1298
+ assert channels == wav.size(1), "not all audio has same number of channels in batch"
1299
+ assert wav.size(0) == 1, "Expecting single-wav batch in the collate method"
1300
+ wav = einops.rearrange(wav, "b c t -> (b c t)") # [1, C, T] => [C * T]
1301
+ wavs[attribute].append(wav)
1302
+ texts[attribute].extend(text)
1303
+ lengths[attribute].append(length)
1304
+ sample_rates[attribute].extend(sample_rate)
1305
+ paths[attribute].extend(path)
1306
+ seek_times[attribute].extend(seek_time)
1307
+
1308
+ for attribute in self.joint_embed_conditions:
1309
+ stacked_texts = texts[attribute]
1310
+ stacked_paths = paths[attribute]
1311
+ stacked_seek_times = seek_times[attribute]
1312
+ stacked_wavs = pad_sequence(wavs[attribute]).to(self.device)
1313
+ stacked_wavs = einops.rearrange(stacked_wavs, "(c t) b -> b c t", c=channels)
1314
+ stacked_sample_rates = sample_rates[attribute]
1315
+ stacked_lengths = torch.cat(lengths[attribute]).to(self.device)
1316
+ assert stacked_lengths.size(0) == stacked_wavs.size(0)
1317
+ assert len(stacked_sample_rates) == stacked_wavs.size(0)
1318
+ assert len(stacked_texts) == stacked_wavs.size(0)
1319
+ out[attribute] = JointEmbedCondition(
1320
+ text=stacked_texts, wav=stacked_wavs,
1321
+ length=stacked_lengths, sample_rate=stacked_sample_rates,
1322
+ path=stacked_paths, seek_time=stacked_seek_times)
1323
+
1324
+ return out
1325
+
1326
+
1327
+ class ConditionFuser(StreamingModule):
1328
+ """Condition fuser handles the logic to combine the different conditions
1329
+ to the actual model input.
1330
+
1331
+ Args:
1332
+ fuse2cond (tp.Dict[str, str]): A dictionary that says how to fuse
1333
+ each condition. For example:
1334
+ {
1335
+ "prepend": ["description"],
1336
+ "sum": ["genre", "bpm"],
1337
+ "cross": ["description"],
1338
+ }
1339
+ cross_attention_pos_emb (bool, optional): Use positional embeddings in cross attention.
1340
+ cross_attention_pos_emb_scale (int): Scale for positional embeddings in cross attention if used.
1341
+ """
1342
+ FUSING_METHODS = ["sum", "prepend", "cross", "input_interpolate"]
1343
+
1344
+ def __init__(self, fuse2cond: tp.Dict[str, tp.List[str]], cross_attention_pos_emb: bool = False,
1345
+ cross_attention_pos_emb_scale: float = 1.0):
1346
+ super().__init__()
1347
+ assert all(
1348
+ [k in self.FUSING_METHODS for k in fuse2cond.keys()]
1349
+ ), f"Got invalid fuse method, allowed methods: {self.FUSING_METHODS}"
1350
+ self.cross_attention_pos_emb = cross_attention_pos_emb
1351
+ self.cross_attention_pos_emb_scale = cross_attention_pos_emb_scale
1352
+ self.fuse2cond: tp.Dict[str, tp.List[str]] = fuse2cond
1353
+ self.cond2fuse: tp.Dict[str, str] = {}
1354
+ for fuse_method, conditions in fuse2cond.items():
1355
+ for condition in conditions:
1356
+ self.cond2fuse[condition] = fuse_method
1357
+
1358
+ def forward(
1359
+ self,
1360
+ input: torch.Tensor,
1361
+ conditions: tp.Dict[str, ConditionType]
1362
+ ) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
1363
+ """Fuse the conditions to the provided model input.
1364
+
1365
+ Args:
1366
+ input (torch.Tensor): Transformer input.
1367
+ conditions (dict[str, ConditionType]): Dict of conditions.
1368
+ Returns:
1369
+ tuple[torch.Tensor, torch.Tensor]: The first tensor is the transformer input
1370
+ after the conditions have been fused. The second output tensor is the tensor
1371
+ used for cross-attention or None if no cross attention inputs exist.
1372
+ """
1373
+ B, T, _ = input.shape
1374
+
1375
+ if 'offsets' in self._streaming_state:
1376
+ first_step = False
1377
+ offsets = self._streaming_state['offsets']
1378
+ else:
1379
+ first_step = True
1380
+ offsets = torch.zeros(input.shape[0], dtype=torch.long, device=input.device)
1381
+
1382
+ assert set(conditions.keys()).issubset(set(self.cond2fuse.keys())), \
1383
+ f"given conditions contain unknown attributes for fuser, " \
1384
+ f"expected {self.cond2fuse.keys()}, got {conditions.keys()}"
1385
+ cross_attention_output = None
1386
+ for cond_type, (cond, cond_mask) in conditions.items():
1387
+ op = self.cond2fuse[cond_type]
1388
+ if op == 'sum':
1389
+ input += cond
1390
+ elif op == 'input_interpolate':
1391
+ cond = einops.rearrange(cond, "b t d -> b d t")
1392
+ cond = F.interpolate(cond, size=input.shape[1])
1393
+ input += einops.rearrange(cond, "b d t -> b t d")
1394
+ elif op == 'prepend':
1395
+ if first_step:
1396
+ input = torch.cat([cond, input], dim=1)
1397
+ elif op == 'cross':
1398
+ if cross_attention_output is not None:
1399
+ cross_attention_output = torch.cat([cross_attention_output, cond], dim=1)
1400
+ else:
1401
+ cross_attention_output = cond
1402
+ else:
1403
+ raise ValueError(f"unknown op ({op})")
1404
+
1405
+ if self.cross_attention_pos_emb and cross_attention_output is not None:
1406
+ positions = torch.arange(
1407
+ cross_attention_output.shape[1],
1408
+ device=cross_attention_output.device
1409
+ ).view(1, -1, 1)
1410
+ pos_emb = create_sin_embedding(positions, cross_attention_output.shape[-1])
1411
+ cross_attention_output = cross_attention_output + self.cross_attention_pos_emb_scale * pos_emb
1412
+
1413
+ if self._is_streaming:
1414
+ self._streaming_state['offsets'] = offsets + T
1415
+
1416
+ return input, cross_attention_output
audiocraft/modules/conv.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+ import typing as tp
9
+ import warnings
10
+
11
+ import torch
12
+ from torch import nn
13
+ from torch.nn import functional as F
14
+ from torch.nn.utils import spectral_norm, weight_norm
15
+
16
+
17
+ CONV_NORMALIZATIONS = frozenset(['none', 'weight_norm', 'spectral_norm',
18
+ 'time_group_norm'])
19
+
20
+
21
+ def apply_parametrization_norm(module: nn.Module, norm: str = 'none'):
22
+ assert norm in CONV_NORMALIZATIONS
23
+ if norm == 'weight_norm':
24
+ return weight_norm(module)
25
+ elif norm == 'spectral_norm':
26
+ return spectral_norm(module)
27
+ else:
28
+ # We already check was in CONV_NORMALIZATION, so any other choice
29
+ # doesn't need reparametrization.
30
+ return module
31
+
32
+
33
+ def get_norm_module(module: nn.Module, causal: bool = False, norm: str = 'none', **norm_kwargs):
34
+ """Return the proper normalization module. If causal is True, this will ensure the returned
35
+ module is causal, or return an error if the normalization doesn't support causal evaluation.
36
+ """
37
+ assert norm in CONV_NORMALIZATIONS
38
+ if norm == 'time_group_norm':
39
+ if causal:
40
+ raise ValueError("GroupNorm doesn't support causal evaluation.")
41
+ assert isinstance(module, nn.modules.conv._ConvNd)
42
+ return nn.GroupNorm(1, module.out_channels, **norm_kwargs)
43
+ else:
44
+ return nn.Identity()
45
+
46
+
47
+ def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int,
48
+ padding_total: int = 0) -> int:
49
+ """See `pad_for_conv1d`."""
50
+ length = x.shape[-1]
51
+ n_frames = (length - kernel_size + padding_total) / stride + 1
52
+ ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
53
+ return ideal_length - length
54
+
55
+
56
+ def pad_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0):
57
+ """Pad for a convolution to make sure that the last window is full.
58
+ Extra padding is added at the end. This is required to ensure that we can rebuild
59
+ an output of the same length, as otherwise, even with padding, some time steps
60
+ might get removed.
61
+ For instance, with total padding = 4, kernel size = 4, stride = 2:
62
+ 0 0 1 2 3 4 5 0 0 # (0s are padding)
63
+ 1 2 3 # (output frames of a convolution, last 0 is never used)
64
+ 0 0 1 2 3 4 5 0 # (output of tr. conv., but pos. 5 is going to get removed as padding)
65
+ 1 2 3 4 # once you removed padding, we are missing one time step !
66
+ """
67
+ extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
68
+ return F.pad(x, (0, extra_padding))
69
+
70
+
71
+ def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'constant', value: float = 0.):
72
+ """Tiny wrapper around F.pad, just to allow for reflect padding on small input.
73
+ If this is the case, we insert extra 0 padding to the right before the reflection happen.
74
+ """
75
+ length = x.shape[-1]
76
+ padding_left, padding_right = paddings
77
+ assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
78
+ if mode == 'reflect':
79
+ max_pad = max(padding_left, padding_right)
80
+ extra_pad = 0
81
+ if length <= max_pad:
82
+ extra_pad = max_pad - length + 1
83
+ x = F.pad(x, (0, extra_pad))
84
+ padded = F.pad(x, paddings, mode, value)
85
+ end = padded.shape[-1] - extra_pad
86
+ return padded[..., :end]
87
+ else:
88
+ return F.pad(x, paddings, mode, value)
89
+
90
+
91
+ def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
92
+ """Remove padding from x, handling properly zero padding. Only for 1d!"""
93
+ padding_left, padding_right = paddings
94
+ assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
95
+ assert (padding_left + padding_right) <= x.shape[-1]
96
+ end = x.shape[-1] - padding_right
97
+ return x[..., padding_left: end]
98
+
99
+
100
+ class NormConv1d(nn.Module):
101
+ """Wrapper around Conv1d and normalization applied to this conv
102
+ to provide a uniform interface across normalization approaches.
103
+ """
104
+ def __init__(self, *args, causal: bool = False, norm: str = 'none',
105
+ norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
106
+ super().__init__()
107
+ self.conv = apply_parametrization_norm(nn.Conv1d(*args, **kwargs), norm)
108
+ self.norm = get_norm_module(self.conv, causal, norm, **norm_kwargs)
109
+ self.norm_type = norm
110
+
111
+ def forward(self, x):
112
+ x = self.conv(x)
113
+ x = self.norm(x)
114
+ return x
115
+
116
+
117
+ class NormConv2d(nn.Module):
118
+ """Wrapper around Conv2d and normalization applied to this conv
119
+ to provide a uniform interface across normalization approaches.
120
+ """
121
+ def __init__(self, *args, norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
122
+ super().__init__()
123
+ self.conv = apply_parametrization_norm(nn.Conv2d(*args, **kwargs), norm)
124
+ self.norm = get_norm_module(self.conv, causal=False, norm=norm, **norm_kwargs)
125
+ self.norm_type = norm
126
+
127
+ def forward(self, x):
128
+ x = self.conv(x)
129
+ x = self.norm(x)
130
+ return x
131
+
132
+
133
+ class NormConvTranspose1d(nn.Module):
134
+ """Wrapper around ConvTranspose1d and normalization applied to this conv
135
+ to provide a uniform interface across normalization approaches.
136
+ """
137
+ def __init__(self, *args, causal: bool = False, norm: str = 'none',
138
+ norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
139
+ super().__init__()
140
+ self.convtr = apply_parametrization_norm(nn.ConvTranspose1d(*args, **kwargs), norm)
141
+ self.norm = get_norm_module(self.convtr, causal, norm, **norm_kwargs)
142
+ self.norm_type = norm
143
+
144
+ def forward(self, x):
145
+ x = self.convtr(x)
146
+ x = self.norm(x)
147
+ return x
148
+
149
+
150
+ class NormConvTranspose2d(nn.Module):
151
+ """Wrapper around ConvTranspose2d and normalization applied to this conv
152
+ to provide a uniform interface across normalization approaches.
153
+ """
154
+ def __init__(self, *args, norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
155
+ super().__init__()
156
+ self.convtr = apply_parametrization_norm(nn.ConvTranspose2d(*args, **kwargs), norm)
157
+ self.norm = get_norm_module(self.convtr, causal=False, norm=norm, **norm_kwargs)
158
+
159
+ def forward(self, x):
160
+ x = self.convtr(x)
161
+ x = self.norm(x)
162
+ return x
163
+
164
+
165
+ class StreamableConv1d(nn.Module):
166
+ """Conv1d with some builtin handling of asymmetric or causal padding
167
+ and normalization.
168
+ """
169
+ def __init__(self, in_channels: int, out_channels: int,
170
+ kernel_size: int, stride: int = 1, dilation: int = 1,
171
+ groups: int = 1, bias: bool = True, causal: bool = False,
172
+ norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {},
173
+ pad_mode: str = 'reflect'):
174
+ super().__init__()
175
+ # warn user on unusual setup between dilation and stride
176
+ if stride > 1 and dilation > 1:
177
+ warnings.warn("StreamableConv1d has been initialized with stride > 1 and dilation > 1"
178
+ f" (kernel_size={kernel_size} stride={stride}, dilation={dilation}).")
179
+ self.conv = NormConv1d(in_channels, out_channels, kernel_size, stride,
180
+ dilation=dilation, groups=groups, bias=bias, causal=causal,
181
+ norm=norm, norm_kwargs=norm_kwargs)
182
+ self.causal = causal
183
+ self.pad_mode = pad_mode
184
+
185
+ def forward(self, x):
186
+ B, C, T = x.shape
187
+ kernel_size = self.conv.conv.kernel_size[0]
188
+ stride = self.conv.conv.stride[0]
189
+ dilation = self.conv.conv.dilation[0]
190
+ kernel_size = (kernel_size - 1) * dilation + 1 # effective kernel size with dilations
191
+ padding_total = kernel_size - stride
192
+ extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
193
+ if self.causal:
194
+ # Left padding for causal
195
+ x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode)
196
+ else:
197
+ # Asymmetric padding required for odd strides
198
+ padding_right = padding_total // 2
199
+ padding_left = padding_total - padding_right
200
+ x = pad1d(x, (padding_left, padding_right + extra_padding), mode=self.pad_mode)
201
+ return self.conv(x)
202
+
203
+
204
+ class StreamableConvTranspose1d(nn.Module):
205
+ """ConvTranspose1d with some builtin handling of asymmetric or causal padding
206
+ and normalization.
207
+ """
208
+ def __init__(self, in_channels: int, out_channels: int,
209
+ kernel_size: int, stride: int = 1, causal: bool = False,
210
+ norm: str = 'none', trim_right_ratio: float = 1.,
211
+ norm_kwargs: tp.Dict[str, tp.Any] = {}):
212
+ super().__init__()
213
+ self.convtr = NormConvTranspose1d(in_channels, out_channels, kernel_size, stride,
214
+ causal=causal, norm=norm, norm_kwargs=norm_kwargs)
215
+ self.causal = causal
216
+ self.trim_right_ratio = trim_right_ratio
217
+ assert self.causal or self.trim_right_ratio == 1., \
218
+ "`trim_right_ratio` != 1.0 only makes sense for causal convolutions"
219
+ assert self.trim_right_ratio >= 0. and self.trim_right_ratio <= 1.
220
+
221
+ def forward(self, x):
222
+ kernel_size = self.convtr.convtr.kernel_size[0]
223
+ stride = self.convtr.convtr.stride[0]
224
+ padding_total = kernel_size - stride
225
+
226
+ y = self.convtr(x)
227
+
228
+ # We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be
229
+ # removed at the very end, when keeping only the right length for the output,
230
+ # as removing it here would require also passing the length at the matching layer
231
+ # in the encoder.
232
+ if self.causal:
233
+ # Trim the padding on the right according to the specified ratio
234
+ # if trim_right_ratio = 1.0, trim everything from right
235
+ padding_right = math.ceil(padding_total * self.trim_right_ratio)
236
+ padding_left = padding_total - padding_right
237
+ y = unpad1d(y, (padding_left, padding_right))
238
+ else:
239
+ # Asymmetric padding required for odd strides
240
+ padding_right = padding_total // 2
241
+ padding_left = padding_total - padding_right
242
+ y = unpad1d(y, (padding_left, padding_right))
243
+ return y
audiocraft/modules/diffusion_schedule.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Functions for Noise Schedule, defines diffusion process, reverse process and data processor.
9
+ """
10
+
11
+ from collections import namedtuple
12
+ import random
13
+ import typing as tp
14
+ import julius
15
+ import torch
16
+
17
+ TrainingItem = namedtuple("TrainingItem", "noisy noise step")
18
+
19
+
20
+ def betas_from_alpha_bar(alpha_bar):
21
+ alphas = torch.cat([torch.Tensor([alpha_bar[0]]), alpha_bar[1:]/alpha_bar[:-1]])
22
+ return 1 - alphas
23
+
24
+
25
+ class SampleProcessor(torch.nn.Module):
26
+ def project_sample(self, x: torch.Tensor):
27
+ """Project the original sample to the 'space' where the diffusion will happen."""
28
+ return x
29
+
30
+ def return_sample(self, z: torch.Tensor):
31
+ """Project back from diffusion space to the actual sample space."""
32
+ return z
33
+
34
+
35
+ class MultiBandProcessor(SampleProcessor):
36
+ """
37
+ MultiBand sample processor. The input audio is splitted across
38
+ frequency bands evenly distributed in mel-scale.
39
+
40
+ Each band will be rescaled to match the power distribution
41
+ of Gaussian noise in that band, using online metrics
42
+ computed on the first few samples.
43
+
44
+ Args:
45
+ n_bands (int): Number of mel-bands to split the signal over.
46
+ sample_rate (int): Sample rate of the audio.
47
+ num_samples (int): Number of samples to use to fit the rescaling
48
+ for each band. The processor won't be stable
49
+ until it has seen that many samples.
50
+ power_std (float or list/tensor): The rescaling factor computed to match the
51
+ power of Gaussian noise in each band is taken to
52
+ that power, i.e. `1.` means full correction of the energy
53
+ in each band, and values less than `1` means only partial
54
+ correction. Can be used to balance the relative importance
55
+ of low vs. high freq in typical audio signals.
56
+ """
57
+ def __init__(self, n_bands: int = 8, sample_rate: float = 24_000,
58
+ num_samples: int = 10_000, power_std: tp.Union[float, tp.List[float], torch.Tensor] = 1.):
59
+ super().__init__()
60
+ self.n_bands = n_bands
61
+ self.split_bands = julius.SplitBands(sample_rate, n_bands=n_bands)
62
+ self.num_samples = num_samples
63
+ self.power_std = power_std
64
+ if isinstance(power_std, list):
65
+ assert len(power_std) == n_bands
66
+ power_std = torch.tensor(power_std)
67
+ self.register_buffer('counts', torch.zeros(1))
68
+ self.register_buffer('sum_x', torch.zeros(n_bands))
69
+ self.register_buffer('sum_x2', torch.zeros(n_bands))
70
+ self.register_buffer('sum_target_x2', torch.zeros(n_bands))
71
+ self.counts: torch.Tensor
72
+ self.sum_x: torch.Tensor
73
+ self.sum_x2: torch.Tensor
74
+ self.sum_target_x2: torch.Tensor
75
+
76
+ @property
77
+ def mean(self):
78
+ mean = self.sum_x / self.counts
79
+ return mean
80
+
81
+ @property
82
+ def std(self):
83
+ std = (self.sum_x2 / self.counts - self.mean**2).clamp(min=0).sqrt()
84
+ return std
85
+
86
+ @property
87
+ def target_std(self):
88
+ target_std = self.sum_target_x2 / self.counts
89
+ return target_std
90
+
91
+ def project_sample(self, x: torch.Tensor):
92
+ assert x.dim() == 3
93
+ bands = self.split_bands(x)
94
+ if self.counts.item() < self.num_samples:
95
+ ref_bands = self.split_bands(torch.randn_like(x))
96
+ self.counts += len(x)
97
+ self.sum_x += bands.mean(dim=(2, 3)).sum(dim=1)
98
+ self.sum_x2 += bands.pow(2).mean(dim=(2, 3)).sum(dim=1)
99
+ self.sum_target_x2 += ref_bands.pow(2).mean(dim=(2, 3)).sum(dim=1)
100
+ rescale = (self.target_std / self.std.clamp(min=1e-12)) ** self.power_std # same output size
101
+ bands = (bands - self.mean.view(-1, 1, 1, 1)) * rescale.view(-1, 1, 1, 1)
102
+ return bands.sum(dim=0)
103
+
104
+ def return_sample(self, x: torch.Tensor):
105
+ assert x.dim() == 3
106
+ bands = self.split_bands(x)
107
+ rescale = (self.std / self.target_std) ** self.power_std
108
+ bands = bands * rescale.view(-1, 1, 1, 1) + self.mean.view(-1, 1, 1, 1)
109
+ return bands.sum(dim=0)
110
+
111
+
112
+ class NoiseSchedule:
113
+ """Noise schedule for diffusion.
114
+
115
+ Args:
116
+ beta_t0 (float): Variance of the first diffusion step.
117
+ beta_t1 (float): Variance of the last diffusion step.
118
+ beta_exp (float): Power schedule exponent
119
+ num_steps (int): Number of diffusion step.
120
+ variance (str): choice of the sigma value for the denoising eq. Choices: "beta" or "beta_tilde"
121
+ clip (float): clipping value for the denoising steps
122
+ rescale (float): rescaling value to avoid vanishing signals unused by default (i.e 1)
123
+ repartition (str): shape of the schedule only power schedule is supported
124
+ sample_processor (SampleProcessor): Module that normalize data to match better the gaussian distribution
125
+ noise_scale (float): Scaling factor for the noise
126
+ """
127
+ def __init__(self, beta_t0: float = 1e-4, beta_t1: float = 0.02, num_steps: int = 1000, variance: str = 'beta',
128
+ clip: float = 5., rescale: float = 1., device='cuda', beta_exp: float = 1,
129
+ repartition: str = "power", alpha_sigmoid: dict = {}, n_bands: tp.Optional[int] = None,
130
+ sample_processor: SampleProcessor = SampleProcessor(), noise_scale: float = 1.0, **kwargs):
131
+
132
+ self.beta_t0 = beta_t0
133
+ self.beta_t1 = beta_t1
134
+ self.variance = variance
135
+ self.num_steps = num_steps
136
+ self.clip = clip
137
+ self.sample_processor = sample_processor
138
+ self.rescale = rescale
139
+ self.n_bands = n_bands
140
+ self.noise_scale = noise_scale
141
+ assert n_bands is None
142
+ if repartition == "power":
143
+ self.betas = torch.linspace(beta_t0 ** (1 / beta_exp), beta_t1 ** (1 / beta_exp), num_steps,
144
+ device=device, dtype=torch.float) ** beta_exp
145
+ else:
146
+ raise RuntimeError('Not implemented')
147
+ self.rng = random.Random(1234)
148
+
149
+ def get_beta(self, step: tp.Union[int, torch.Tensor]):
150
+ if self.n_bands is None:
151
+ return self.betas[step]
152
+ else:
153
+ return self.betas[:, step] # [n_bands, len(step)]
154
+
155
+ def get_initial_noise(self, x: torch.Tensor):
156
+ if self.n_bands is None:
157
+ return torch.randn_like(x)
158
+ return torch.randn((x.size(0), self.n_bands, x.size(2)))
159
+
160
+ def get_alpha_bar(self, step: tp.Optional[tp.Union[int, torch.Tensor]] = None) -> torch.Tensor:
161
+ """Return 'alpha_bar', either for a given step, or as a tensor with its value for each step."""
162
+ if step is None:
163
+ return (1 - self.betas).cumprod(dim=-1) # works for simgle and multi bands
164
+ if type(step) is int:
165
+ return (1 - self.betas[:step + 1]).prod()
166
+ else:
167
+ return (1 - self.betas).cumprod(dim=0)[step].view(-1, 1, 1)
168
+
169
+ def get_training_item(self, x: torch.Tensor, tensor_step: bool = False) -> TrainingItem:
170
+ """Create a noisy data item for diffusion model training:
171
+
172
+ Args:
173
+ x (torch.Tensor): clean audio data torch.tensor(bs, 1, T)
174
+ tensor_step (bool): If tensor_step = false, only one step t is sample,
175
+ the whole batch is diffused to the same step and t is int.
176
+ If tensor_step = true, t is a tensor of size (x.size(0),)
177
+ every element of the batch is diffused to a independently sampled.
178
+ """
179
+ step: tp.Union[int, torch.Tensor]
180
+ if tensor_step:
181
+ bs = x.size(0)
182
+ step = torch.randint(0, self.num_steps, size=(bs,), device=x.device)
183
+ else:
184
+ step = self.rng.randrange(self.num_steps)
185
+ alpha_bar = self.get_alpha_bar(step) # [batch_size, n_bands, 1]
186
+
187
+ x = self.sample_processor.project_sample(x)
188
+ noise = torch.randn_like(x)
189
+ noisy = (alpha_bar.sqrt() / self.rescale) * x + (1 - alpha_bar).sqrt() * noise * self.noise_scale
190
+ return TrainingItem(noisy, noise, step)
191
+
192
+ def generate(self, model: torch.nn.Module, initial: tp.Optional[torch.Tensor] = None,
193
+ condition: tp.Optional[torch.Tensor] = None, return_list: bool = False):
194
+ """Full ddpm reverse process.
195
+
196
+ Args:
197
+ model (nn.Module): Diffusion model.
198
+ initial (tensor): Initial Noise.
199
+ condition (tensor): Input conditionning Tensor (e.g. encodec compressed representation).
200
+ return_list (bool): Whether to return the whole process or only the sampled point.
201
+ """
202
+ alpha_bar = self.get_alpha_bar(step=self.num_steps - 1)
203
+ current = initial
204
+ iterates = [initial]
205
+ for step in range(self.num_steps)[::-1]:
206
+ with torch.no_grad():
207
+ estimate = model(current, step, condition=condition).sample
208
+ alpha = 1 - self.betas[step]
209
+ previous = (current - (1 - alpha) / (1 - alpha_bar).sqrt() * estimate) / alpha.sqrt()
210
+ previous_alpha_bar = self.get_alpha_bar(step=step - 1)
211
+ if step == 0:
212
+ sigma2 = 0
213
+ elif self.variance == 'beta':
214
+ sigma2 = 1 - alpha
215
+ elif self.variance == 'beta_tilde':
216
+ sigma2 = (1 - previous_alpha_bar) / (1 - alpha_bar) * (1 - alpha)
217
+ elif self.variance == 'none':
218
+ sigma2 = 0
219
+ else:
220
+ raise ValueError(f'Invalid variance type {self.variance}')
221
+
222
+ if sigma2 > 0:
223
+ previous += sigma2**0.5 * torch.randn_like(previous) * self.noise_scale
224
+ if self.clip:
225
+ previous = previous.clamp(-self.clip, self.clip)
226
+ current = previous
227
+ alpha_bar = previous_alpha_bar
228
+ if step == 0:
229
+ previous *= self.rescale
230
+ if return_list:
231
+ iterates.append(previous.cpu())
232
+
233
+ if return_list:
234
+ return iterates
235
+ else:
236
+ return self.sample_processor.return_sample(previous)
237
+
238
+ def generate_subsampled(self, model: torch.nn.Module, initial: torch.Tensor, step_list: tp.Optional[list] = None,
239
+ condition: tp.Optional[torch.Tensor] = None, return_list: bool = False):
240
+ """Reverse process that only goes through Markov chain states in step_list."""
241
+ if step_list is None:
242
+ step_list = list(range(1000))[::-50] + [0]
243
+ alpha_bar = self.get_alpha_bar(step=self.num_steps - 1)
244
+ alpha_bars_subsampled = (1 - self.betas).cumprod(dim=0)[list(reversed(step_list))].cpu()
245
+ betas_subsampled = betas_from_alpha_bar(alpha_bars_subsampled)
246
+ current = initial * self.noise_scale
247
+ iterates = [current]
248
+ for idx, step in enumerate(step_list[:-1]):
249
+ with torch.no_grad():
250
+ estimate = model(current, step, condition=condition).sample * self.noise_scale
251
+ alpha = 1 - betas_subsampled[-1 - idx]
252
+ previous = (current - (1 - alpha) / (1 - alpha_bar).sqrt() * estimate) / alpha.sqrt()
253
+ previous_alpha_bar = self.get_alpha_bar(step_list[idx + 1])
254
+ if step == step_list[-2]:
255
+ sigma2 = 0
256
+ previous_alpha_bar = torch.tensor(1.0)
257
+ else:
258
+ sigma2 = (1 - previous_alpha_bar) / (1 - alpha_bar) * (1 - alpha)
259
+ if sigma2 > 0:
260
+ previous += sigma2**0.5 * torch.randn_like(previous) * self.noise_scale
261
+ if self.clip:
262
+ previous = previous.clamp(-self.clip, self.clip)
263
+ current = previous
264
+ alpha_bar = previous_alpha_bar
265
+ if step == 0:
266
+ previous *= self.rescale
267
+ if return_list:
268
+ iterates.append(previous.cpu())
269
+ if return_list:
270
+ return iterates
271
+ else:
272
+ return self.sample_processor.return_sample(previous)
audiocraft/modules/lstm.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from torch import nn
8
+
9
+
10
+ class StreamableLSTM(nn.Module):
11
+ """LSTM without worrying about the hidden state, nor the layout of the data.
12
+ Expects input as convolutional layout.
13
+ """
14
+ def __init__(self, dimension: int, num_layers: int = 2, skip: bool = True):
15
+ super().__init__()
16
+ self.skip = skip
17
+ self.lstm = nn.LSTM(dimension, dimension, num_layers)
18
+
19
+ def forward(self, x):
20
+ x = x.permute(2, 0, 1)
21
+ y, _ = self.lstm(x)
22
+ if self.skip:
23
+ y = y + x
24
+ y = y.permute(1, 2, 0)
25
+ return y
audiocraft/modules/rope.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import typing as tp
8
+
9
+ from torch import nn
10
+ import torch
11
+
12
+
13
+ class XPos(nn.Module):
14
+ """Length-extrapolatable positional embedding (xPos) from [Sun et al 2022](https://arxiv.org/abs/2212.10554v1).
15
+ This applies an exponential decay to the RoPE rotation matrix.
16
+
17
+ Args:
18
+ dim (int): Embedding dimension.
19
+ smoothing (float): Smoothing factor applied to the decay rates.
20
+ base_scale (int): Base decay rate, given in terms of scaling time.
21
+ device (torch.device, optional): Device on which to initialize the module.
22
+ dtype (torch.dtype): dtype to use to generate the embedding.
23
+ """
24
+ def __init__(self, dim: int, smoothing: float = 0.4, base_scale: int = 512,
25
+ device=None, dtype: torch.dtype = torch.float32):
26
+ super().__init__()
27
+ assert dim % 2 == 0
28
+ assert dtype in [torch.float64, torch.float32]
29
+ self.dtype = dtype
30
+ self.base_scale = base_scale
31
+
32
+ half_dim = dim // 2
33
+ adim = torch.arange(half_dim, device=device, dtype=dtype)
34
+ decay_rates = (adim / half_dim + smoothing) / (1.0 + smoothing)
35
+ self.register_buffer("decay_rates", decay_rates)
36
+ self.decay: tp.Optional[torch.Tensor] = None
37
+
38
+ def get_decay(self, start: int, end: int):
39
+ """Create complex decay tensor, cache values for fast computation."""
40
+ if self.decay is None or end > self.decay.shape[0]:
41
+ assert isinstance(self.decay_rates, torch.Tensor) # Satisfy type checker.
42
+ idx = torch.arange(end, device=self.decay_rates.device, dtype=self.dtype)
43
+ power = idx / self.base_scale
44
+ scale = self.decay_rates ** power.unsqueeze(-1)
45
+ self.decay = torch.polar(scale, torch.zeros_like(scale))
46
+ return self.decay[start:end] # [T, C/2]
47
+
48
+
49
+ class RotaryEmbedding(nn.Module):
50
+ """Rotary positional embedding (RoPE) from [Su et al 2022](https://arxiv.org/abs/2104.09864).
51
+
52
+ Args:
53
+ dim (int): Embedding dimension (twice the number of frequencies).
54
+ max_period (float): Maximum period of the rotation frequencies.
55
+ xpos (bool): Use xPos, applies an exponential decay to rotation matrix.
56
+ scale (float): Scale of positional embedding, set to 0 to deactivate.
57
+ device (torch.device, optional): Device on which to initialize the module.
58
+ dtype (torch.dtype): dtype to use to generate the embedding.
59
+ """
60
+ def __init__(self, dim: int, max_period: float = 10000.0, xpos: bool = False,
61
+ scale: float = 1.0, device=None, dtype: torch.dtype = torch.float32):
62
+ super().__init__()
63
+ assert dim % 2 == 0
64
+ self.scale = scale
65
+ assert dtype in [torch.float64, torch.float32]
66
+ self.dtype = dtype
67
+
68
+ adim = torch.arange(0, dim, 2, device=device, dtype=dtype)[: (dim // 2)]
69
+ frequencies = 1.0 / (max_period ** (adim / dim))
70
+ self.register_buffer("frequencies", frequencies)
71
+ self.rotation: tp.Optional[torch.Tensor] = None
72
+
73
+ self.xpos = XPos(dim, device=device, dtype=dtype) if xpos else None
74
+
75
+ def get_rotation(self, start: int, end: int):
76
+ """Create complex rotation tensor, cache values for fast computation."""
77
+ if self.rotation is None or end > self.rotation.shape[0]:
78
+ assert isinstance(self.frequencies, torch.Tensor) # Satisfy type checker.
79
+ idx = torch.arange(end, device=self.frequencies.device, dtype=self.dtype)
80
+ angles = torch.outer(idx, self.frequencies)
81
+ self.rotation = torch.polar(torch.ones_like(angles), angles)
82
+ return self.rotation[start:end]
83
+
84
+ def rotate(self, x: torch.Tensor, start: int = 0, time_dim: int = 1, invert_decay: bool = False):
85
+ """Apply rope rotation to query or key tensor."""
86
+ T = x.shape[time_dim]
87
+ target_shape = [1] * x.dim()
88
+ target_shape[time_dim] = T
89
+ target_shape[-1] = -1
90
+ rotation = self.get_rotation(start, start + T).view(target_shape)
91
+
92
+ if self.xpos:
93
+ decay = self.xpos.get_decay(start, start + T).view(target_shape)
94
+ else:
95
+ decay = 1.0
96
+
97
+ if invert_decay:
98
+ decay = decay ** -1
99
+
100
+ x_complex = torch.view_as_complex(x.to(self.dtype).reshape(*x.shape[:-1], -1, 2))
101
+ scaled_rotation = (rotation * decay) * self.scale + (1.0 - self.scale)
102
+ x_out = torch.view_as_real(x_complex * scaled_rotation).view_as(x)
103
+
104
+ return x_out.type_as(x)
105
+
106
+ def rotate_qk(self, query: torch.Tensor, key: torch.Tensor, start: int = 0, time_dim: int = 1):
107
+ """ Apply rope rotation to both query and key tensors.
108
+ Supports streaming mode, in which query and key are not expected to have the same shape.
109
+ In streaming mode, key will be of length [P + C] with P the cached past timesteps, but
110
+ query will be [C] (typically C == 1).
111
+
112
+ Args:
113
+ query (torch.Tensor): Query to rotate.
114
+ key (torch.Tensor): Key to rotate.
115
+ start (int): Start index of the sequence for time offset.
116
+ time_dim (int): which dimension represent the time steps.
117
+ """
118
+ query_timesteps = query.shape[time_dim]
119
+ key_timesteps = key.shape[time_dim]
120
+ streaming_offset = key_timesteps - query_timesteps
121
+
122
+ query_out = self.rotate(query, start + streaming_offset, time_dim)
123
+ key_out = self.rotate(key, start, time_dim, invert_decay=True)
124
+
125
+ return query_out, key_out
audiocraft/modules/seanet.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import typing as tp
8
+
9
+ import numpy as np
10
+ import torch.nn as nn
11
+
12
+ from .conv import StreamableConv1d, StreamableConvTranspose1d
13
+ from .lstm import StreamableLSTM
14
+
15
+
16
+ class SEANetResnetBlock(nn.Module):
17
+ """Residual block from SEANet model.
18
+
19
+ Args:
20
+ dim (int): Dimension of the input/output.
21
+ kernel_sizes (list): List of kernel sizes for the convolutions.
22
+ dilations (list): List of dilations for the convolutions.
23
+ activation (str): Activation function.
24
+ activation_params (dict): Parameters to provide to the activation function.
25
+ norm (str): Normalization method.
26
+ norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
27
+ causal (bool): Whether to use fully causal convolution.
28
+ pad_mode (str): Padding mode for the convolutions.
29
+ compress (int): Reduced dimensionality in residual branches (from Demucs v3).
30
+ true_skip (bool): Whether to use true skip connection or a simple
31
+ (streamable) convolution as the skip connection.
32
+ """
33
+ def __init__(self, dim: int, kernel_sizes: tp.List[int] = [3, 1], dilations: tp.List[int] = [1, 1],
34
+ activation: str = 'ELU', activation_params: dict = {'alpha': 1.0},
35
+ norm: str = 'none', norm_params: tp.Dict[str, tp.Any] = {}, causal: bool = False,
36
+ pad_mode: str = 'reflect', compress: int = 2, true_skip: bool = True):
37
+ super().__init__()
38
+ assert len(kernel_sizes) == len(dilations), 'Number of kernel sizes should match number of dilations'
39
+ act = getattr(nn, activation)
40
+ hidden = dim // compress
41
+ block = []
42
+ for i, (kernel_size, dilation) in enumerate(zip(kernel_sizes, dilations)):
43
+ in_chs = dim if i == 0 else hidden
44
+ out_chs = dim if i == len(kernel_sizes) - 1 else hidden
45
+ block += [
46
+ act(**activation_params),
47
+ StreamableConv1d(in_chs, out_chs, kernel_size=kernel_size, dilation=dilation,
48
+ norm=norm, norm_kwargs=norm_params,
49
+ causal=causal, pad_mode=pad_mode),
50
+ ]
51
+ self.block = nn.Sequential(*block)
52
+ self.shortcut: nn.Module
53
+ if true_skip:
54
+ self.shortcut = nn.Identity()
55
+ else:
56
+ self.shortcut = StreamableConv1d(dim, dim, kernel_size=1, norm=norm, norm_kwargs=norm_params,
57
+ causal=causal, pad_mode=pad_mode)
58
+
59
+ def forward(self, x):
60
+ return self.shortcut(x) + self.block(x)
61
+
62
+
63
+ class SEANetEncoder(nn.Module):
64
+ """SEANet encoder.
65
+
66
+ Args:
67
+ channels (int): Audio channels.
68
+ dimension (int): Intermediate representation dimension.
69
+ n_filters (int): Base width for the model.
70
+ n_residual_layers (int): nb of residual layers.
71
+ ratios (Sequence[int]): kernel size and stride ratios. The encoder uses downsampling ratios instead of
72
+ upsampling ratios, hence it will use the ratios in the reverse order to the ones specified here
73
+ that must match the decoder order. We use the decoder order as some models may only employ the decoder.
74
+ activation (str): Activation function.
75
+ activation_params (dict): Parameters to provide to the activation function.
76
+ norm (str): Normalization method.
77
+ norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
78
+ kernel_size (int): Kernel size for the initial convolution.
79
+ last_kernel_size (int): Kernel size for the initial convolution.
80
+ residual_kernel_size (int): Kernel size for the residual layers.
81
+ dilation_base (int): How much to increase the dilation with each layer.
82
+ causal (bool): Whether to use fully causal convolution.
83
+ pad_mode (str): Padding mode for the convolutions.
84
+ true_skip (bool): Whether to use true skip connection or a simple
85
+ (streamable) convolution as the skip connection in the residual network blocks.
86
+ compress (int): Reduced dimensionality in residual branches (from Demucs v3).
87
+ lstm (int): Number of LSTM layers at the end of the encoder.
88
+ disable_norm_outer_blocks (int): Number of blocks for which we don't apply norm.
89
+ For the encoder, it corresponds to the N first blocks.
90
+ """
91
+ def __init__(self, channels: int = 1, dimension: int = 128, n_filters: int = 32, n_residual_layers: int = 3,
92
+ ratios: tp.List[int] = [8, 5, 4, 2], activation: str = 'ELU', activation_params: dict = {'alpha': 1.0},
93
+ norm: str = 'none', norm_params: tp.Dict[str, tp.Any] = {}, kernel_size: int = 7,
94
+ last_kernel_size: int = 7, residual_kernel_size: int = 3, dilation_base: int = 2, causal: bool = False,
95
+ pad_mode: str = 'reflect', true_skip: bool = True, compress: int = 2, lstm: int = 0,
96
+ disable_norm_outer_blocks: int = 0):
97
+ super().__init__()
98
+ self.channels = channels
99
+ self.dimension = dimension
100
+ self.n_filters = n_filters
101
+ self.ratios = list(reversed(ratios))
102
+ del ratios
103
+ self.n_residual_layers = n_residual_layers
104
+ self.hop_length = np.prod(self.ratios)
105
+ self.n_blocks = len(self.ratios) + 2 # first and last conv + residual blocks
106
+ self.disable_norm_outer_blocks = disable_norm_outer_blocks
107
+ assert self.disable_norm_outer_blocks >= 0 and self.disable_norm_outer_blocks <= self.n_blocks, \
108
+ "Number of blocks for which to disable norm is invalid." \
109
+ "It should be lower or equal to the actual number of blocks in the network and greater or equal to 0."
110
+
111
+ act = getattr(nn, activation)
112
+ mult = 1
113
+ model: tp.List[nn.Module] = [
114
+ StreamableConv1d(channels, mult * n_filters, kernel_size,
115
+ norm='none' if self.disable_norm_outer_blocks >= 1 else norm,
116
+ norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode)
117
+ ]
118
+ # Downsample to raw audio scale
119
+ for i, ratio in enumerate(self.ratios):
120
+ block_norm = 'none' if self.disable_norm_outer_blocks >= i + 2 else norm
121
+ # Add residual layers
122
+ for j in range(n_residual_layers):
123
+ model += [
124
+ SEANetResnetBlock(mult * n_filters, kernel_sizes=[residual_kernel_size, 1],
125
+ dilations=[dilation_base ** j, 1],
126
+ norm=block_norm, norm_params=norm_params,
127
+ activation=activation, activation_params=activation_params,
128
+ causal=causal, pad_mode=pad_mode, compress=compress, true_skip=true_skip)]
129
+
130
+ # Add downsampling layers
131
+ model += [
132
+ act(**activation_params),
133
+ StreamableConv1d(mult * n_filters, mult * n_filters * 2,
134
+ kernel_size=ratio * 2, stride=ratio,
135
+ norm=block_norm, norm_kwargs=norm_params,
136
+ causal=causal, pad_mode=pad_mode),
137
+ ]
138
+ mult *= 2
139
+
140
+ if lstm:
141
+ model += [StreamableLSTM(mult * n_filters, num_layers=lstm)]
142
+
143
+ model += [
144
+ act(**activation_params),
145
+ StreamableConv1d(mult * n_filters, dimension, last_kernel_size,
146
+ norm='none' if self.disable_norm_outer_blocks == self.n_blocks else norm,
147
+ norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode)
148
+ ]
149
+
150
+ self.model = nn.Sequential(*model)
151
+
152
+ def forward(self, x):
153
+ return self.model(x)
154
+
155
+
156
+ class SEANetDecoder(nn.Module):
157
+ """SEANet decoder.
158
+
159
+ Args:
160
+ channels (int): Audio channels.
161
+ dimension (int): Intermediate representation dimension.
162
+ n_filters (int): Base width for the model.
163
+ n_residual_layers (int): nb of residual layers.
164
+ ratios (Sequence[int]): kernel size and stride ratios.
165
+ activation (str): Activation function.
166
+ activation_params (dict): Parameters to provide to the activation function.
167
+ final_activation (str): Final activation function after all convolutions.
168
+ final_activation_params (dict): Parameters to provide to the activation function.
169
+ norm (str): Normalization method.
170
+ norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
171
+ kernel_size (int): Kernel size for the initial convolution.
172
+ last_kernel_size (int): Kernel size for the initial convolution.
173
+ residual_kernel_size (int): Kernel size for the residual layers.
174
+ dilation_base (int): How much to increase the dilation with each layer.
175
+ causal (bool): Whether to use fully causal convolution.
176
+ pad_mode (str): Padding mode for the convolutions.
177
+ true_skip (bool): Whether to use true skip connection or a simple.
178
+ (streamable) convolution as the skip connection in the residual network blocks.
179
+ compress (int): Reduced dimensionality in residual branches (from Demucs v3).
180
+ lstm (int): Number of LSTM layers at the end of the encoder.
181
+ disable_norm_outer_blocks (int): Number of blocks for which we don't apply norm.
182
+ For the decoder, it corresponds to the N last blocks.
183
+ trim_right_ratio (float): Ratio for trimming at the right of the transposed convolution under the causal setup.
184
+ If equal to 1.0, it means that all the trimming is done at the right.
185
+ """
186
+ def __init__(self, channels: int = 1, dimension: int = 128, n_filters: int = 32, n_residual_layers: int = 3,
187
+ ratios: tp.List[int] = [8, 5, 4, 2], activation: str = 'ELU', activation_params: dict = {'alpha': 1.0},
188
+ final_activation: tp.Optional[str] = None, final_activation_params: tp.Optional[dict] = None,
189
+ norm: str = 'none', norm_params: tp.Dict[str, tp.Any] = {}, kernel_size: int = 7,
190
+ last_kernel_size: int = 7, residual_kernel_size: int = 3, dilation_base: int = 2, causal: bool = False,
191
+ pad_mode: str = 'reflect', true_skip: bool = True, compress: int = 2, lstm: int = 0,
192
+ disable_norm_outer_blocks: int = 0, trim_right_ratio: float = 1.0):
193
+ super().__init__()
194
+ self.dimension = dimension
195
+ self.channels = channels
196
+ self.n_filters = n_filters
197
+ self.ratios = ratios
198
+ del ratios
199
+ self.n_residual_layers = n_residual_layers
200
+ self.hop_length = np.prod(self.ratios)
201
+ self.n_blocks = len(self.ratios) + 2 # first and last conv + residual blocks
202
+ self.disable_norm_outer_blocks = disable_norm_outer_blocks
203
+ assert self.disable_norm_outer_blocks >= 0 and self.disable_norm_outer_blocks <= self.n_blocks, \
204
+ "Number of blocks for which to disable norm is invalid." \
205
+ "It should be lower or equal to the actual number of blocks in the network and greater or equal to 0."
206
+
207
+ act = getattr(nn, activation)
208
+ mult = int(2 ** len(self.ratios))
209
+ model: tp.List[nn.Module] = [
210
+ StreamableConv1d(dimension, mult * n_filters, kernel_size,
211
+ norm='none' if self.disable_norm_outer_blocks == self.n_blocks else norm,
212
+ norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode)
213
+ ]
214
+
215
+ if lstm:
216
+ model += [StreamableLSTM(mult * n_filters, num_layers=lstm)]
217
+
218
+ # Upsample to raw audio scale
219
+ for i, ratio in enumerate(self.ratios):
220
+ block_norm = 'none' if self.disable_norm_outer_blocks >= self.n_blocks - (i + 1) else norm
221
+ # Add upsampling layers
222
+ model += [
223
+ act(**activation_params),
224
+ StreamableConvTranspose1d(mult * n_filters, mult * n_filters // 2,
225
+ kernel_size=ratio * 2, stride=ratio,
226
+ norm=block_norm, norm_kwargs=norm_params,
227
+ causal=causal, trim_right_ratio=trim_right_ratio),
228
+ ]
229
+ # Add residual layers
230
+ for j in range(n_residual_layers):
231
+ model += [
232
+ SEANetResnetBlock(mult * n_filters // 2, kernel_sizes=[residual_kernel_size, 1],
233
+ dilations=[dilation_base ** j, 1],
234
+ activation=activation, activation_params=activation_params,
235
+ norm=block_norm, norm_params=norm_params, causal=causal,
236
+ pad_mode=pad_mode, compress=compress, true_skip=true_skip)]
237
+
238
+ mult //= 2
239
+
240
+ # Add final layers
241
+ model += [
242
+ act(**activation_params),
243
+ StreamableConv1d(n_filters, channels, last_kernel_size,
244
+ norm='none' if self.disable_norm_outer_blocks >= 1 else norm,
245
+ norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode)
246
+ ]
247
+ # Add optional final activation to decoder (eg. tanh)
248
+ if final_activation is not None:
249
+ final_act = getattr(nn, final_activation)
250
+ final_activation_params = final_activation_params or {}
251
+ model += [
252
+ final_act(**final_activation_params)
253
+ ]
254
+ self.model = nn.Sequential(*model)
255
+
256
+ def forward(self, z):
257
+ y = self.model(z)
258
+ return y
audiocraft/modules/streaming.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Streaming module API that should be implemented by all Streaming components,
9
+ """
10
+
11
+ from contextlib import contextmanager
12
+ import typing as tp
13
+ from torch import nn
14
+ import torch
15
+
16
+
17
+ State = tp.Dict[str, torch.Tensor]
18
+
19
+
20
+ class StreamingModule(nn.Module):
21
+ """Common API for streaming components.
22
+
23
+ Each streaming component has a streaming state, which is just a dict[str, Tensor].
24
+ By convention, the first dim of each tensor must be the batch size.
25
+ Don't use dots in the key names, as this would clash with submodules
26
+ (like in state_dict).
27
+
28
+ If `self._is_streaming` is True, the component should use and remember
29
+ the proper state inside `self._streaming_state`.
30
+
31
+ To set a streaming component in streaming state, use
32
+
33
+ with module.streaming():
34
+ ...
35
+
36
+ This will automatically reset the streaming state when exiting the context manager.
37
+ This also automatically propagates to all streaming children module.
38
+
39
+ Some module might also implement the `StreamingModule.flush` method, although
40
+ this one is trickier, as all parents module must be StreamingModule and implement
41
+ it as well for it to work properly. See `StreamingSequential` after.
42
+ """
43
+ def __init__(self) -> None:
44
+ super().__init__()
45
+ self._streaming_state: State = {}
46
+ self._is_streaming = False
47
+
48
+ def _apply_named_streaming(self, fn: tp.Any):
49
+ for name, module in self.named_modules():
50
+ if isinstance(module, StreamingModule):
51
+ fn(name, module)
52
+
53
+ def _set_streaming(self, streaming: bool):
54
+ def _set_streaming(name, module):
55
+ module._is_streaming = streaming
56
+ self._apply_named_streaming(_set_streaming)
57
+
58
+ @contextmanager
59
+ def streaming(self):
60
+ """Context manager to enter streaming mode. Reset streaming state on exit."""
61
+ self._set_streaming(True)
62
+ try:
63
+ yield
64
+ finally:
65
+ self._set_streaming(False)
66
+ self.reset_streaming()
67
+
68
+ def reset_streaming(self):
69
+ """Reset the streaming state."""
70
+ def _reset(name: str, module: StreamingModule):
71
+ module._streaming_state.clear()
72
+
73
+ self._apply_named_streaming(_reset)
74
+
75
+ def get_streaming_state(self) -> State:
76
+ """Return the streaming state, including that of sub-modules."""
77
+ state: State = {}
78
+
79
+ def _add(name: str, module: StreamingModule):
80
+ if name:
81
+ name += "."
82
+ for key, value in module._streaming_state.items():
83
+ state[name + key] = value
84
+
85
+ self._apply_named_streaming(_add)
86
+ return state
87
+
88
+ def set_streaming_state(self, state: State):
89
+ """Set the streaming state, including that of sub-modules."""
90
+ state = dict(state)
91
+
92
+ def _set(name: str, module: StreamingModule):
93
+ if name:
94
+ name += "."
95
+ module._streaming_state.clear()
96
+ for key, value in list(state.items()):
97
+ # complexity is not ideal here, but probably fine.
98
+ if key.startswith(name):
99
+ local_key = key[len(name):]
100
+ if '.' not in local_key:
101
+ module._streaming_state[local_key] = value
102
+ del state[key]
103
+
104
+ self._apply_named_streaming(_set)
105
+ assert len(state) == 0, list(state.keys())
106
+
107
+ def flush(self, x: tp.Optional[torch.Tensor] = None):
108
+ """Flush any remaining outputs that were waiting for completion.
109
+ Typically, for convolutions, this will add the final padding
110
+ and process the last buffer.
111
+
112
+ This should take an optional argument `x`, which will be provided
113
+ if a module before this one in the streaming pipeline has already
114
+ spitted out a flushed out buffer.
115
+ """
116
+ if x is None:
117
+ return None
118
+ else:
119
+ return self(x)
120
+
121
+
122
+ class StreamingSequential(StreamingModule, nn.Sequential):
123
+ """A streaming compatible alternative of `nn.Sequential`.
124
+ """
125
+ def flush(self, x: tp.Optional[torch.Tensor] = None):
126
+ for module in self:
127
+ if isinstance(module, StreamingModule):
128
+ x = module.flush(x)
129
+ elif x is not None:
130
+ x = module(x)
131
+ return x
audiocraft/modules/transformer.py ADDED
@@ -0,0 +1,755 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Transformer model, with streaming support, xformer attention support
9
+ and easy causal attention with a potentially finite receptive field.
10
+
11
+ See `StreamingTransformer` for more information.
12
+
13
+ Unlike regular PyTorch Transformer, we make the hard choice that batches are first.
14
+ """
15
+
16
+ import typing as tp
17
+
18
+ from einops import rearrange
19
+ import torch
20
+ import torch.nn as nn
21
+ from torch.nn import functional as F
22
+ from torch.utils.checkpoint import checkpoint as torch_checkpoint
23
+ from xformers import ops
24
+
25
+ from .rope import RotaryEmbedding
26
+ from .streaming import StreamingModule
27
+
28
+ _efficient_attention_backend: str = 'torch'
29
+
30
+
31
+ def set_efficient_attention_backend(backend: str = 'torch'):
32
+ # Using torch by default, it seems a bit faster on older P100 GPUs (~20% faster).
33
+ global _efficient_attention_backend
34
+ assert _efficient_attention_backend in ['xformers', 'torch']
35
+ _efficient_attention_backend = backend
36
+
37
+
38
+ def _get_attention_time_dimension(memory_efficient: bool) -> int:
39
+ if _efficient_attention_backend == 'torch' and memory_efficient:
40
+ return 2
41
+ else:
42
+ return 1
43
+
44
+
45
+ def _is_profiled() -> bool:
46
+ # Return true if we are currently running with a xformers profiler activated.
47
+ try:
48
+ from xformers.profiler import profiler
49
+ except ImportError:
50
+ return False
51
+ return profiler._Profiler._CURRENT_PROFILER is not None
52
+
53
+
54
+ def create_norm_fn(norm_type: str, dim: int, **kwargs) -> nn.Module:
55
+ """Create normalization module for transformer encoder layer.
56
+
57
+ Args:
58
+ norm_type (str): Normalization method.
59
+ dim (int): Dimension of the normalized layer.
60
+ **kwargs (dict): Additional parameters for normalization layer.
61
+ Returns:
62
+ nn.Module: Normalization module.
63
+ """
64
+ if norm_type == 'layer_norm':
65
+ return nn.LayerNorm(dim, eps=1e-5, **kwargs)
66
+ else:
67
+ raise ValueError(f"Unknown norm type: {norm_type}")
68
+
69
+
70
+ def create_sin_embedding(positions: torch.Tensor, dim: int, max_period: float = 10000,
71
+ dtype: torch.dtype = torch.float32) -> torch.Tensor:
72
+ """Create sinusoidal positional embedding, with shape `[B, T, C]`.
73
+
74
+ Args:
75
+ positions (torch.Tensor): LongTensor of positions.
76
+ dim (int): Dimension of the embedding.
77
+ max_period (float): Maximum period of the cosine/sine functions.
78
+ dtype (torch.dtype or str): dtype to use to generate the embedding.
79
+ Returns:
80
+ torch.Tensor: Sinusoidal positional embedding.
81
+ """
82
+ # We aim for BTC format
83
+ assert dim % 2 == 0
84
+ half_dim = dim // 2
85
+ positions = positions.to(dtype)
86
+ adim = torch.arange(half_dim, device=positions.device, dtype=dtype).view(1, 1, -1)
87
+ max_period_tensor = torch.full([], max_period, device=positions.device, dtype=dtype) # avoid sync point
88
+ phase = positions / (max_period_tensor ** (adim / (half_dim - 1)))
89
+ return torch.cat([torch.cos(phase), torch.sin(phase)], dim=-1)
90
+
91
+
92
+ def expand_repeated_kv(x: torch.Tensor, n_rep: int, memory_efficient: bool) -> torch.Tensor:
93
+ """torch.repeat_interleave(x, dim=2, repeats=n_rep) from xlformers."""
94
+ if n_rep == 1:
95
+ return x
96
+ if _efficient_attention_backend == 'torch' and memory_efficient:
97
+ bs, n_kv_heads, slen, head_dim = x.shape
98
+ return (
99
+ x[:, :, None, :, :]
100
+ .expand(bs, n_kv_heads, n_rep, slen, head_dim)
101
+ .reshape(bs, n_kv_heads * n_rep, slen, head_dim)
102
+ )
103
+ else:
104
+ bs, slen, n_kv_heads, head_dim = x.shape
105
+ return (
106
+ x[:, :, :, None, :]
107
+ .expand(bs, slen, n_kv_heads, n_rep, head_dim)
108
+ .reshape(bs, slen, n_kv_heads * n_rep, head_dim)
109
+ )
110
+
111
+
112
+ class LayerScale(nn.Module):
113
+ """Layer scale from [Touvron et al 2021] (https://arxiv.org/pdf/2103.17239.pdf).
114
+ This rescales diagonally the residual outputs close to 0, with a learnt scale.
115
+
116
+ Args:
117
+ channels (int): Number of channels.
118
+ init (float): Initial scale.
119
+ channel_last (bool): If True, expect `[*, C]` shaped tensors, otherwise, `[*, C, T]`.
120
+ device (torch.device or str, optional): Device on which to initialize the module.
121
+ dtype (torch.dtype, optional): dtype to use to initialize the module.
122
+ """
123
+ def __init__(self, channels: int, init: float = 1e-4, channel_last: bool = True,
124
+ device=None, dtype=None):
125
+ super().__init__()
126
+ self.channel_last = channel_last
127
+ self.scale = nn.Parameter(
128
+ torch.full((channels,), init,
129
+ requires_grad=True, device=device, dtype=dtype))
130
+
131
+ def forward(self, x: torch.Tensor):
132
+ if self.channel_last:
133
+ return self.scale * x
134
+ else:
135
+ return self.scale[:, None] * x
136
+
137
+
138
+ class StreamingMultiheadAttention(StreamingModule):
139
+ """Similar to `nn.MultiheadAttention` but with support for streaming, causal evaluation.
140
+
141
+ Args:
142
+ embed_dim (int): Dimension to project to.
143
+ num_heads (int): Number of heads.
144
+ dropout (float): Dropout level.
145
+ bias (bool): Use bias in projections.
146
+ causal (bool): Causal mask applied automatically.
147
+ past_context (int, optional): Receptive field for the causal mask, infinite if None.
148
+ custom (bool): Use custom MHA implementation, for testing / benchmarking.
149
+ memory_efficient (bool): Use xformers based memory efficient attention.
150
+ attention_as_float32 (bool): Perform the attention as float32
151
+ (especially important with memory_efficient as autocast won't do this automatically).
152
+ rope (`RotaryEmbedding`, optional): Rope embedding to use.
153
+ cross_attention: Should be true when used as a cross attention.
154
+ All keys and values must be available at once, streaming is only for the queries.
155
+ Cannot be used with `causal` or `rope` (as it wouldn't make sens to
156
+ interpret the time steps in the keys relative to those in the queries).
157
+ safe_streaming (bool): Bug fix, will go away with xformers update.
158
+ qk_layer_norm (bool): Layer normalization applied to queries and keys before dot product.
159
+ kv_repeat (int): If > 1, will repeat keys and queries multiple times (need to divide num_heads).
160
+ This will lead to faster decoding time on A100 or other GPUs with tensorcore.
161
+ device (torch.device, optional): Device on which to initialize.
162
+ dtype (torch.dtype, optional): dtype to use.
163
+ """
164
+ def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0, bias: bool = True,
165
+ causal: bool = False, past_context: tp.Optional[int] = None, custom: bool = False,
166
+ memory_efficient: bool = False, attention_as_float32: bool = False,
167
+ rope: tp.Optional[RotaryEmbedding] = None, cross_attention: bool = False,
168
+ safe_streaming: bool = True, qk_layer_norm: bool = False, kv_repeat: int = 1,
169
+ device=None, dtype=None):
170
+ super().__init__()
171
+ factory_kwargs = {'device': device, 'dtype': dtype}
172
+ if past_context is not None:
173
+ assert causal
174
+
175
+ self.embed_dim = embed_dim
176
+ self.causal = causal
177
+ self.past_context = past_context
178
+ self.memory_efficient = memory_efficient
179
+ self.attention_as_float32 = attention_as_float32
180
+ self.rope = rope
181
+ self.cross_attention = cross_attention
182
+ self.safe_streaming = safe_streaming
183
+ self.num_heads = num_heads
184
+ self.dropout = dropout
185
+ self.kv_repeat = kv_repeat
186
+ if cross_attention:
187
+ assert not causal, "Causal cannot work with cross attention."
188
+ assert rope is None, "Rope cannot work with cross attention."
189
+
190
+ if memory_efficient:
191
+ _verify_xformers_memory_efficient_compat()
192
+
193
+ self.custom = _is_custom(custom, memory_efficient)
194
+ if self.custom:
195
+ out_dim = embed_dim
196
+ assert num_heads % kv_repeat == 0
197
+ assert not cross_attention or kv_repeat == 1
198
+ num_kv = num_heads // kv_repeat
199
+ kv_dim = (embed_dim // num_heads) * num_kv
200
+ out_dim += 2 * kv_dim
201
+ in_proj = nn.Linear(embed_dim, out_dim, bias=bias, **factory_kwargs)
202
+ # We try to follow the default PyTorch MHA convention, to easily compare results.
203
+ self.in_proj_weight = in_proj.weight
204
+ self.in_proj_bias = in_proj.bias
205
+ if bias:
206
+ self.in_proj_bias.data.zero_() # Following Pytorch convention
207
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs)
208
+ if bias:
209
+ self.out_proj.bias.data.zero_()
210
+ else:
211
+ assert not qk_layer_norm
212
+ assert kv_repeat == 1
213
+ self.mha = nn.MultiheadAttention(
214
+ embed_dim, num_heads, dropout=dropout, bias=bias, batch_first=True,
215
+ **factory_kwargs)
216
+ self.qk_layer_norm = qk_layer_norm
217
+ if qk_layer_norm:
218
+ assert self.custom
219
+ assert kv_repeat == 1
220
+ ln_dim = embed_dim
221
+ self.q_layer_norm = nn.LayerNorm(ln_dim)
222
+ self.k_layer_norm = nn.LayerNorm(ln_dim)
223
+
224
+ def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
225
+ if not self.custom:
226
+ # Support compat with regular MHA
227
+ keys = [n for n, _ in self.mha.named_parameters()]
228
+ for key in keys:
229
+ if prefix + key in state_dict:
230
+ state_dict[prefix + "mha." + key] = state_dict.pop(prefix + key)
231
+ super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
232
+
233
+ def _get_mask(self, current_steps: int, device: torch.device, dtype: torch.dtype):
234
+ # Return a causal mask, accounting for potentially stored past keys/values
235
+ # We actually return a bias for the attention score, as this has the same
236
+ # convention both in the builtin MHA in Pytorch, and Xformers functions.
237
+ time_dim = _get_attention_time_dimension(self.memory_efficient)
238
+ if self.memory_efficient:
239
+ from xformers.ops import LowerTriangularMask
240
+ if current_steps == 1:
241
+ # If we only have one step, then we do not need a mask.
242
+ return None
243
+ elif 'past_keys' in self._streaming_state:
244
+ raise RuntimeError("Not supported at the moment")
245
+ else:
246
+ # Then we can safely use a lower triangular mask
247
+ return LowerTriangularMask()
248
+ if self._streaming_state:
249
+ past_keys = self._streaming_state['past_keys']
250
+ past_steps = past_keys.shape[time_dim]
251
+ else:
252
+ past_steps = 0
253
+
254
+ queries_pos = torch.arange(
255
+ past_steps, current_steps + past_steps, device=device).view(-1, 1)
256
+ keys_pos = torch.arange(past_steps + current_steps, device=device).view(1, -1)
257
+ delta = queries_pos - keys_pos
258
+ valid = delta >= 0
259
+ if self.past_context is not None:
260
+ valid &= (delta <= self.past_context)
261
+ return torch.where(
262
+ valid,
263
+ torch.zeros([], device=device, dtype=dtype),
264
+ torch.full([], float('-inf'), device=device, dtype=dtype))
265
+
266
+ def _complete_kv(self, k, v):
267
+ time_dim = _get_attention_time_dimension(self.memory_efficient)
268
+ if self.cross_attention:
269
+ # With cross attention we assume all keys and values
270
+ # are already available, and streaming is with respect
271
+ # to the queries only.
272
+ return k, v
273
+ # Complete the key/value pair using the streaming state.
274
+ if self._streaming_state:
275
+ pk = self._streaming_state['past_keys']
276
+ nk = torch.cat([pk, k], dim=time_dim)
277
+ if v is k:
278
+ nv = nk
279
+ else:
280
+ pv = self._streaming_state['past_values']
281
+ nv = torch.cat([pv, v], dim=time_dim)
282
+ else:
283
+ nk = k
284
+ nv = v
285
+
286
+ assert nk.shape[time_dim] == nv.shape[time_dim]
287
+ offset = 0
288
+ if self.past_context is not None:
289
+ offset = max(0, nk.shape[time_dim] - self.past_context)
290
+ if self._is_streaming:
291
+ self._streaming_state['past_keys'] = nk[:, offset:]
292
+ if v is not k:
293
+ self._streaming_state['past_values'] = nv[:, offset:]
294
+ if 'offset' in self._streaming_state:
295
+ self._streaming_state['offset'] += offset
296
+ else:
297
+ self._streaming_state['offset'] = torch.tensor(0)
298
+ return nk, nv
299
+
300
+ def _apply_rope(self, query: torch.Tensor, key: torch.Tensor):
301
+ time_dim = _get_attention_time_dimension(self.memory_efficient)
302
+ # Apply rope embeddings to query and key tensors.
303
+ assert self.rope is not None
304
+ if 'past_keys' in self._streaming_state:
305
+ past_keys_offset = self._streaming_state['past_keys'].shape[1]
306
+ else:
307
+ past_keys_offset = 0
308
+ if 'offset' in self._streaming_state:
309
+ past_context_offset = int(self._streaming_state['offset'].item())
310
+ else:
311
+ past_context_offset = 0
312
+ streaming_offset = past_context_offset + past_keys_offset
313
+ return self.rope.rotate_qk(query, key, start=streaming_offset, time_dim=time_dim)
314
+
315
+ def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor,
316
+ key_padding_mask=None, need_weights=False, attn_mask=None,
317
+ average_attn_weights=True, is_causal=False):
318
+ assert not is_causal, ("New param added in torch 2.0.1 not supported, "
319
+ "use the causal args in the constructor.")
320
+
321
+ time_dim = _get_attention_time_dimension(self.memory_efficient)
322
+ if time_dim == 2:
323
+ layout = "b h t d"
324
+ else:
325
+ layout = "b t h d"
326
+ dtype = query.dtype
327
+ if self._is_streaming:
328
+ assert self.causal or self.cross_attention, \
329
+ "Streaming only available for causal or cross attention"
330
+
331
+ custom_attn_mask = attn_mask is not None
332
+
333
+ if self.causal:
334
+ assert attn_mask is None
335
+ # At the moment we specialize only for the self-attention case.
336
+ assert query.shape[1] == key.shape[1], "Causal only for same length query / key / value"
337
+ assert value.shape[1] == key.shape[1], "Causal only for same length query / key / value"
338
+ attn_mask = self._get_mask(query.shape[1], query.device, query.dtype)
339
+
340
+ if self.custom:
341
+ # custom implementation
342
+ assert need_weights is False
343
+ assert key_padding_mask is None
344
+ if self.cross_attention:
345
+ # Different queries, keys, values, we have to spit manually the weights
346
+ # before applying the linear.
347
+ dim = self.in_proj_weight.shape[0] // 3
348
+ if self.in_proj_bias is None:
349
+ bias_q, bias_k, bias_v = None, None, None
350
+ else:
351
+ bias_q = self.in_proj_bias[:dim]
352
+ bias_k = self.in_proj_bias[dim: 2 * dim]
353
+ bias_v = self.in_proj_bias[2 * dim:]
354
+ q = nn.functional.linear(query, self.in_proj_weight[:dim], bias_q)
355
+ # todo: when streaming, we could actually save k, v and check the shape actually match.
356
+ k = nn.functional.linear(key, self.in_proj_weight[dim: 2 * dim], bias_k)
357
+ v = nn.functional.linear(value, self.in_proj_weight[2 * dim:], bias_v)
358
+ if self.qk_layer_norm is True:
359
+ q = self.q_layer_norm(q)
360
+ k = self.k_layer_norm(k)
361
+ q, k, v = [rearrange(x, f"b t (h d) -> {layout}", h=self.num_heads) for x in [q, k, v]]
362
+ else:
363
+ if not _is_profiled():
364
+ # profiling breaks that propertysomehow.
365
+ assert query is key, "specialized implementation"
366
+ assert value is key, "specialized implementation"
367
+ projected = nn.functional.linear(query, self.in_proj_weight, self.in_proj_bias)
368
+ if self.kv_repeat == 1:
369
+ if time_dim == 2:
370
+ bound_layout = "b h p t d"
371
+ else:
372
+ bound_layout = "b t p h d"
373
+ packed = rearrange(projected, f"b t (p h d) -> {bound_layout}", p=3, h=self.num_heads)
374
+ q, k, v = ops.unbind(packed, dim=2)
375
+ else:
376
+ embed_dim = self.embed_dim
377
+ per_head_dim = (embed_dim // self.num_heads)
378
+ kv_heads = self.num_heads // self.kv_repeat
379
+ q = projected[:, :, :embed_dim]
380
+ start = embed_dim
381
+ end = start + per_head_dim * kv_heads
382
+ k = projected[:, :, start: end]
383
+ v = projected[:, :, end:]
384
+ q = rearrange(q, f"b t (h d) -> {layout}", h=self.num_heads)
385
+ k = rearrange(k, f"b t (h d) -> {layout}", h=kv_heads)
386
+ v = rearrange(v, f"b t (h d) -> {layout}", h=kv_heads)
387
+
388
+ if self.qk_layer_norm is True:
389
+ assert self.kv_repeat == 1
390
+ q, k = [rearrange(x, f"{layout} -> b t (h d)") for x in [q, k]]
391
+ q = self.q_layer_norm(q)
392
+ k = self.k_layer_norm(k)
393
+ q, k = [rearrange(x, f"b t (h d) -> {layout}", h=self.num_heads) for x in [q, k]]
394
+ if self.rope:
395
+ q, k = self._apply_rope(q, k)
396
+ k, v = self._complete_kv(k, v)
397
+ if self.kv_repeat > 1:
398
+ k = expand_repeated_kv(k, self.kv_repeat, self.memory_efficient)
399
+ v = expand_repeated_kv(v, self.kv_repeat, self.memory_efficient)
400
+ if self.attention_as_float32:
401
+ q, k, v = [x.float() for x in [q, k, v]]
402
+ if self.memory_efficient:
403
+ if custom_attn_mask:
404
+ # When using a custom attn mask:
405
+ # Move to query's device, repeat for each sample, remove align8 padding
406
+ seq_len = query.shape[1]
407
+ attn_mask = attn_mask.to(q.dtype)
408
+ attn_mask = attn_mask.repeat((q.shape[0], 1, 1, 1))
409
+ attn_mask = attn_mask[..., :seq_len, :seq_len]
410
+
411
+ p = self.dropout if self.training else 0
412
+ if _efficient_attention_backend == 'torch':
413
+ x = torch.nn.functional.scaled_dot_product_attention(
414
+ q, k, v, is_causal=attn_mask is not None, dropout_p=p)
415
+ else:
416
+ x = ops.memory_efficient_attention(q, k, v, attn_mask, p=p)
417
+ else:
418
+ # We include the dot product as float32, for consistency
419
+ # with the other implementations that include that step
420
+ # as part of the attention. Note that when using `autocast`,
421
+ # the einsums would be done as bfloat16, but the softmax
422
+ # would be done as bfloat16, so `attention_as_float32` will
423
+ # extend a bit the range of operations done in float32,
424
+ # although this should make no difference.
425
+ q = q / q.shape[-1] ** 0.5
426
+ key_layout = layout.replace('t', 'k')
427
+ query_layout = layout
428
+ if self._is_streaming and self.safe_streaming and q.device.type == 'cuda':
429
+ with torch.autocast(device_type=q.device.type, dtype=torch.float32):
430
+ pre_w = torch.einsum(f"{query_layout},{key_layout}-> b h t k", q, k)
431
+ else:
432
+ pre_w = torch.einsum(f"{query_layout},{key_layout}-> b h t k", q, k)
433
+ if attn_mask is not None:
434
+ pre_w = pre_w + attn_mask
435
+ w = torch.softmax(pre_w, dim=-1)
436
+ w = F.dropout(w, self.dropout, training=self.training).to(v)
437
+ # Key and value have the same format.
438
+ x = torch.einsum(f"b h t k, {key_layout} -> {layout}", w, v)
439
+ x = x.to(dtype)
440
+ x = rearrange(x, f"{layout} -> b t (h d)", h=self.num_heads)
441
+ x = self.out_proj(x)
442
+ else:
443
+ key, value = self._complete_kv(key, value)
444
+ if self.attention_as_float32:
445
+ query, key, value = [x.float() for x in [query, key, value]]
446
+ x, _ = self.mha(
447
+ query, key, value, key_padding_mask,
448
+ need_weights, attn_mask, average_attn_weights)
449
+ x = x.to(dtype)
450
+
451
+ return x, None
452
+
453
+
454
+ class StreamingTransformerLayer(nn.TransformerEncoderLayer):
455
+ """TransformerLayer with Streaming / Causal support.
456
+ This also integrates cross_attention, when passing `cross_attention=True`,
457
+ rather than having two separate classes like in PyTorch.
458
+
459
+ Args:
460
+ d_model (int): Dimension of the data.
461
+ num_heads (int): Number of heads.
462
+ dim_feedforward (int): Intermediate dimension of FF module.
463
+ dropout (float): Dropout both for MHA and FF.
464
+ bias_ff (bool): Use bias for FF.
465
+ bias_attn (bool): Use bias for MHA.
466
+ causal (bool): Causal mask applied automatically.
467
+ past_context (int, optional): Receptive field for the causal mask, infinite if None.
468
+ custom (bool): Use custom MHA implementation, for testing / benchmarking.
469
+ memory_efficient (bool): Use xformers based memory efficient attention.
470
+ attention_as_float32 (bool): Perform the attention as float32
471
+ (especially important with memory_efficient as autocast won't do this automatically).
472
+ qk_layer_norm (bool): Layer normalization applied to queries and keys before dot product in attention.
473
+ qk_layer_norm_cross (bool): Same for the cross attention.
474
+ cross_attention (bool): If True, expect to get secondary input for cross-attention.
475
+ Cross attention will use the default MHA, as it typically won't require
476
+ special treatment.
477
+ layer_scale (float, optional): If not None, LayerScale will be used with
478
+ the given value as initial scale.
479
+ rope (`RotaryEmbedding`, optional): Rope embedding to use.
480
+ attention_dropout (float, optional): If not None, separate the value of the dimension dropout
481
+ in FFN and of the attention dropout.
482
+ kv_repeat (int): If > 1, will repeat keys and queries multiple times (need to divide num_heads).
483
+ This will lead to faster decoding time on A100 or other GPUs with tensorcore.
484
+ device (torch.device, optional): Device on which to initialize.
485
+ dtype (torch.dtype, optional): dtype to use.
486
+ **kwargs: See `nn.TransformerEncoderLayer`.
487
+ """
488
+ def __init__(self, d_model: int, num_heads: int, dim_feedforward: int = 2048, dropout: float = 0.1,
489
+ bias_ff: bool = True, bias_attn: bool = True, causal: bool = False,
490
+ past_context: tp.Optional[int] = None, custom: bool = False,
491
+ memory_efficient: bool = False, attention_as_float32: bool = False,
492
+ qk_layer_norm: bool = False, qk_layer_norm_cross: bool = False,
493
+ cross_attention: bool = False, layer_scale: tp.Optional[float] = None,
494
+ rope: tp.Optional[RotaryEmbedding] = None, attention_dropout: tp.Optional[float] = None,
495
+ kv_repeat: int = 1, norm: str = 'layer_norm', device=None, dtype=None, **kwargs):
496
+ super().__init__(d_model, num_heads, dim_feedforward, dropout,
497
+ device=device, dtype=dtype, batch_first=True, **kwargs)
498
+ factory_kwargs = {'device': device, 'dtype': dtype}
499
+ # Redefine self_attn to our streaming multi-head attention
500
+ attn_kwargs: tp.Dict[str, tp.Any] = {
501
+ 'embed_dim': d_model,
502
+ 'num_heads': num_heads,
503
+ 'dropout': dropout if attention_dropout is None else attention_dropout,
504
+ 'bias': bias_attn,
505
+ 'custom': custom,
506
+ 'memory_efficient': memory_efficient,
507
+ 'attention_as_float32': attention_as_float32,
508
+ }
509
+ self.self_attn: StreamingMultiheadAttention = StreamingMultiheadAttention(
510
+ causal=causal, past_context=past_context, rope=rope, qk_layer_norm=qk_layer_norm,
511
+ kv_repeat=kv_repeat, **attn_kwargs, **factory_kwargs) # type: ignore
512
+ # Redefine feedforward layers to expose bias parameter
513
+ self.linear1 = nn.Linear(d_model, dim_feedforward, bias=bias_ff, **factory_kwargs)
514
+ self.linear2 = nn.Linear(dim_feedforward, d_model, bias=bias_ff, **factory_kwargs)
515
+
516
+ self.layer_scale_1: nn.Module
517
+ self.layer_scale_2: nn.Module
518
+ if layer_scale is None:
519
+ self.layer_scale_1 = nn.Identity()
520
+ self.layer_scale_2 = nn.Identity()
521
+ else:
522
+ self.layer_scale_1 = LayerScale(d_model, layer_scale, **factory_kwargs)
523
+ self.layer_scale_2 = LayerScale(d_model, layer_scale, **factory_kwargs)
524
+
525
+ self.cross_attention: tp.Optional[nn.Module] = None
526
+ if cross_attention:
527
+ self.cross_attention = StreamingMultiheadAttention(
528
+ cross_attention=True, qk_layer_norm=qk_layer_norm_cross,
529
+ **attn_kwargs, **factory_kwargs)
530
+ # Norm and dropout
531
+ self.dropout_cross = nn.Dropout(dropout)
532
+ # eps value matching that used in PyTorch reference implementation.
533
+ self.norm_cross = nn.LayerNorm(d_model, eps=1e-5, **factory_kwargs)
534
+ self.layer_scale_cross: nn.Module
535
+ if layer_scale is None:
536
+ self.layer_scale_cross = nn.Identity()
537
+ else:
538
+ self.layer_scale_cross = LayerScale(d_model, layer_scale, **factory_kwargs)
539
+ self.norm1 = create_norm_fn(norm, d_model, **factory_kwargs) # type: ignore
540
+ self.norm2 = create_norm_fn(norm, d_model, **factory_kwargs) # type: ignore
541
+
542
+ def _cross_attention_block(self, src: torch.Tensor,
543
+ cross_attention_src: torch.Tensor) -> torch.Tensor:
544
+ assert self.cross_attention is not None
545
+ # queries are from src, keys and values from cross_attention_src.
546
+ x = self.cross_attention(
547
+ src, cross_attention_src, cross_attention_src, need_weights=False)[0]
548
+ return self.dropout_cross(x) # type: ignore
549
+
550
+ def forward(self, src: torch.Tensor, src_mask: tp.Optional[torch.Tensor] = None, # type: ignore
551
+ src_key_padding_mask: tp.Optional[torch.Tensor] = None,
552
+ cross_attention_src: tp.Optional[torch.Tensor] = None):
553
+ if self.cross_attention is None:
554
+ assert cross_attention_src is None
555
+ else:
556
+ assert cross_attention_src is not None
557
+ x = src
558
+ if self.norm_first:
559
+ x = x + self.layer_scale_1(
560
+ self._sa_block(self.norm1(x), src_mask, src_key_padding_mask))
561
+ if cross_attention_src is not None:
562
+ x = x + self.layer_scale_cross(
563
+ self._cross_attention_block(
564
+ self.norm_cross(x), cross_attention_src))
565
+ x = x + self.layer_scale_2(self._ff_block(self.norm2(x)))
566
+ else:
567
+ x = self.norm1(x + self.layer_scale_1(
568
+ self._sa_block(x, src_mask, src_key_padding_mask)))
569
+ if cross_attention_src is not None:
570
+ x = self.norm_cross(
571
+ x + self.layer_scale_cross(
572
+ self._cross_attention_block(src, cross_attention_src)))
573
+ x = self.norm2(x + self.layer_scale_2(self._ff_block(x)))
574
+ return x
575
+
576
+
577
+ class StreamingTransformer(StreamingModule):
578
+ """Transformer with Streaming / Causal support.
579
+
580
+ Args:
581
+ d_model (int): Dimension of the data.
582
+ num_heads (int): Number of heads.
583
+ dim_feedforward (int): Intermediate dimension of FF module.
584
+ dropout (float): Dropout both for MHA and FF.
585
+ bias_ff (bool): Use bias for FF.
586
+ bias_attn (bool): Use bias for MHA.
587
+ causal (bool): Causal mask applied automatically.
588
+ past_context (int, optional): Receptive field for the causal mask, infinite if None.
589
+ custom (bool): Use custom MHA implementation, for testing / benchmarking.
590
+ memory_efficient (bool): Use xformers based memory efficient attention.
591
+ attention_as_float32 (bool): Perform the attention as float32
592
+ (especially important with memory_efficient as autocast won't do this automatically).
593
+ cross_attention (bool): If True, expect to get secondary input for cross-attention.
594
+ layer_scale (float, optional): If not None, LayerScale will be used
595
+ with the given value as initial scale.
596
+ positional_embedding (str): Positional embedding strategy (sin, rope, or sin_rope).
597
+ max_period (float): Maximum period of the time embedding.
598
+ positional_scale (float): Scale of positional embedding, set to 0 to deactivate.
599
+ xpos (bool): Apply xpos exponential decay to positional embedding (rope only).
600
+ lr (float, optional): learning rate override through the `make_optim_group` API.
601
+ weight_decay (float, optional): Weight_decay override through the `make_optim_group` API.
602
+ layer_class: (subclass of `StreamingTransformerLayer): class to use
603
+ to initialize the layers, allowing further customization outside of AudioCraft.
604
+ checkpointing (str): Checkpointing strategy to reduce memory usage.
605
+ No checkpointing if set to 'none'. Per layer checkpointing using PyTorch
606
+ if set to 'torch' (entire layer checkpointed, i.e. linears are evaluated twice,
607
+ minimal memory usage, but maximal runtime). Finally, `xformers_default` provide
608
+ a policy for opting-out some operations of the checkpointing like
609
+ linear layers and attention, providing a middle ground between speed and memory.
610
+ device (torch.device, optional): Device on which to initialize.
611
+ dtype (torch.dtype, optional): dtype to use.
612
+ **kwargs: See `nn.TransformerEncoderLayer`.
613
+ """
614
+ def __init__(self, d_model: int, num_heads: int, num_layers: int, dim_feedforward: int = 2048,
615
+ dropout: float = 0.1, bias_ff: bool = True, bias_attn: bool = True,
616
+ causal: bool = False, past_context: tp.Optional[int] = None,
617
+ custom: bool = False, memory_efficient: bool = False, attention_as_float32: bool = False,
618
+ cross_attention: bool = False, layer_scale: tp.Optional[float] = None,
619
+ positional_embedding: str = 'sin', max_period: float = 10_000, positional_scale: float = 1.,
620
+ xpos: bool = False, lr: tp.Optional[float] = None, weight_decay: tp.Optional[float] = None,
621
+ layer_class: tp.Type[StreamingTransformerLayer] = StreamingTransformerLayer,
622
+ checkpointing: str = 'none', device=None, dtype=None, **kwargs):
623
+ super().__init__()
624
+ assert d_model % num_heads == 0
625
+
626
+ self.positional_embedding = positional_embedding
627
+ self.max_period = max_period
628
+ self.positional_scale = positional_scale
629
+ self.weight_decay = weight_decay
630
+ self.lr = lr
631
+
632
+ assert positional_embedding in ['sin', 'rope', 'sin_rope']
633
+ self.rope: tp.Optional[RotaryEmbedding] = None
634
+ if self.positional_embedding in ['rope', 'sin_rope']:
635
+ assert _is_custom(custom, memory_efficient)
636
+ self.rope = RotaryEmbedding(d_model // num_heads, max_period=max_period,
637
+ xpos=xpos, scale=positional_scale, device=device)
638
+
639
+ self.checkpointing = checkpointing
640
+
641
+ assert checkpointing in ['none', 'torch', 'xformers_default', 'xformers_mm']
642
+ if self.checkpointing.startswith('xformers'):
643
+ _verify_xformers_internal_compat()
644
+
645
+ self.layers = nn.ModuleList()
646
+ for idx in range(num_layers):
647
+ self.layers.append(
648
+ layer_class(
649
+ d_model=d_model, num_heads=num_heads, dim_feedforward=dim_feedforward,
650
+ dropout=dropout, bias_ff=bias_ff, bias_attn=bias_attn,
651
+ causal=causal, past_context=past_context, custom=custom,
652
+ memory_efficient=memory_efficient, attention_as_float32=attention_as_float32,
653
+ cross_attention=cross_attention, layer_scale=layer_scale, rope=self.rope,
654
+ device=device, dtype=dtype, **kwargs))
655
+
656
+ if self.checkpointing != 'none':
657
+ for layer in self.layers:
658
+ # see audiocraft/optim/fsdp.py, magic signal to indicate this requires fixing the
659
+ # backward hook inside of FSDP...
660
+ layer._magma_checkpointed = True # type: ignore
661
+
662
+ def _apply_layer(self, layer, *args, **kwargs):
663
+ method = self.checkpointing
664
+ if method == 'none':
665
+ return layer(*args, **kwargs)
666
+ elif method == 'torch':
667
+ return torch_checkpoint(layer, *args, use_reentrant=False, **kwargs)
668
+ elif method.startswith('xformers'):
669
+ from xformers.checkpoint_fairinternal import checkpoint, _get_default_policy
670
+ if method == 'xformers_default':
671
+ # those operations will be saved, and not recomputed.
672
+ # According to Francisco we can get smarter policies but this is a good start.
673
+ allow_list = [
674
+ "xformers.efficient_attention_forward_cutlass.default",
675
+ "xformers_flash.flash_fwd.default",
676
+ "aten.addmm.default",
677
+ "aten.mm.default",
678
+ ]
679
+ elif method == 'xformers_mm':
680
+ # those operations will be saved, and not recomputed.
681
+ # According to Francisco we can get smarter policies but this is a good start.
682
+ allow_list = [
683
+ "aten.addmm.default",
684
+ "aten.mm.default",
685
+ ]
686
+ else:
687
+ raise ValueError(f"xformers checkpointing xformers policy {method} is not known.")
688
+ policy_fn = _get_default_policy(allow_list)
689
+ return checkpoint(layer, *args, policy_fn=policy_fn, **kwargs)
690
+ else:
691
+ raise ValueError(f"Checkpointing method {method} is unknown.")
692
+
693
+ def forward(self, x: torch.Tensor, *args, **kwargs):
694
+ B, T, C = x.shape
695
+
696
+ if 'offsets' in self._streaming_state:
697
+ offsets = self._streaming_state['offsets']
698
+ else:
699
+ offsets = torch.zeros(B, dtype=torch.long, device=x.device)
700
+
701
+ if self.positional_embedding in ['sin', 'sin_rope']:
702
+ positions = torch.arange(T, device=x.device).view(1, -1, 1)
703
+ positions = positions + offsets.view(-1, 1, 1)
704
+ pos_emb = create_sin_embedding(positions, C, max_period=self.max_period, dtype=x.dtype)
705
+ x = x + self.positional_scale * pos_emb
706
+
707
+ for layer in self.layers:
708
+ x = self._apply_layer(layer, x, *args, **kwargs)
709
+
710
+ if self._is_streaming:
711
+ self._streaming_state['offsets'] = offsets + T
712
+
713
+ return x
714
+
715
+ def make_optim_group(self):
716
+ group = {"params": list(self.parameters())}
717
+ if self.lr is not None:
718
+ group["lr"] = self.lr
719
+ if self.weight_decay is not None:
720
+ group["weight_decay"] = self.weight_decay
721
+ return group
722
+
723
+
724
+ # special attention related function
725
+
726
+ def _verify_xformers_memory_efficient_compat():
727
+ try:
728
+ from xformers.ops import memory_efficient_attention, LowerTriangularMask # noqa
729
+ except ImportError:
730
+ raise ImportError(
731
+ "xformers is not installed. Please install it and try again.\n"
732
+ "To install on AWS and Azure, run \n"
733
+ "FORCE_CUDA=1 TORCH_CUDA_ARCH_LIST='8.0'\\\n"
734
+ "pip install -U git+https://git@github.com/fairinternal/xformers.git#egg=xformers\n"
735
+ "To install on FAIR Cluster, run \n"
736
+ "FORCE_CUDA=1 TORCH_CUDA_ARCH_LIST='6.0;7.0'\\\n"
737
+ "pip install -U git+https://git@github.com/fairinternal/xformers.git#egg=xformers\n")
738
+
739
+
740
+ def _verify_xformers_internal_compat():
741
+ try:
742
+ from xformers.checkpoint_fairinternal import checkpoint, _get_default_policy # noqa
743
+ except ImportError:
744
+ raise ImportError(
745
+ "Francisco's fairinternal xformers is not installed. Please install it and try again.\n"
746
+ "To install on AWS and Azure, run \n"
747
+ "FORCE_CUDA=1 TORCH_CUDA_ARCH_LIST='8.0'\\\n"
748
+ "pip install -U git+https://git@github.com/fairinternal/xformers.git#egg=xformers\n"
749
+ "To install on FAIR Cluster, run \n"
750
+ "FORCE_CUDA=1 TORCH_CUDA_ARCH_LIST='6.0;7.0'\\\n"
751
+ "pip install -U git+https://git@github.com/fairinternal/xformers.git#egg=xformers\n")
752
+
753
+
754
+ def _is_custom(custom: bool, memory_efficient: bool):
755
+ return custom or memory_efficient