anonymous9a7b commited on
Commit
f032e68
·
1 Parent(s): 6cf6784
app.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import spaces
3
+ import yaml
4
+ import torch
5
+ import librosa
6
+ from diffusers import DDIMScheduler
7
+ from transformers import AutoProcessor, ClapModel
8
+ from model.udit import UDiT
9
+ from vae_modules.autoencoder_wrapper import Autoencoder
10
+ import numpy as np
11
+
12
+ diffusion_config = './config/SoloAudio.yaml'
13
+ diffusion_ckpt = './pretrained_models/soloaudio_v2.pt'
14
+ autoencoder_path = './pretrained_models/audio-vae.pt'
15
+ uncond_path = './pretrained_models/uncond.npz'
16
+ sample_rate = 24000
17
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
18
+
19
+ with open(diffusion_config, 'r') as fp:
20
+ diff_config = yaml.safe_load(fp)
21
+
22
+ v_prediction = diff_config["ddim"]["v_prediction"]
23
+
24
+ clapmodel = ClapModel.from_pretrained("laion/larger_clap_general").to(device)
25
+ processor = AutoProcessor.from_pretrained('laion/larger_clap_general')
26
+ autoencoder = Autoencoder(autoencoder_path, 'stable_vae', quantization_first=True)
27
+ autoencoder.eval()
28
+ autoencoder.to(device)
29
+ unet = UDiT(
30
+ **diff_config['diffwrap']['UDiT']
31
+ ).to(device)
32
+ unet.load_state_dict(torch.load(diffusion_ckpt)['model'])
33
+ unet.eval()
34
+
35
+ if v_prediction:
36
+ print('v prediction')
37
+ scheduler = DDIMScheduler(**diff_config["ddim"]['diffusers'])
38
+ else:
39
+ print('noise prediction')
40
+ scheduler = DDIMScheduler(**diff_config["ddim"]['diffusers'])
41
+
42
+ # these steps reset dtype of noise_scheduler params
43
+ latents = torch.randn((1, 128, 128),
44
+ device=device)
45
+ noise = torch.randn(latents.shape).to(latents.device)
46
+ timesteps = torch.randint(0, scheduler.config.num_train_timesteps,
47
+ (noise.shape[0],),
48
+ device=latents.device).long()
49
+ _ = scheduler.add_noise(latents, noise, timesteps)
50
+
51
+
52
+ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
53
+ """
54
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
55
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
56
+ """
57
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
58
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
59
+ # rescale the results from guidance (fixes overexposure)
60
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
61
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
62
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
63
+ return noise_cfg
64
+
65
+ @spaces.GPU
66
+ def sample_diffusion(mixture, timbre, ddim_steps=50, eta=0, seed=2023, guidance_scale=False, guidance_rescale=0.0,):
67
+ with torch.no_grad():
68
+ scheduler.set_timesteps(ddim_steps)
69
+ generator = torch.Generator(device=device).manual_seed(seed)
70
+ # init noise
71
+ noise = torch.randn(mixture.shape, generator=generator, device=device)
72
+ pred = noise
73
+
74
+ for t in scheduler.timesteps:
75
+ pred = scheduler.scale_model_input(pred, t)
76
+ if guidance_scale:
77
+ uncond = torch.tensor(np.load(uncond_path)['arr_0']).unsqueeze(0).to(device)
78
+ pred_combined = torch.cat([pred, pred], dim=0)
79
+ mixture_combined = torch.cat([mixture, mixture], dim=0)
80
+ timbre_combined = torch.cat([timbre, uncond], dim=0)
81
+ output_combined = unet(x=pred_combined, timesteps=t, mixture=mixture_combined, timbre=timbre_combined)
82
+ output_pos, output_neg = torch.chunk(output_combined, 2, dim=0)
83
+
84
+ model_output = output_neg + guidance_scale * (output_pos - output_neg)
85
+ if guidance_rescale > 0.0:
86
+ # avoid overexposed
87
+ model_output = rescale_noise_cfg(model_output, output_pos,
88
+ guidance_rescale=guidance_rescale)
89
+ else:
90
+ model_output = unet(x=pred, timesteps=t, mixture=mixture, timbre=timbre)
91
+ pred = scheduler.step(model_output=model_output, timestep=t, sample=pred,
92
+ eta=eta, generator=generator).prev_sample
93
+
94
+ pred = autoencoder(embedding=pred).squeeze(1)
95
+
96
+ return pred
97
+
98
+ @spaces.GPU
99
+ def tse(gt_file_input, text_input, num_infer_steps, eta, seed, guidance_scale, guidance_rescale):
100
+ with torch.no_grad():
101
+ mixture, _ = librosa.load(gt_file_input, sr=sample_rate)
102
+ # Check the length of the audio in samples
103
+ current_length = len(mixture)
104
+ target_length = sample_rate * 10
105
+ # Cut or pad the audio to match the target length
106
+ if current_length > target_length:
107
+ # Trim the audio if it's longer than the target length
108
+ mixture = mixture[:target_length]
109
+ elif current_length < target_length:
110
+ # Pad the audio with zeros if it's shorter than the target length
111
+ padding = target_length - current_length
112
+ mixture = np.pad(mixture, (0, padding), mode='constant')
113
+ mixture = torch.tensor(mixture).unsqueeze(0).to(device)
114
+ mixture = autoencoder(audio=mixture.unsqueeze(1))
115
+
116
+ text_inputs = processor(
117
+ text=[text_input],
118
+ max_length=10, # Fixed length for text
119
+ padding='max_length', # Pad text to max length
120
+ truncation=True, # Truncate text if it's longer than max length
121
+ return_tensors="pt"
122
+ )
123
+ inputs = {
124
+ "input_ids": text_inputs["input_ids"][0].unsqueeze(0), # Text input IDs
125
+ "attention_mask": text_inputs["attention_mask"][0].unsqueeze(0), # Attention mask for text
126
+ }
127
+ inputs = {key: value.to(device) for key, value in inputs.items()}
128
+ timbre = clapmodel.get_text_features(**inputs)
129
+
130
+
131
+ pred = sample_diffusion(mixture, timbre, num_infer_steps, eta, seed, guidance_scale, guidance_rescale)
132
+ return sample_rate, pred.squeeze().cpu().numpy()
133
+
134
+
135
+ # CSS styling (optional)
136
+ css = """
137
+ #col-container {
138
+ margin: 0 auto;
139
+ max-width: 1280px;
140
+ }
141
+ """
142
+
143
+ # Gradio Blocks layout
144
+ with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
145
+ with gr.Column(elem_id="col-container"):
146
+ gr.Markdown("""
147
+ # SoloAudio: Target Sound Extraction with Language-oriented Audio Diffusion Transformer.
148
+ Adjust advanced settings for more control. This space only supports a 10-second audio input now.
149
+
150
+ Learn more about 🟣**SoloAudio** on the [SoloAudio Homepage](https://wanghelin1997.github.io/SoloAudio-Demo/).
151
+ """)
152
+
153
+
154
+ with gr.Tab("Target Sound Extraction"):
155
+ # Basic Input: Text prompt
156
+ with gr.Row():
157
+ gt_file_input = gr.Audio(label="Upload Audio to Extract", type="filepath", value="demo/0_mix.wav")
158
+ text_input = gr.Textbox(
159
+ label="Text Prompt",
160
+ show_label=True,
161
+ max_lines=2,
162
+ placeholder="Enter your prompt",
163
+ container=True,
164
+ value="The sound of gunshot",
165
+ scale=4
166
+ )
167
+ # Run button
168
+ run_button = gr.Button("Extract", scale=1)
169
+
170
+ # Output Component
171
+ result = gr.Audio(label="Extracted Audio", type="numpy")
172
+
173
+ # Advanced settings in an Accordion
174
+ with gr.Accordion("Advanced Settings", open=False):
175
+ # Audio Length
176
+ guidance_scale = gr.Slider(minimum=1.0, maximum=10, step=0.1, value=3.0, label="Guidance Scale")
177
+ guidance_rescale = gr.Slider(minimum=0.0, maximum=1, step=0.05, value=0., label="Guidance Rescale")
178
+ num_infer_steps = gr.Slider(minimum=25, maximum=200, step=5, value=50, label="DDIM Steps")
179
+ eta = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=0.0, label="Eta")
180
+ seed = gr.Slider(minimum=0, maximum=100, step=1, value=0, label="Seed")
181
+
182
+ # Define the trigger and input-output linking for generation
183
+ run_button.click(
184
+ fn=tse,
185
+ inputs=[gt_file_input, text_input, num_infer_steps, eta, seed, guidance_scale, guidance_rescale],
186
+ outputs=[result]
187
+ )
188
+ text_input.submit(fn=tse,
189
+ inputs=[gt_file_input, text_input, num_infer_steps, eta, seed, guidance_scale, guidance_rescale],
190
+ outputs=[result]
191
+ )
192
+
193
+ # Launch the Gradio demo
194
+ demo.launch()
config/SoloAudio.yaml ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ version: 1.0
2
+
3
+ system: "udit_rotary_v_b_1000"
4
+
5
+ ddim:
6
+ v_prediction: true
7
+ diffusers:
8
+ num_train_timesteps: 1000
9
+ beta_schedule: 'scaled_linear'
10
+ beta_start: 0.00085
11
+ beta_end: 0.012
12
+ prediction_type: 'v_prediction'
13
+ rescale_betas_zero_snr: true
14
+ timestep_spacing: 'trailing'
15
+ clip_sample: false
16
+
17
+ diffwrap:
18
+ UDiT:
19
+ input_dim: 256
20
+ output_dim: 128
21
+ pos_method: 'none'
22
+ pos_length: 500
23
+ timbre_dim: 512
24
+ hidden_size: 768
25
+ depth: 12
26
+ num_heads: 12
demo/0_mix.wav ADDED
Binary file (480 kB). View file
 
demo/soloaudio.webp ADDED
model/attention.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torch.utils.checkpoint
5
+ import einops
6
+ from einops import rearrange, repeat
7
+ from inspect import isfunction
8
+ from .rotary import RotaryEmbedding
9
+
10
+
11
+ if hasattr(nn.functional, 'scaled_dot_product_attention'):
12
+ ATTENTION_MODE = 'flash'
13
+ else:
14
+ ATTENTION_MODE = 'math'
15
+ print(f'attention mode is {ATTENTION_MODE}')
16
+
17
+
18
+ def add_mask(sim, mask):
19
+ b, ndim = sim.shape[0], mask.ndim
20
+ if ndim == 3:
21
+ mask = rearrange(mask, "b n m -> b 1 n m")
22
+ if ndim == 2:
23
+ mask = repeat(mask, "n m -> b 1 n m", b=b)
24
+ max_neg_value = -torch.finfo(sim.dtype).max
25
+ sim = sim.masked_fill(~mask, max_neg_value)
26
+ return sim
27
+
28
+
29
+ def create_mask(q_shape, k_shape, device, q_mask=None, k_mask=None):
30
+ def default(val, d):
31
+ return val if val is not None else (d() if isfunction(d) else d)
32
+ b, i, j, device = q_shape[0], q_shape[-2], k_shape[-2], device
33
+ q_mask = default(q_mask, torch.ones((b, i), device=device, dtype=torch.bool))
34
+ k_mask = default(k_mask, torch.ones((b, j), device=device, dtype=torch.bool))
35
+ attn_mask = rearrange(q_mask, 'b i -> b 1 i 1') * rearrange(k_mask, 'b j -> b 1 1 j')
36
+ return attn_mask
37
+
38
+
39
+ class Attention(nn.Module):
40
+ def __init__(self, dim, context_dim=None, num_heads=8,
41
+ qkv_bias=False, qk_scale=None, qk_norm='layernorm',
42
+ attn_drop=0., proj_drop=0., rope_mode='shared'):
43
+ super().__init__()
44
+ self.num_heads = num_heads
45
+ head_dim = dim // num_heads
46
+ self.scale = qk_scale or head_dim ** -0.5
47
+
48
+ if context_dim is None:
49
+ self.cross_attn = False
50
+ else:
51
+ self.cross_attn = True
52
+
53
+ context_dim = dim if context_dim is None else context_dim
54
+
55
+ self.to_q = nn.Linear(dim, dim, bias=qkv_bias)
56
+ self.to_k = nn.Linear(context_dim, dim, bias=qkv_bias)
57
+ self.to_v = nn.Linear(context_dim, dim, bias=qkv_bias)
58
+
59
+ if qk_norm is None:
60
+ self.norm_q = nn.Identity()
61
+ self.norm_k = nn.Identity()
62
+ elif qk_norm == 'layernorm':
63
+ self.norm_q = nn.LayerNorm(head_dim)
64
+ self.norm_k = nn.LayerNorm(head_dim)
65
+ else:
66
+ raise NotImplementedError
67
+
68
+ self.attn_drop_p = attn_drop
69
+ self.attn_drop = nn.Dropout(attn_drop)
70
+ self.proj = nn.Linear(dim, dim)
71
+ self.proj_drop = nn.Dropout(proj_drop)
72
+
73
+ if self.cross_attn:
74
+ assert rope_mode == 'none'
75
+ self.rope_mode = rope_mode
76
+ if self.rope_mode == 'shared' or self.rope_mode == 'x_only':
77
+ self.rotary = RotaryEmbedding(dim=head_dim)
78
+ elif self.rope_mode == 'dual':
79
+ self.rotary_x = RotaryEmbedding(dim=head_dim)
80
+ self.rotary_c = RotaryEmbedding(dim=head_dim)
81
+
82
+ def _rotary(self, q, k, extras):
83
+ if self.rope_mode == 'shared':
84
+ q, k = self.rotary(q=q, k=k)
85
+ elif self.rope_mode == 'x_only':
86
+ q_x, k_x = self.rotary(q=q[:, :, extras:, :], k=k[:, :, extras:, :])
87
+ q_c, k_c = q[:, :, :extras, :], k[:, :, :extras, :]
88
+ q = torch.cat((q_c, q_x), dim=2)
89
+ k = torch.cat((k_c, k_x), dim=2)
90
+ elif self.rope_mode == 'dual':
91
+ q_x, k_x = self.rotary_x(q=q[:, :, extras:, :], k=k[:, :, extras:, :])
92
+ q_c, k_c = self.rotary_c(q=q[:, :, :extras, :], k=k[:, :, :extras, :])
93
+ q = torch.cat((q_c, q_x), dim=2)
94
+ k = torch.cat((k_c, k_x), dim=2)
95
+ elif self.rope_mode == 'none':
96
+ pass
97
+ else:
98
+ raise NotImplementedError
99
+ return q, k
100
+
101
+ def _attn(self, q, k, v, mask_binary):
102
+ if ATTENTION_MODE == 'flash':
103
+ x = F.scaled_dot_product_attention(q, k, v,
104
+ dropout_p=self.attn_drop_p,
105
+ attn_mask=mask_binary)
106
+ x = einops.rearrange(x, 'B H L D -> B L (H D)')
107
+ elif ATTENTION_MODE == 'math':
108
+ attn = (q @ k.transpose(-2, -1)) * self.scale
109
+ attn = add_mask(attn, mask_binary) if mask_binary is not None else attn
110
+ attn = attn.softmax(dim=-1)
111
+ attn = self.attn_drop(attn)
112
+ x = (attn @ v).transpose(1, 2)
113
+ x = einops.rearrange(x, 'B H L D -> B L (H D)')
114
+ else:
115
+ raise NotImplementedError
116
+ return x
117
+
118
+ def forward(self, x, context=None, context_mask=None, extras=0):
119
+ B, L, C = x.shape
120
+ if context is None:
121
+ context = x
122
+
123
+ q = self.to_q(x)
124
+ k = self.to_k(context)
125
+ v = self.to_v(context)
126
+
127
+ if context_mask is not None:
128
+ mask_binary = create_mask(x.shape, context.shape,
129
+ x.device, None, context_mask)
130
+ else:
131
+ mask_binary = None
132
+
133
+ q = einops.rearrange(q, 'B L (H D) -> B H L D', H=self.num_heads)
134
+ k = einops.rearrange(k, 'B L (H D) -> B H L D', H=self.num_heads)
135
+ v = einops.rearrange(v, 'B L (H D) -> B H L D', H=self.num_heads)
136
+
137
+ q = self.norm_q(q)
138
+ k = self.norm_k(k)
139
+
140
+ q, k = self._rotary(q, k, extras)
141
+
142
+ x = self._attn(q, k, v, mask_binary)
143
+
144
+ x = self.proj(x)
145
+ x = self.proj_drop(x)
146
+ return x
model/rotary.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ "this rope is faster than llama rope with jit script"
4
+
5
+
6
+ def rotate_half(x):
7
+ x1, x2 = x.chunk(2, dim=-1)
8
+ return torch.cat((-x2, x1), dim=-1)
9
+
10
+
11
+ # disable in checkpoint mode
12
+ # @torch.jit.script
13
+ def apply_rotary_pos_emb(x, cos, sin):
14
+ # NOTE: This could probably be moved to Triton
15
+ # Handle a possible sequence length mismatch in between q and k
16
+ cos = cos[:, :, : x.shape[-2], :]
17
+ sin = sin[:, :, : x.shape[-2], :]
18
+ return (x * cos) + (rotate_half(x) * sin)
19
+
20
+
21
+ class RotaryEmbedding(torch.nn.Module):
22
+ """
23
+ The rotary position embeddings from RoFormer_ (Su et. al).
24
+ A crucial insight from the method is that the query and keys are
25
+ transformed by rotation matrices which depend on the relative positions.
26
+
27
+ Other implementations are available in the Rotary Transformer repo_ and in
28
+ GPT-NeoX_, GPT-NeoX was an inspiration
29
+
30
+ .. _RoFormer: https://arxiv.org/abs/2104.09864
31
+ .. _repo: https://github.com/ZhuiyiTechnology/roformer
32
+ .. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox
33
+
34
+
35
+ .. warning: Please note that this embedding is not registered on purpose, as it is transformative
36
+ (it does not create the embedding dimension) and will likely be picked up (imported) on a ad-hoc basis
37
+ """
38
+
39
+ def __init__(self, dim: int):
40
+ super().__init__()
41
+ # Generate and save the inverse frequency buffer (non trainable)
42
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
43
+ self.register_buffer("inv_freq", inv_freq)
44
+ self._seq_len_cached = None
45
+ self._cos_cached = None
46
+ self._sin_cached = None
47
+
48
+ def _update_cos_sin_tables(self, x, seq_dimension=-2):
49
+ # expect input: B, H, L, D
50
+ seq_len = x.shape[seq_dimension]
51
+
52
+ # Reset the tables if the sequence length has changed,
53
+ # or if we're on a new device (possibly due to tracing for instance)
54
+ # also make sure dtype wont change
55
+ if (
56
+ seq_len != self._seq_len_cached
57
+ or self._cos_cached.device != x.device
58
+ or self._cos_cached.dtype != x.dtype
59
+ ):
60
+ self._seq_len_cached = seq_len
61
+ t = torch.arange(
62
+ x.shape[seq_dimension], device=x.device, dtype=torch.float32
63
+ )
64
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq.to(x.dtype))
65
+ emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
66
+
67
+ self._cos_cached = emb.cos()[None, None, :, :].to(x.dtype)
68
+ self._sin_cached = emb.sin()[None, None, :, :].to(x.dtype)
69
+
70
+ return self._cos_cached, self._sin_cached
71
+
72
+ def forward(self, q, k):
73
+ self._cos_cached, self._sin_cached = self._update_cos_sin_tables(
74
+ q.float(), seq_dimension=-2
75
+ )
76
+ if k is not None:
77
+ return (
78
+ apply_rotary_pos_emb(q.float(),
79
+ self._cos_cached,
80
+ self._sin_cached).type_as(q),
81
+ apply_rotary_pos_emb(k.float(),
82
+ self._cos_cached,
83
+ self._sin_cached).type_as(k),
84
+ )
85
+ else:
86
+ return (
87
+ apply_rotary_pos_emb(q.float(),
88
+ self._cos_cached,
89
+ self._sin_cached).type_as(q),
90
+ None
91
+ )
model/udit.py ADDED
@@ -0,0 +1,491 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # --------------------------------------------------------
7
+ # References:
8
+ # GLIDE: https://github.com/openai/glide-text2im
9
+ # MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
10
+ # --------------------------------------------------------
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import numpy as np
15
+ import math
16
+ import warnings
17
+ import einops
18
+ import torch.utils.checkpoint
19
+ import yaml
20
+ import torch.nn.functional as F
21
+ from .attention import Attention
22
+
23
+
24
+ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
25
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
26
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
27
+ def norm_cdf(x):
28
+ # Computes standard normal cumulative distribution function
29
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
30
+
31
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
32
+ warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
33
+ "The distribution of values may be incorrect.",
34
+ stacklevel=2)
35
+
36
+ with torch.no_grad():
37
+ # Values are generated by using a truncated uniform distribution and
38
+ # then using the inverse CDF for the normal distribution.
39
+ # Get upper and lower cdf values
40
+ l = norm_cdf((a - mean) / std)
41
+ u = norm_cdf((b - mean) / std)
42
+
43
+ # Uniformly fill tensor with values from [l, u], then translate to
44
+ # [2l-1, 2u-1].
45
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
46
+
47
+ # Use inverse cdf transform for normal distribution to get truncated
48
+ # standard normal
49
+ tensor.erfinv_()
50
+
51
+ # Transform to proper mean, std
52
+ tensor.mul_(std * math.sqrt(2.))
53
+ tensor.add_(mean)
54
+
55
+ # Clamp to ensure it's in the proper range
56
+ tensor.clamp_(min=a, max=b)
57
+ return tensor
58
+
59
+
60
+ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
61
+ # type: (Tensor, float, float, float, float) -> Tensor
62
+ r"""Fills the input Tensor with values drawn from a truncated
63
+ normal distribution. The values are effectively drawn from the
64
+ normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
65
+ with values outside :math:`[a, b]` redrawn until they are within
66
+ the bounds. The method used for generating the random values works
67
+ best when :math:`a \leq \text{mean} \leq b`.
68
+ Args:
69
+ tensor: an n-dimensional `torch.Tensor`
70
+ mean: the mean of the normal distribution
71
+ std: the standard deviation of the normal distribution
72
+ a: the minimum cutoff value
73
+ b: the maximum cutoff value
74
+ Examples:
75
+ >>> w = torch.empty(3, 5)
76
+ >>> nn.init.trunc_normal_(w)
77
+ """
78
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
79
+
80
+
81
+ class Mlp(nn.Module):
82
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
83
+ super().__init__()
84
+ out_features = out_features or in_features
85
+ hidden_features = hidden_features or in_features
86
+ self.fc1 = nn.Linear(in_features, hidden_features)
87
+ self.act = act_layer()
88
+ self.fc2 = nn.Linear(hidden_features, out_features)
89
+ self.drop = nn.Dropout(drop)
90
+
91
+ def forward(self, x):
92
+ x = self.fc1(x)
93
+ x = self.act(x)
94
+ x = self.drop(x)
95
+ x = self.fc2(x)
96
+ x = self.drop(x)
97
+ return x
98
+
99
+
100
+
101
+ class PositionalConvEmbedding(nn.Module):
102
+ """
103
+ Relative positional embedding used in HuBERT
104
+ """
105
+
106
+ def __init__(self, dim=768, kernel_size=128, groups=16):
107
+ super().__init__()
108
+ self.conv = nn.Conv1d(
109
+ dim,
110
+ dim,
111
+ kernel_size=kernel_size,
112
+ padding=kernel_size // 2,
113
+ groups=groups,
114
+ bias=True
115
+ )
116
+ self.conv = nn.utils.parametrizations.weight_norm(self.conv, name="weight", dim=2)
117
+
118
+ def forward(self, x):
119
+ x = x.transpose(2, 1)
120
+ # B C T
121
+ x = self.conv(x)
122
+ x = F.gelu(x[:, :, :-1])
123
+ x = x.transpose(2, 1)
124
+ return x
125
+
126
+
127
+ class SinusoidalPositionalEncoding(nn.Module):
128
+ def __init__(self, dim, length):
129
+ super(SinusoidalPositionalEncoding, self).__init__()
130
+ self.length = length
131
+ self.dim = dim
132
+ self.register_buffer('pe', self._generate_positional_encoding(length, dim))
133
+
134
+ def _generate_positional_encoding(self, length, dim):
135
+ pe = torch.zeros(length, dim)
136
+ position = torch.arange(0, length, dtype=torch.float).unsqueeze(1)
137
+ div_term = torch.exp(torch.arange(0, dim, 2).float() * (-math.log(10000.0) / dim))
138
+
139
+ pe[:, 0::2] = torch.sin(position * div_term)
140
+ pe[:, 1::2] = torch.cos(position * div_term)
141
+
142
+ pe = pe.unsqueeze(0)
143
+ return pe
144
+
145
+ def forward(self, x):
146
+ x = x + self.pe[:, :x.size(1)]
147
+ return x
148
+
149
+
150
+ class PE_wrapper(nn.Module):
151
+ def __init__(self, dim=768, method='none', length=None):
152
+ super().__init__()
153
+ self.method = method
154
+ if method == 'abs':
155
+ # init absolute pe like UViT
156
+ self.length = length
157
+ self.abs_pe = nn.Parameter(torch.zeros(1, length, dim))
158
+ trunc_normal_(self.abs_pe, std=.02)
159
+ elif method == 'conv':
160
+ self.conv_pe = PositionalConvEmbedding(dim=dim)
161
+ elif method == 'sinu':
162
+ self.sinu_pe = SinusoidalPositionalEncoding(dim=dim, length=length)
163
+ elif method == 'none':
164
+ # skip pe
165
+ self.id = nn.Identity()
166
+ else:
167
+ raise NotImplementedError
168
+
169
+ def forward(self, x):
170
+ if self.method == 'abs':
171
+ _, L, _ = x.shape
172
+ assert L <= self.length
173
+ x = x + self.abs_pe[:, :L, :]
174
+ elif self.method == 'conv':
175
+ x = x + self.conv_pe(x)
176
+ elif self.method == 'sinu':
177
+ x = self.sinu_pe(x)
178
+ elif self.method == 'none':
179
+ x = self.id(x)
180
+ else:
181
+ raise NotImplementedError
182
+ return x
183
+
184
+
185
+ def modulate(x, shift, scale):
186
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
187
+
188
+
189
+ #################################################################################
190
+ # Embedding Layers for Timesteps and Class Labels #
191
+ #################################################################################
192
+
193
+ class TimestepEmbedder(nn.Module):
194
+ """
195
+ Embeds scalar timesteps into vector representations.
196
+ """
197
+ def __init__(self, hidden_size, frequency_embedding_size=256):
198
+ super().__init__()
199
+ self.mlp = nn.Sequential(
200
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
201
+ nn.SiLU(),
202
+ nn.Linear(hidden_size, hidden_size, bias=True),
203
+ )
204
+ self.frequency_embedding_size = frequency_embedding_size
205
+
206
+ @staticmethod
207
+ def timestep_embedding(t, dim, max_period=10000):
208
+ """
209
+ Create sinusoidal timestep embeddings.
210
+ :param t: a 1-D Tensor of N indices, one per batch element.
211
+ These may be fractional.
212
+ :param dim: the dimension of the output.
213
+ :param max_period: controls the minimum frequency of the embeddings.
214
+ :return: an (N, D) Tensor of positional embeddings.
215
+ """
216
+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
217
+ half = dim // 2
218
+ freqs = torch.exp(
219
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
220
+ ).to(device=t.device)
221
+ args = t[:, None].float() * freqs[None]
222
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
223
+ if dim % 2:
224
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
225
+ return embedding
226
+
227
+ def forward(self, t):
228
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
229
+ t_emb = self.mlp(t_freq)
230
+ return t_emb
231
+
232
+
233
+ class LabelEmbedder(nn.Module):
234
+ """
235
+ Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
236
+ """
237
+ def __init__(self, num_classes, hidden_size, dropout_prob):
238
+ super().__init__()
239
+ use_cfg_embedding = dropout_prob > 0
240
+ self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
241
+ self.num_classes = num_classes
242
+ self.dropout_prob = dropout_prob
243
+
244
+ def token_drop(self, labels, force_drop_ids=None):
245
+ """
246
+ Drops labels to enable classifier-free guidance.
247
+ """
248
+ if force_drop_ids is None:
249
+ drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
250
+ else:
251
+ drop_ids = force_drop_ids == 1
252
+ labels = torch.where(drop_ids, self.num_classes, labels)
253
+ return labels
254
+
255
+ def forward(self, labels, train, force_drop_ids=None):
256
+ use_dropout = self.dropout_prob > 0
257
+ if (train and use_dropout) or (force_drop_ids is not None):
258
+ labels = self.token_drop(labels, force_drop_ids)
259
+ embeddings = self.embedding_table(labels)
260
+ return embeddings
261
+
262
+
263
+ #################################################################################
264
+ # Core DiT Model #
265
+ #################################################################################
266
+
267
+ class DiTBlock(nn.Module):
268
+ """
269
+ A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
270
+ """
271
+ def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, skip=False, skip_norm=True, use_checkpoint=True, **block_kwargs):
272
+ super().__init__()
273
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
274
+ self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
275
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
276
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
277
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
278
+ self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
279
+ self.adaLN_modulation = nn.Sequential(
280
+ nn.SiLU(),
281
+ nn.Linear(hidden_size, 6 * hidden_size, bias=True)
282
+ )
283
+ self.skip_linear = nn.Linear(2 * hidden_size, hidden_size) if skip else None
284
+ self.skip_norm = nn.LayerNorm(2 * hidden_size, elementwise_affine=False, eps=1e-6) if skip_norm else nn.Identity()
285
+ self.use_checkpoint = use_checkpoint
286
+
287
+ def forward(self, x, c, skip=None):
288
+ if self.use_checkpoint:
289
+ return torch.utils.checkpoint.checkpoint(self._forward, x, c, skip)
290
+ else:
291
+ return self._forward(x, c, skip)
292
+
293
+ def _forward(self, x, c, skip=None):
294
+ if self.skip_linear is not None:
295
+ cat = torch.cat([x, skip], dim=-1)
296
+ cat = self.skip_norm(cat)
297
+ x = self.skip_linear(cat)
298
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
299
+ x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
300
+ x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
301
+ return x
302
+
303
+
304
+ class FinalLayer(nn.Module):
305
+ """
306
+ The final layer of DiT.
307
+ """
308
+ def __init__(self, hidden_size, output_dim):
309
+ super().__init__()
310
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
311
+ self.linear = nn.Linear(hidden_size, output_dim, bias=True)
312
+ self.adaLN_modulation = nn.Sequential(
313
+ nn.SiLU(),
314
+ nn.Linear(hidden_size, 2 * hidden_size, bias=True)
315
+ )
316
+
317
+ def forward(self, x, c):
318
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
319
+ x = modulate(self.norm_final(x), shift, scale)
320
+ x = self.linear(x)
321
+ return x
322
+
323
+
324
+ class UDiT(nn.Module):
325
+ """
326
+ Diffusion model with a Transformer backbone.
327
+ """
328
+ def __init__(
329
+ self,
330
+ input_dim=256,
331
+ output_dim=128,
332
+ pos_method='none',
333
+ pos_length=500,
334
+ timbre_dim=512,
335
+ hidden_size=1152,
336
+ depth=28,
337
+ num_heads=16,
338
+ mlp_ratio=4.0,
339
+ use_checkpoint=True
340
+ ):
341
+ super().__init__()
342
+ self.num_heads = num_heads
343
+ self.input_proj = nn.Linear(input_dim, hidden_size, bias=True)
344
+ self.t_embedder = TimestepEmbedder(hidden_size)
345
+ self.pos_embed = PE_wrapper(dim=hidden_size, method=pos_method, length=pos_length)
346
+ self.timbre_proj = nn.Linear(timbre_dim, hidden_size, bias=True)
347
+
348
+ self.in_blocks = nn.ModuleList([
349
+ DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, use_checkpoint=use_checkpoint) for _ in range(depth // 2)
350
+ ])
351
+ self.mid_block = DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, use_checkpoint=use_checkpoint)
352
+ self.out_blocks = nn.ModuleList([
353
+ DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, skip=True, use_checkpoint=use_checkpoint) for _ in range(depth // 2)
354
+ ])
355
+
356
+ self.final_layer = FinalLayer(hidden_size, output_dim)
357
+ self.initialize_weights()
358
+
359
+ def initialize_weights(self):
360
+ # Initialize transformer layers:
361
+ def _basic_init(module):
362
+ if isinstance(module, nn.Linear):
363
+ torch.nn.init.xavier_uniform_(module.weight)
364
+ if module.bias is not None:
365
+ nn.init.constant_(module.bias, 0)
366
+ self.apply(_basic_init)
367
+
368
+ # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
369
+ nn.init.normal_(self.input_proj.weight, std=0.02)
370
+ nn.init.normal_(self.timbre_proj.weight, std=0.02)
371
+
372
+ # Initialize timestep embedding MLP:
373
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
374
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
375
+
376
+ # Zero-out adaLN modulation layers in DiT blocks:
377
+ for block in self.in_blocks:
378
+ nn.init.constant_(self.mid_block.adaLN_modulation[-1].weight, 0)
379
+ nn.init.constant_(self.mid_block.adaLN_modulation[-1].bias, 0)
380
+
381
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
382
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
383
+
384
+ for block in self.out_blocks:
385
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
386
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
387
+
388
+ # Zero-out output layers:
389
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
390
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
391
+ nn.init.constant_(self.final_layer.linear.weight, 0)
392
+ nn.init.constant_(self.final_layer.linear.bias, 0)
393
+
394
+ def forward(self, x, timesteps, mixture, timbre):
395
+ """
396
+ Forward pass of DiT.
397
+ x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
398
+ t: (N,) tensor of diffusion timesteps
399
+ y: (N,) tensor of class labels
400
+ """
401
+ x = x.transpose(2,1)
402
+ mixture = mixture.transpose(2,1)
403
+ x = self.input_proj(torch.cat((x, mixture), dim=-1))
404
+ x = self.pos_embed(x)
405
+ if not torch.is_tensor(timesteps):
406
+ timesteps = torch.tensor([timesteps], dtype=torch.long, device=x.device)
407
+ elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
408
+ timesteps = timesteps[None].to(x.device)
409
+ t = self.t_embedder(timesteps) # (N, D)
410
+ timbre = self.timbre_proj(timbre)
411
+ c = t + timbre # (N, D)
412
+
413
+ skips = []
414
+ for blk in self.in_blocks:
415
+ x = blk(x, c)
416
+ skips.append(x)
417
+
418
+ x = self.mid_block(x, c)
419
+
420
+ for blk in self.out_blocks:
421
+ x = blk(x, c, skips.pop())
422
+
423
+ x = self.final_layer(x, c) # (N, T, out_dim)
424
+ x = x.transpose(2, 1)
425
+ return x
426
+
427
+
428
+ #################################################################################
429
+ # DiT Configs #
430
+ #################################################################################
431
+
432
+ def DiT_XL_2(**kwargs):
433
+ return DiT(depth=28, hidden_size=1152, patch_size=2, num_heads=16, **kwargs)
434
+
435
+ def DiT_XL_4(**kwargs):
436
+ return DiT(depth=28, hidden_size=1152, patch_size=4, num_heads=16, **kwargs)
437
+
438
+ def DiT_XL_8(**kwargs):
439
+ return DiT(depth=28, hidden_size=1152, patch_size=8, num_heads=16, **kwargs)
440
+
441
+ def DiT_L_2(**kwargs):
442
+ return DiT(depth=24, hidden_size=1024, patch_size=2, num_heads=16, **kwargs)
443
+
444
+ def DiT_L_4(**kwargs):
445
+ return DiT(depth=24, hidden_size=1024, patch_size=4, num_heads=16, **kwargs)
446
+
447
+ def DiT_L_8(**kwargs):
448
+ return DiT(depth=24, hidden_size=1024, patch_size=8, num_heads=16, **kwargs)
449
+
450
+ def DiT_B_2(**kwargs):
451
+ return DiT(depth=12, hidden_size=768, patch_size=2, num_heads=12, **kwargs)
452
+
453
+ def DiT_B_4(**kwargs):
454
+ return DiT(depth=12, hidden_size=768, patch_size=4, num_heads=12, **kwargs)
455
+
456
+ def DiT_B_8(**kwargs):
457
+ return DiT(depth=12, hidden_size=768, patch_size=8, num_heads=12, **kwargs)
458
+
459
+ def DiT_S_2(**kwargs):
460
+ return DiT(depth=12, hidden_size=384, patch_size=2, num_heads=6, **kwargs)
461
+
462
+ def DiT_S_4(**kwargs):
463
+ return DiT(depth=12, hidden_size=384, patch_size=4, num_heads=6, **kwargs)
464
+
465
+ def DiT_S_8(**kwargs):
466
+ return DiT(depth=12, hidden_size=384, patch_size=8, num_heads=6, **kwargs)
467
+
468
+
469
+ DiT_models = {
470
+ 'DiT-XL/2': DiT_XL_2, 'DiT-XL/4': DiT_XL_4, 'DiT-XL/8': DiT_XL_8,
471
+ 'DiT-L/2': DiT_L_2, 'DiT-L/4': DiT_L_4, 'DiT-L/8': DiT_L_8,
472
+ 'DiT-B/2': DiT_B_2, 'DiT-B/4': DiT_B_4, 'DiT-B/8': DiT_B_8,
473
+ 'DiT-S/2': DiT_S_2, 'DiT-S/4': DiT_S_4, 'DiT-S/8': DiT_S_8,
474
+ }
475
+
476
+ if __name__ == "__main__":
477
+ with open('/export/corpora7/HW/DPMTSE-main/src/config/DiffTSE_udit_conv_v_b_1000.yaml', 'r') as fp:
478
+ config = yaml.safe_load(fp)
479
+ device = 'cuda'
480
+
481
+ model = UDiT(
482
+ **config['diffwrap']['UDiT']
483
+ ).to(device)
484
+
485
+ x = torch.rand((1, 128, 150)).to(device)
486
+ t = torch.randint(0, 1000, (1, )).long().to(device)
487
+ mixture = torch.rand((1, 128, 150)).to(device)
488
+ timbre = torch.rand((1, 512)).to(device)
489
+
490
+ y = model(x, t, mixture, timbre)
491
+ print(y.shape)
pretrained_models/config.json ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "autoencoder",
3
+ "sample_size": 12000,
4
+ "sample_rate": 24000,
5
+ "audio_channels": 1,
6
+ "model": {
7
+ "encoder": {
8
+ "type": "oobleck",
9
+ "config": {
10
+ "in_channels": 1,
11
+ "channels": 128,
12
+ "c_mults": [1, 2, 4, 8],
13
+ "strides": [2, 4, 6, 10],
14
+ "latent_dim": 256,
15
+ "use_snake": true
16
+ }
17
+ },
18
+ "decoder": {
19
+ "type": "oobleck",
20
+ "config": {
21
+ "out_channels": 1,
22
+ "channels": 128,
23
+ "c_mults": [1, 2, 4, 8],
24
+ "strides": [2, 4, 6, 10],
25
+ "latent_dim": 128,
26
+ "use_snake": true,
27
+ "final_tanh": false
28
+ }
29
+ },
30
+ "bottleneck": {
31
+ "type": "vae"
32
+ },
33
+ "latent_dim": 128,
34
+ "downsampling_ratio": 480,
35
+ "io_channels": 1
36
+ },
37
+ "training": {
38
+ "learning_rate": 1.5e-4,
39
+ "warmup_steps": 0,
40
+ "use_ema": false,
41
+ "optimizer_configs": {
42
+ "autoencoder": {
43
+ "optimizer": {
44
+ "type": "AdamW",
45
+ "config": {
46
+ "betas": [0.8, 0.99],
47
+ "lr": 1.5e-4,
48
+ "weight_decay": 1e-3
49
+ }
50
+ },
51
+ "scheduler": {
52
+ "type": "InverseLR",
53
+ "config": {
54
+ "inv_gamma": 200000,
55
+ "power": 0.5,
56
+ "warmup": 0.999
57
+ }
58
+ }
59
+ },
60
+ "discriminator": {
61
+ "optimizer": {
62
+ "type": "AdamW",
63
+ "config": {
64
+ "betas": [0.8, 0.99],
65
+ "lr": 3e-4,
66
+ "weight_decay": 1e-3
67
+ }
68
+ },
69
+ "scheduler": {
70
+ "type": "InverseLR",
71
+ "config": {
72
+ "inv_gamma": 200000,
73
+ "power": 0.5,
74
+ "warmup": 0.999
75
+ }
76
+ }
77
+ }
78
+ },
79
+ "loss_configs": {
80
+ "discriminator": {
81
+ "type": "encodec",
82
+ "config": {
83
+ "filters": 64,
84
+ "n_ffts": [1280, 640, 320, 160, 80],
85
+ "hop_lengths": [320, 160, 80, 40, 20],
86
+ "win_lengths": [1280, 640, 320, 160, 80]
87
+ },
88
+ "weights": {
89
+ "adversarial": 0.1,
90
+ "feature_matching": 5.0
91
+ }
92
+ },
93
+ "spectral": {
94
+ "type": "mrstft",
95
+ "config": {
96
+ "fft_sizes": [1280, 640, 320, 160, 80, 40, 20],
97
+ "hop_sizes": [320, 160, 80, 40, 20, 10, 5],
98
+ "win_lengths": [1280, 640, 320, 160, 80, 40, 20],
99
+ "perceptual_weighting": true
100
+ },
101
+ "weights": {
102
+ "mrstft": 1.0
103
+ }
104
+ },
105
+ "time": {
106
+ "type": "l1",
107
+ "weights": {
108
+ "l1": 0.0
109
+ }
110
+ },
111
+ "bottleneck": {
112
+ "type": "kl",
113
+ "weights": {
114
+ "kl": 1e-4
115
+ }
116
+ }
117
+ },
118
+ "demo": {
119
+ "demo_every": 10000
120
+ }
121
+ }
122
+ }
vae_modules/.DS_Store ADDED
Binary file (6.15 kB). View file
 
vae_modules/autoencoder_wrapper.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from .dac import DAC
4
+ from .stable_vae import load_vae
5
+
6
+
7
+ class Autoencoder(nn.Module):
8
+ def __init__(self, ckpt_path, model_type='stable_vae', quantization_first=True):
9
+ super(Autoencoder, self).__init__()
10
+ self.model_type = model_type
11
+ if self.model_type == 'dac':
12
+ model = DAC.load(ckpt_path)
13
+ elif self.model_type == 'stable_vae':
14
+ model = load_vae(ckpt_path)
15
+ else:
16
+ raise NotImplementedError(f"Model type not implemented: {self.model_type}")
17
+ self.ae = model.eval()
18
+ self.quantization_first = quantization_first
19
+ print(f'Autoencoder quantization first mode: {quantization_first}')
20
+
21
+ @torch.no_grad()
22
+ def forward(self, audio=None, embedding=None):
23
+ if self.model_type == 'dac':
24
+ return self.process_dac(audio, embedding)
25
+ elif self.model_type == 'encodec':
26
+ return self.process_encodec(audio, embedding)
27
+ elif self.model_type == 'stable_vae':
28
+ return self.process_stable_vae(audio, embedding)
29
+ else:
30
+ raise NotImplementedError(f"Model type not implemented: {self.model_type}")
31
+
32
+ def process_dac(self, audio=None, embedding=None):
33
+ if audio is not None:
34
+ z = self.ae.encoder(audio)
35
+ if self.quantization_first:
36
+ z, *_ = self.ae.quantizer(z, None)
37
+ return z
38
+ elif embedding is not None:
39
+ z = embedding
40
+ if self.quantization_first:
41
+ audio = self.ae.decoder(z)
42
+ else:
43
+ z, *_ = self.ae.quantizer(z, None)
44
+ audio = self.ae.decoder(z)
45
+ return audio
46
+ else:
47
+ raise ValueError("Either audio or embedding must be provided.")
48
+
49
+ def process_encodec(self, audio=None, embedding=None):
50
+ if audio is not None:
51
+ z = self.ae.encoder(audio)
52
+ if self.quantization_first:
53
+ code = self.ae.quantizer.encode(z)
54
+ z = self.ae.quantizer.decode(code)
55
+ return z
56
+ elif embedding is not None:
57
+ z = embedding
58
+ if self.quantization_first:
59
+ audio = self.ae.decoder(z)
60
+ else:
61
+ code = self.ae.quantizer.encode(z)
62
+ z = self.ae.quantizer.decode(code)
63
+ audio = self.ae.decoder(z)
64
+ return audio
65
+ else:
66
+ raise ValueError("Either audio or embedding must be provided.")
67
+
68
+ def process_stable_vae(self, audio=None, embedding=None):
69
+ if audio is not None:
70
+ z = self.ae.encoder(audio)
71
+ if self.quantization_first:
72
+ z = self.ae.bottleneck.encode(z)
73
+ return z
74
+ if embedding is not None:
75
+ z = embedding
76
+ if self.quantization_first:
77
+ audio = self.ae.decoder(z)
78
+ else:
79
+ z = self.ae.bottleneck.encode(z)
80
+ audio = self.ae.decoder(z)
81
+ return audio
82
+ else:
83
+ raise ValueError("Either audio or embedding must be provided.")
vae_modules/clap_wrapper.py ADDED
File without changes
vae_modules/dac/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __version__ = "1.0.0"
2
+
3
+ # preserved here for legacy reasons
4
+ __model_version__ = "latest"
5
+
6
+ import audiotools
7
+
8
+ audiotools.ml.BaseModel.INTERN += ["dac.**"]
9
+ audiotools.ml.BaseModel.EXTERN += ["einops"]
10
+
11
+
12
+ from . import nn
13
+ from . import model
14
+ from . import utils
15
+ from .model import DAC
16
+ from .model import DACFile
vae_modules/dac/__main__.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+
3
+ import argbind
4
+
5
+ from dac.utils import download
6
+ from dac.utils.decode import decode
7
+ from dac.utils.encode import encode
8
+
9
+ STAGES = ["encode", "decode", "download"]
10
+
11
+
12
+ def run(stage: str):
13
+ """Run stages.
14
+
15
+ Parameters
16
+ ----------
17
+ stage : str
18
+ Stage to run
19
+ """
20
+ if stage not in STAGES:
21
+ raise ValueError(f"Unknown command: {stage}. Allowed commands are {STAGES}")
22
+ stage_fn = globals()[stage]
23
+
24
+ if stage == "download":
25
+ stage_fn()
26
+ return
27
+
28
+ stage_fn()
29
+
30
+
31
+ if __name__ == "__main__":
32
+ group = sys.argv.pop(1)
33
+ args = argbind.parse_args(group=group)
34
+
35
+ with argbind.scope(args):
36
+ run(group)
vae_modules/dac/compare/__init__.py ADDED
File without changes
vae_modules/dac/compare/encodec.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from audiotools import AudioSignal
3
+ from audiotools.ml import BaseModel
4
+ from encodec import EncodecModel
5
+
6
+
7
+ class Encodec(BaseModel):
8
+ def __init__(self, sample_rate: int = 24000, bandwidth: float = 24.0):
9
+ super().__init__()
10
+
11
+ if sample_rate == 24000:
12
+ self.model = EncodecModel.encodec_model_24khz()
13
+ else:
14
+ self.model = EncodecModel.encodec_model_48khz()
15
+ self.model.set_target_bandwidth(bandwidth)
16
+ self.sample_rate = 44100
17
+
18
+ def forward(
19
+ self,
20
+ audio_data: torch.Tensor,
21
+ sample_rate: int = 44100,
22
+ n_quantizers: int = None,
23
+ ):
24
+ signal = AudioSignal(audio_data, sample_rate)
25
+ signal.resample(self.model.sample_rate)
26
+ recons = self.model(signal.audio_data)
27
+ recons = AudioSignal(recons, self.model.sample_rate)
28
+ recons.resample(sample_rate)
29
+ return {"audio": recons.audio_data}
30
+
31
+
32
+ if __name__ == "__main__":
33
+ import numpy as np
34
+ from functools import partial
35
+
36
+ model = Encodec()
37
+
38
+ for n, m in model.named_modules():
39
+ o = m.extra_repr()
40
+ p = sum([np.prod(p.size()) for p in m.parameters()])
41
+ fn = lambda o, p: o + f" {p/1e6:<.3f}M params."
42
+ setattr(m, "extra_repr", partial(fn, o=o, p=p))
43
+ print(model)
44
+ print("Total # of params: ", sum([np.prod(p.size()) for p in model.parameters()]))
45
+
46
+ length = 88200 * 2
47
+ x = torch.randn(1, 1, length).to(model.device)
48
+ x.requires_grad_(True)
49
+ x.retain_grad()
50
+
51
+ # Make a forward pass
52
+ out = model(x)["audio"]
53
+
54
+ print(x.shape, out.shape)
vae_modules/dac/model/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .base import CodecMixin
2
+ from .base import DACFile
3
+ from .dac import DAC
4
+ from .discriminator import Discriminator
vae_modules/dac/model/base.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+ from pathlib import Path
4
+ from typing import Union
5
+
6
+ import numpy as np
7
+ import torch
8
+ import tqdm
9
+ from audiotools import AudioSignal
10
+ from torch import nn
11
+
12
+ SUPPORTED_VERSIONS = ["1.0.0"]
13
+
14
+
15
+ @dataclass
16
+ class DACFile:
17
+ codes: torch.Tensor
18
+
19
+ # Metadata
20
+ chunk_length: int
21
+ original_length: int
22
+ input_db: float
23
+ channels: int
24
+ sample_rate: int
25
+ padding: bool
26
+ dac_version: str
27
+
28
+ def save(self, path):
29
+ artifacts = {
30
+ "codes": self.codes.numpy().astype(np.uint16),
31
+ "metadata": {
32
+ "input_db": self.input_db.numpy().astype(np.float32),
33
+ "original_length": self.original_length,
34
+ "sample_rate": self.sample_rate,
35
+ "chunk_length": self.chunk_length,
36
+ "channels": self.channels,
37
+ "padding": self.padding,
38
+ "dac_version": SUPPORTED_VERSIONS[-1],
39
+ },
40
+ }
41
+ path = Path(path).with_suffix(".dac")
42
+ with open(path, "wb") as f:
43
+ np.save(f, artifacts)
44
+ return path
45
+
46
+ @classmethod
47
+ def load(cls, path):
48
+ artifacts = np.load(path, allow_pickle=True)[()]
49
+ codes = torch.from_numpy(artifacts["codes"].astype(int))
50
+ if artifacts["metadata"].get("dac_version", None) not in SUPPORTED_VERSIONS:
51
+ raise RuntimeError(
52
+ f"Given file {path} can't be loaded with this version of descript-audio-codec."
53
+ )
54
+ return cls(codes=codes, **artifacts["metadata"])
55
+
56
+
57
+ class CodecMixin:
58
+ @property
59
+ def padding(self):
60
+ if not hasattr(self, "_padding"):
61
+ self._padding = True
62
+ return self._padding
63
+
64
+ @padding.setter
65
+ def padding(self, value):
66
+ assert isinstance(value, bool)
67
+
68
+ layers = [
69
+ l for l in self.modules() if isinstance(l, (nn.Conv1d, nn.ConvTranspose1d))
70
+ ]
71
+
72
+ for layer in layers:
73
+ if value:
74
+ if hasattr(layer, "original_padding"):
75
+ layer.padding = layer.original_padding
76
+ else:
77
+ layer.original_padding = layer.padding
78
+ layer.padding = tuple(0 for _ in range(len(layer.padding)))
79
+
80
+ self._padding = value
81
+
82
+ def get_delay(self):
83
+ # Any number works here, delay is invariant to input length
84
+ l_out = self.get_output_length(0)
85
+ L = l_out
86
+
87
+ layers = []
88
+ for layer in self.modules():
89
+ if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)):
90
+ layers.append(layer)
91
+
92
+ for layer in reversed(layers):
93
+ d = layer.dilation[0]
94
+ k = layer.kernel_size[0]
95
+ s = layer.stride[0]
96
+
97
+ if isinstance(layer, nn.ConvTranspose1d):
98
+ L = ((L - d * (k - 1) - 1) / s) + 1
99
+ elif isinstance(layer, nn.Conv1d):
100
+ L = (L - 1) * s + d * (k - 1) + 1
101
+
102
+ L = math.ceil(L)
103
+
104
+ l_in = L
105
+
106
+ return (l_in - l_out) // 2
107
+
108
+ def get_output_length(self, input_length):
109
+ L = input_length
110
+ # Calculate output length
111
+ for layer in self.modules():
112
+ if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)):
113
+ d = layer.dilation[0]
114
+ k = layer.kernel_size[0]
115
+ s = layer.stride[0]
116
+
117
+ if isinstance(layer, nn.Conv1d):
118
+ L = ((L - d * (k - 1) - 1) / s) + 1
119
+ elif isinstance(layer, nn.ConvTranspose1d):
120
+ L = (L - 1) * s + d * (k - 1) + 1
121
+
122
+ L = math.floor(L)
123
+ return L
124
+
125
+ @torch.no_grad()
126
+ def compress(
127
+ self,
128
+ audio_path_or_signal: Union[str, Path, AudioSignal],
129
+ win_duration: float = 1.0,
130
+ verbose: bool = False,
131
+ normalize_db: float = -16,
132
+ n_quantizers: int = None,
133
+ ) -> DACFile:
134
+ """Processes an audio signal from a file or AudioSignal object into
135
+ discrete codes. This function processes the signal in short windows,
136
+ using constant GPU memory.
137
+
138
+ Parameters
139
+ ----------
140
+ audio_path_or_signal : Union[str, Path, AudioSignal]
141
+ audio signal to reconstruct
142
+ win_duration : float, optional
143
+ window duration in seconds, by default 5.0
144
+ verbose : bool, optional
145
+ by default False
146
+ normalize_db : float, optional
147
+ normalize db, by default -16
148
+
149
+ Returns
150
+ -------
151
+ DACFile
152
+ Object containing compressed codes and metadata
153
+ required for decompression
154
+ """
155
+ audio_signal = audio_path_or_signal
156
+ if isinstance(audio_signal, (str, Path)):
157
+ audio_signal = AudioSignal.load_from_file_with_ffmpeg(str(audio_signal))
158
+
159
+ self.eval()
160
+ original_padding = self.padding
161
+ original_device = audio_signal.device
162
+
163
+ audio_signal = audio_signal.clone()
164
+ original_sr = audio_signal.sample_rate
165
+
166
+ resample_fn = audio_signal.resample
167
+ loudness_fn = audio_signal.loudness
168
+
169
+ # If audio is > 10 minutes long, use the ffmpeg versions
170
+ if audio_signal.signal_duration >= 10 * 60 * 60:
171
+ resample_fn = audio_signal.ffmpeg_resample
172
+ loudness_fn = audio_signal.ffmpeg_loudness
173
+
174
+ original_length = audio_signal.signal_length
175
+ resample_fn(self.sample_rate)
176
+ input_db = loudness_fn()
177
+
178
+ if normalize_db is not None:
179
+ audio_signal.normalize(normalize_db)
180
+ audio_signal.ensure_max_of_audio()
181
+
182
+ nb, nac, nt = audio_signal.audio_data.shape
183
+ audio_signal.audio_data = audio_signal.audio_data.reshape(nb * nac, 1, nt)
184
+ win_duration = (
185
+ audio_signal.signal_duration if win_duration is None else win_duration
186
+ )
187
+
188
+ if audio_signal.signal_duration <= win_duration:
189
+ # Unchunked compression (used if signal length < win duration)
190
+ self.padding = True
191
+ n_samples = nt
192
+ hop = nt
193
+ else:
194
+ # Chunked inference
195
+ self.padding = False
196
+ # Zero-pad signal on either side by the delay
197
+ audio_signal.zero_pad(self.delay, self.delay)
198
+ n_samples = int(win_duration * self.sample_rate)
199
+ # Round n_samples to nearest hop length multiple
200
+ n_samples = int(math.ceil(n_samples / self.hop_length) * self.hop_length)
201
+ hop = self.get_output_length(n_samples)
202
+
203
+ codes = []
204
+ range_fn = range if not verbose else tqdm.trange
205
+
206
+ for i in range_fn(0, nt, hop):
207
+ x = audio_signal[..., i : i + n_samples]
208
+ x = x.zero_pad(0, max(0, n_samples - x.shape[-1]))
209
+
210
+ audio_data = x.audio_data.to(self.device)
211
+ audio_data = self.preprocess(audio_data, self.sample_rate)
212
+ _, c, _, _, _ = self.encode(audio_data, n_quantizers)
213
+ codes.append(c.to(original_device))
214
+ chunk_length = c.shape[-1]
215
+
216
+ codes = torch.cat(codes, dim=-1)
217
+
218
+ dac_file = DACFile(
219
+ codes=codes,
220
+ chunk_length=chunk_length,
221
+ original_length=original_length,
222
+ input_db=input_db,
223
+ channels=nac,
224
+ sample_rate=original_sr,
225
+ padding=self.padding,
226
+ dac_version=SUPPORTED_VERSIONS[-1],
227
+ )
228
+
229
+ if n_quantizers is not None:
230
+ codes = codes[:, :n_quantizers, :]
231
+
232
+ self.padding = original_padding
233
+ return dac_file
234
+
235
+ @torch.no_grad()
236
+ def decompress(
237
+ self,
238
+ obj: Union[str, Path, DACFile],
239
+ verbose: bool = False,
240
+ ) -> AudioSignal:
241
+ """Reconstruct audio from a given .dac file
242
+
243
+ Parameters
244
+ ----------
245
+ obj : Union[str, Path, DACFile]
246
+ .dac file location or corresponding DACFile object.
247
+ verbose : bool, optional
248
+ Prints progress if True, by default False
249
+
250
+ Returns
251
+ -------
252
+ AudioSignal
253
+ Object with the reconstructed audio
254
+ """
255
+ self.eval()
256
+ if isinstance(obj, (str, Path)):
257
+ obj = DACFile.load(obj)
258
+
259
+ original_padding = self.padding
260
+ self.padding = obj.padding
261
+
262
+ range_fn = range if not verbose else tqdm.trange
263
+ codes = obj.codes
264
+ original_device = codes.device
265
+ chunk_length = obj.chunk_length
266
+ recons = []
267
+
268
+ for i in range_fn(0, codes.shape[-1], chunk_length):
269
+ c = codes[..., i : i + chunk_length].to(self.device)
270
+ z = self.quantizer.from_codes(c)[0]
271
+ r = self.decode(z)
272
+ recons.append(r.to(original_device))
273
+
274
+ recons = torch.cat(recons, dim=-1)
275
+ recons = AudioSignal(recons, self.sample_rate)
276
+
277
+ resample_fn = recons.resample
278
+ loudness_fn = recons.loudness
279
+
280
+ # If audio is > 10 minutes long, use the ffmpeg versions
281
+ if recons.signal_duration >= 10 * 60 * 60:
282
+ resample_fn = recons.ffmpeg_resample
283
+ loudness_fn = recons.ffmpeg_loudness
284
+
285
+ recons.normalize(obj.input_db)
286
+ resample_fn(obj.sample_rate)
287
+ recons = recons[..., : obj.original_length]
288
+ loudness_fn()
289
+ recons.audio_data = recons.audio_data.reshape(
290
+ -1, obj.channels, obj.original_length
291
+ )
292
+
293
+ self.padding = original_padding
294
+ return recons
vae_modules/dac/model/dac.py ADDED
@@ -0,0 +1,364 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import List
3
+ from typing import Union
4
+
5
+ import numpy as np
6
+ import torch
7
+ from audiotools import AudioSignal
8
+ from audiotools.ml import BaseModel
9
+ from torch import nn
10
+
11
+ from .base import CodecMixin
12
+ from ..nn.layers import Snake1d
13
+ from ..nn.layers import WNConv1d
14
+ from ..nn.layers import WNConvTranspose1d
15
+ from ..nn.quantize import ResidualVectorQuantize
16
+
17
+
18
+ def init_weights(m):
19
+ if isinstance(m, nn.Conv1d):
20
+ nn.init.trunc_normal_(m.weight, std=0.02)
21
+ nn.init.constant_(m.bias, 0)
22
+
23
+
24
+ class ResidualUnit(nn.Module):
25
+ def __init__(self, dim: int = 16, dilation: int = 1):
26
+ super().__init__()
27
+ pad = ((7 - 1) * dilation) // 2
28
+ self.block = nn.Sequential(
29
+ Snake1d(dim),
30
+ WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad),
31
+ Snake1d(dim),
32
+ WNConv1d(dim, dim, kernel_size=1),
33
+ )
34
+
35
+ def forward(self, x):
36
+ y = self.block(x)
37
+ pad = (x.shape[-1] - y.shape[-1]) // 2
38
+ if pad > 0:
39
+ x = x[..., pad:-pad]
40
+ return x + y
41
+
42
+
43
+ class EncoderBlock(nn.Module):
44
+ def __init__(self, dim: int = 16, stride: int = 1):
45
+ super().__init__()
46
+ self.block = nn.Sequential(
47
+ ResidualUnit(dim // 2, dilation=1),
48
+ ResidualUnit(dim // 2, dilation=3),
49
+ ResidualUnit(dim // 2, dilation=9),
50
+ Snake1d(dim // 2),
51
+ WNConv1d(
52
+ dim // 2,
53
+ dim,
54
+ kernel_size=2 * stride,
55
+ stride=stride,
56
+ padding=math.ceil(stride / 2),
57
+ ),
58
+ )
59
+
60
+ def forward(self, x):
61
+ return self.block(x)
62
+
63
+
64
+ class Encoder(nn.Module):
65
+ def __init__(
66
+ self,
67
+ d_model: int = 64,
68
+ strides: list = [2, 4, 8, 8],
69
+ d_latent: int = 64,
70
+ ):
71
+ super().__init__()
72
+ # Create first convolution
73
+ self.block = [WNConv1d(1, d_model, kernel_size=7, padding=3)]
74
+
75
+ # Create EncoderBlocks that double channels as they downsample by `stride`
76
+ for stride in strides:
77
+ d_model *= 2
78
+ self.block += [EncoderBlock(d_model, stride=stride)]
79
+
80
+ # Create last convolution
81
+ self.block += [
82
+ Snake1d(d_model),
83
+ WNConv1d(d_model, d_latent, kernel_size=3, padding=1),
84
+ ]
85
+
86
+ # Wrap black into nn.Sequential
87
+ self.block = nn.Sequential(*self.block)
88
+ self.enc_dim = d_model
89
+
90
+ def forward(self, x):
91
+ return self.block(x)
92
+
93
+
94
+ class DecoderBlock(nn.Module):
95
+ def __init__(self, input_dim: int = 16, output_dim: int = 8, stride: int = 1):
96
+ super().__init__()
97
+ self.block = nn.Sequential(
98
+ Snake1d(input_dim),
99
+ WNConvTranspose1d(
100
+ input_dim,
101
+ output_dim,
102
+ kernel_size=2 * stride,
103
+ stride=stride,
104
+ padding=math.ceil(stride / 2),
105
+ ),
106
+ ResidualUnit(output_dim, dilation=1),
107
+ ResidualUnit(output_dim, dilation=3),
108
+ ResidualUnit(output_dim, dilation=9),
109
+ )
110
+
111
+ def forward(self, x):
112
+ return self.block(x)
113
+
114
+
115
+ class Decoder(nn.Module):
116
+ def __init__(
117
+ self,
118
+ input_channel,
119
+ channels,
120
+ rates,
121
+ d_out: int = 1,
122
+ ):
123
+ super().__init__()
124
+
125
+ # Add first conv layer
126
+ layers = [WNConv1d(input_channel, channels, kernel_size=7, padding=3)]
127
+
128
+ # Add upsampling + MRF blocks
129
+ for i, stride in enumerate(rates):
130
+ input_dim = channels // 2**i
131
+ output_dim = channels // 2 ** (i + 1)
132
+ layers += [DecoderBlock(input_dim, output_dim, stride)]
133
+
134
+ # Add final conv layer
135
+ layers += [
136
+ Snake1d(output_dim),
137
+ WNConv1d(output_dim, d_out, kernel_size=7, padding=3),
138
+ nn.Tanh(),
139
+ ]
140
+
141
+ self.model = nn.Sequential(*layers)
142
+
143
+ def forward(self, x):
144
+ return self.model(x)
145
+
146
+
147
+ class DAC(BaseModel, CodecMixin):
148
+ def __init__(
149
+ self,
150
+ encoder_dim: int = 64,
151
+ encoder_rates: List[int] = [2, 4, 8, 8],
152
+ latent_dim: int = None,
153
+ decoder_dim: int = 1536,
154
+ decoder_rates: List[int] = [8, 8, 4, 2],
155
+ n_codebooks: int = 9,
156
+ codebook_size: int = 1024,
157
+ codebook_dim: Union[int, list] = 8,
158
+ quantizer_dropout: bool = False,
159
+ sample_rate: int = 44100,
160
+ ):
161
+ super().__init__()
162
+
163
+ self.encoder_dim = encoder_dim
164
+ self.encoder_rates = encoder_rates
165
+ self.decoder_dim = decoder_dim
166
+ self.decoder_rates = decoder_rates
167
+ self.sample_rate = sample_rate
168
+
169
+ if latent_dim is None:
170
+ latent_dim = encoder_dim * (2 ** len(encoder_rates))
171
+
172
+ self.latent_dim = latent_dim
173
+
174
+ self.hop_length = np.prod(encoder_rates)
175
+ self.encoder = Encoder(encoder_dim, encoder_rates, latent_dim)
176
+
177
+ self.n_codebooks = n_codebooks
178
+ self.codebook_size = codebook_size
179
+ self.codebook_dim = codebook_dim
180
+ self.quantizer = ResidualVectorQuantize(
181
+ input_dim=latent_dim,
182
+ n_codebooks=n_codebooks,
183
+ codebook_size=codebook_size,
184
+ codebook_dim=codebook_dim,
185
+ quantizer_dropout=quantizer_dropout,
186
+ )
187
+
188
+ self.decoder = Decoder(
189
+ latent_dim,
190
+ decoder_dim,
191
+ decoder_rates,
192
+ )
193
+ self.sample_rate = sample_rate
194
+ self.apply(init_weights)
195
+
196
+ self.delay = self.get_delay()
197
+
198
+ def preprocess(self, audio_data, sample_rate):
199
+ if sample_rate is None:
200
+ sample_rate = self.sample_rate
201
+ assert sample_rate == self.sample_rate
202
+
203
+ length = audio_data.shape[-1]
204
+ right_pad = math.ceil(length / self.hop_length) * self.hop_length - length
205
+ audio_data = nn.functional.pad(audio_data, (0, right_pad))
206
+
207
+ return audio_data
208
+
209
+ def encode(
210
+ self,
211
+ audio_data: torch.Tensor,
212
+ n_quantizers: int = None,
213
+ ):
214
+ """Encode given audio data and return quantized latent codes
215
+
216
+ Parameters
217
+ ----------
218
+ audio_data : Tensor[B x 1 x T]
219
+ Audio data to encode
220
+ n_quantizers : int, optional
221
+ Number of quantizers to use, by default None
222
+ If None, all quantizers are used.
223
+
224
+ Returns
225
+ -------
226
+ dict
227
+ A dictionary with the following keys:
228
+ "z" : Tensor[B x D x T]
229
+ Quantized continuous representation of input
230
+ "codes" : Tensor[B x N x T]
231
+ Codebook indices for each codebook
232
+ (quantized discrete representation of input)
233
+ "latents" : Tensor[B x N*D x T]
234
+ Projected latents (continuous representation of input before quantization)
235
+ "vq/commitment_loss" : Tensor[1]
236
+ Commitment loss to train encoder to predict vectors closer to codebook
237
+ entries
238
+ "vq/codebook_loss" : Tensor[1]
239
+ Codebook loss to update the codebook
240
+ "length" : int
241
+ Number of samples in input audio
242
+ """
243
+ z = self.encoder(audio_data)
244
+ z, codes, latents, commitment_loss, codebook_loss = self.quantizer(
245
+ z, n_quantizers
246
+ )
247
+ return z, codes, latents, commitment_loss, codebook_loss
248
+
249
+ def decode(self, z: torch.Tensor):
250
+ """Decode given latent codes and return audio data
251
+
252
+ Parameters
253
+ ----------
254
+ z : Tensor[B x D x T]
255
+ Quantized continuous representation of input
256
+ length : int, optional
257
+ Number of samples in output audio, by default None
258
+
259
+ Returns
260
+ -------
261
+ dict
262
+ A dictionary with the following keys:
263
+ "audio" : Tensor[B x 1 x length]
264
+ Decoded audio data.
265
+ """
266
+ return self.decoder(z)
267
+
268
+ def forward(
269
+ self,
270
+ audio_data: torch.Tensor,
271
+ sample_rate: int = None,
272
+ n_quantizers: int = None,
273
+ ):
274
+ """Model forward pass
275
+
276
+ Parameters
277
+ ----------
278
+ audio_data : Tensor[B x 1 x T]
279
+ Audio data to encode
280
+ sample_rate : int, optional
281
+ Sample rate of audio data in Hz, by default None
282
+ If None, defaults to `self.sample_rate`
283
+ n_quantizers : int, optional
284
+ Number of quantizers to use, by default None.
285
+ If None, all quantizers are used.
286
+
287
+ Returns
288
+ -------
289
+ dict
290
+ A dictionary with the following keys:
291
+ "z" : Tensor[B x D x T]
292
+ Quantized continuous representation of input
293
+ "codes" : Tensor[B x N x T]
294
+ Codebook indices for each codebook
295
+ (quantized discrete representation of input)
296
+ "latents" : Tensor[B x N*D x T]
297
+ Projected latents (continuous representation of input before quantization)
298
+ "vq/commitment_loss" : Tensor[1]
299
+ Commitment loss to train encoder to predict vectors closer to codebook
300
+ entries
301
+ "vq/codebook_loss" : Tensor[1]
302
+ Codebook loss to update the codebook
303
+ "length" : int
304
+ Number of samples in input audio
305
+ "audio" : Tensor[B x 1 x length]
306
+ Decoded audio data.
307
+ """
308
+ length = audio_data.shape[-1]
309
+ audio_data = self.preprocess(audio_data, sample_rate)
310
+ z, codes, latents, commitment_loss, codebook_loss = self.encode(
311
+ audio_data, n_quantizers
312
+ )
313
+
314
+ x = self.decode(z)
315
+ return {
316
+ "audio": x[..., :length],
317
+ "z": z,
318
+ "codes": codes,
319
+ "latents": latents,
320
+ "vq/commitment_loss": commitment_loss,
321
+ "vq/codebook_loss": codebook_loss,
322
+ }
323
+
324
+
325
+ if __name__ == "__main__":
326
+ import numpy as np
327
+ from functools import partial
328
+
329
+ model = DAC().to("cpu")
330
+
331
+ for n, m in model.named_modules():
332
+ o = m.extra_repr()
333
+ p = sum([np.prod(p.size()) for p in m.parameters()])
334
+ fn = lambda o, p: o + f" {p/1e6:<.3f}M params."
335
+ setattr(m, "extra_repr", partial(fn, o=o, p=p))
336
+ print(model)
337
+ print("Total # of params: ", sum([np.prod(p.size()) for p in model.parameters()]))
338
+
339
+ length = 88200 * 2
340
+ x = torch.randn(1, 1, length).to(model.device)
341
+ x.requires_grad_(True)
342
+ x.retain_grad()
343
+
344
+ # Make a forward pass
345
+ out = model(x)["audio"]
346
+ print("Input shape:", x.shape)
347
+ print("Output shape:", out.shape)
348
+
349
+ # Create gradient variable
350
+ grad = torch.zeros_like(out)
351
+ grad[:, :, grad.shape[-1] // 2] = 1
352
+
353
+ # Make a backward pass
354
+ out.backward(grad)
355
+
356
+ # Check non-zero values
357
+ gradmap = x.grad.squeeze(0)
358
+ gradmap = (gradmap != 0).sum(0) # sum across features
359
+ rf = (gradmap != 0).sum()
360
+
361
+ print(f"Receptive field: {rf.item()}")
362
+
363
+ x = AudioSignal(torch.randn(1, 1, 44100 * 60), 44100)
364
+ model.decompress(model.compress(x, verbose=True), verbose=True)
vae_modules/dac/model/discriminator.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from audiotools import AudioSignal
5
+ from audiotools import ml
6
+ from audiotools import STFTParams
7
+ from einops import rearrange
8
+ from torch.nn.utils import weight_norm
9
+
10
+
11
+ def WNConv1d(*args, **kwargs):
12
+ act = kwargs.pop("act", True)
13
+ conv = weight_norm(nn.Conv1d(*args, **kwargs))
14
+ if not act:
15
+ return conv
16
+ return nn.Sequential(conv, nn.LeakyReLU(0.1))
17
+
18
+
19
+ def WNConv2d(*args, **kwargs):
20
+ act = kwargs.pop("act", True)
21
+ conv = weight_norm(nn.Conv2d(*args, **kwargs))
22
+ if not act:
23
+ return conv
24
+ return nn.Sequential(conv, nn.LeakyReLU(0.1))
25
+
26
+
27
+ class MPD(nn.Module):
28
+ def __init__(self, period):
29
+ super().__init__()
30
+ self.period = period
31
+ self.convs = nn.ModuleList(
32
+ [
33
+ WNConv2d(1, 32, (5, 1), (3, 1), padding=(2, 0)),
34
+ WNConv2d(32, 128, (5, 1), (3, 1), padding=(2, 0)),
35
+ WNConv2d(128, 512, (5, 1), (3, 1), padding=(2, 0)),
36
+ WNConv2d(512, 1024, (5, 1), (3, 1), padding=(2, 0)),
37
+ WNConv2d(1024, 1024, (5, 1), 1, padding=(2, 0)),
38
+ ]
39
+ )
40
+ self.conv_post = WNConv2d(
41
+ 1024, 1, kernel_size=(3, 1), padding=(1, 0), act=False
42
+ )
43
+
44
+ def pad_to_period(self, x):
45
+ t = x.shape[-1]
46
+ x = F.pad(x, (0, self.period - t % self.period), mode="reflect")
47
+ return x
48
+
49
+ def forward(self, x):
50
+ fmap = []
51
+
52
+ x = self.pad_to_period(x)
53
+ x = rearrange(x, "b c (l p) -> b c l p", p=self.period)
54
+
55
+ for layer in self.convs:
56
+ x = layer(x)
57
+ fmap.append(x)
58
+
59
+ x = self.conv_post(x)
60
+ fmap.append(x)
61
+
62
+ return fmap
63
+
64
+
65
+ class MSD(nn.Module):
66
+ def __init__(self, rate: int = 1, sample_rate: int = 44100):
67
+ super().__init__()
68
+ self.convs = nn.ModuleList(
69
+ [
70
+ WNConv1d(1, 16, 15, 1, padding=7),
71
+ WNConv1d(16, 64, 41, 4, groups=4, padding=20),
72
+ WNConv1d(64, 256, 41, 4, groups=16, padding=20),
73
+ WNConv1d(256, 1024, 41, 4, groups=64, padding=20),
74
+ WNConv1d(1024, 1024, 41, 4, groups=256, padding=20),
75
+ WNConv1d(1024, 1024, 5, 1, padding=2),
76
+ ]
77
+ )
78
+ self.conv_post = WNConv1d(1024, 1, 3, 1, padding=1, act=False)
79
+ self.sample_rate = sample_rate
80
+ self.rate = rate
81
+
82
+ def forward(self, x):
83
+ x = AudioSignal(x, self.sample_rate)
84
+ x.resample(self.sample_rate // self.rate)
85
+ x = x.audio_data
86
+
87
+ fmap = []
88
+
89
+ for l in self.convs:
90
+ x = l(x)
91
+ fmap.append(x)
92
+ x = self.conv_post(x)
93
+ fmap.append(x)
94
+
95
+ return fmap
96
+
97
+
98
+ BANDS = [(0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)]
99
+
100
+
101
+ class MRD(nn.Module):
102
+ def __init__(
103
+ self,
104
+ window_length: int,
105
+ hop_factor: float = 0.25,
106
+ sample_rate: int = 44100,
107
+ bands: list = BANDS,
108
+ ):
109
+ """Complex multi-band spectrogram discriminator.
110
+ Parameters
111
+ ----------
112
+ window_length : int
113
+ Window length of STFT.
114
+ hop_factor : float, optional
115
+ Hop factor of the STFT, defaults to ``0.25 * window_length``.
116
+ sample_rate : int, optional
117
+ Sampling rate of audio in Hz, by default 44100
118
+ bands : list, optional
119
+ Bands to run discriminator over.
120
+ """
121
+ super().__init__()
122
+
123
+ self.window_length = window_length
124
+ self.hop_factor = hop_factor
125
+ self.sample_rate = sample_rate
126
+ self.stft_params = STFTParams(
127
+ window_length=window_length,
128
+ hop_length=int(window_length * hop_factor),
129
+ match_stride=True,
130
+ )
131
+
132
+ n_fft = window_length // 2 + 1
133
+ bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands]
134
+ self.bands = bands
135
+
136
+ ch = 32
137
+ convs = lambda: nn.ModuleList(
138
+ [
139
+ WNConv2d(2, ch, (3, 9), (1, 1), padding=(1, 4)),
140
+ WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)),
141
+ WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)),
142
+ WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)),
143
+ WNConv2d(ch, ch, (3, 3), (1, 1), padding=(1, 1)),
144
+ ]
145
+ )
146
+ self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))])
147
+ self.conv_post = WNConv2d(ch, 1, (3, 3), (1, 1), padding=(1, 1), act=False)
148
+
149
+ def spectrogram(self, x):
150
+ x = AudioSignal(x, self.sample_rate, stft_params=self.stft_params)
151
+ x = torch.view_as_real(x.stft())
152
+ x = rearrange(x, "b 1 f t c -> (b 1) c t f")
153
+ # Split into bands
154
+ x_bands = [x[..., b[0] : b[1]] for b in self.bands]
155
+ return x_bands
156
+
157
+ def forward(self, x):
158
+ x_bands = self.spectrogram(x)
159
+ fmap = []
160
+
161
+ x = []
162
+ for band, stack in zip(x_bands, self.band_convs):
163
+ for layer in stack:
164
+ band = layer(band)
165
+ fmap.append(band)
166
+ x.append(band)
167
+
168
+ x = torch.cat(x, dim=-1)
169
+ x = self.conv_post(x)
170
+ fmap.append(x)
171
+
172
+ return fmap
173
+
174
+
175
+ class Discriminator(ml.BaseModel):
176
+ def __init__(
177
+ self,
178
+ rates: list = [],
179
+ periods: list = [2, 3, 5, 7, 11],
180
+ fft_sizes: list = [2048, 1024, 512],
181
+ sample_rate: int = 44100,
182
+ bands: list = BANDS,
183
+ ):
184
+ """Discriminator that combines multiple discriminators.
185
+
186
+ Parameters
187
+ ----------
188
+ rates : list, optional
189
+ sampling rates (in Hz) to run MSD at, by default []
190
+ If empty, MSD is not used.
191
+ periods : list, optional
192
+ periods (of samples) to run MPD at, by default [2, 3, 5, 7, 11]
193
+ fft_sizes : list, optional
194
+ Window sizes of the FFT to run MRD at, by default [2048, 1024, 512]
195
+ sample_rate : int, optional
196
+ Sampling rate of audio in Hz, by default 44100
197
+ bands : list, optional
198
+ Bands to run MRD at, by default `BANDS`
199
+ """
200
+ super().__init__()
201
+ discs = []
202
+ discs += [MPD(p) for p in periods]
203
+ discs += [MSD(r, sample_rate=sample_rate) for r in rates]
204
+ discs += [MRD(f, sample_rate=sample_rate, bands=bands) for f in fft_sizes]
205
+ self.discriminators = nn.ModuleList(discs)
206
+
207
+ def preprocess(self, y):
208
+ # Remove DC offset
209
+ y = y - y.mean(dim=-1, keepdims=True)
210
+ # Peak normalize the volume of input audio
211
+ y = 0.8 * y / (y.abs().max(dim=-1, keepdim=True)[0] + 1e-9)
212
+ return y
213
+
214
+ def forward(self, x):
215
+ x = self.preprocess(x)
216
+ fmaps = [d(x) for d in self.discriminators]
217
+ return fmaps
218
+
219
+
220
+ if __name__ == "__main__":
221
+ disc = Discriminator()
222
+ x = torch.zeros(1, 1, 44100)
223
+ results = disc(x)
224
+ for i, result in enumerate(results):
225
+ print(f"disc{i}")
226
+ for i, r in enumerate(result):
227
+ print(r.shape, r.mean(), r.min(), r.max())
228
+ print()
vae_modules/dac/nn/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from . import layers
2
+ from . import loss
3
+ from . import quantize
vae_modules/dac/nn/layers.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from einops import rearrange
6
+ from torch.nn.utils import weight_norm
7
+
8
+
9
+ def WNConv1d(*args, **kwargs):
10
+ return weight_norm(nn.Conv1d(*args, **kwargs))
11
+
12
+
13
+ def WNConvTranspose1d(*args, **kwargs):
14
+ return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
15
+
16
+
17
+ # Scripting this brings model speed up 1.4x
18
+ @torch.jit.script
19
+ def snake(x, alpha):
20
+ shape = x.shape
21
+ x = x.reshape(shape[0], shape[1], -1)
22
+ x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
23
+ x = x.reshape(shape)
24
+ return x
25
+
26
+
27
+ class Snake1d(nn.Module):
28
+ def __init__(self, channels):
29
+ super().__init__()
30
+ self.alpha = nn.Parameter(torch.ones(1, channels, 1))
31
+
32
+ def forward(self, x):
33
+ return snake(x, self.alpha)
vae_modules/dac/nn/loss.py ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import typing
2
+ from typing import List
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from audiotools import AudioSignal
7
+ from audiotools import STFTParams
8
+ from torch import nn
9
+
10
+
11
+ class L1Loss(nn.L1Loss):
12
+ """L1 Loss between AudioSignals. Defaults
13
+ to comparing ``audio_data``, but any
14
+ attribute of an AudioSignal can be used.
15
+
16
+ Parameters
17
+ ----------
18
+ attribute : str, optional
19
+ Attribute of signal to compare, defaults to ``audio_data``.
20
+ weight : float, optional
21
+ Weight of this loss, defaults to 1.0.
22
+
23
+ Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/distance.py
24
+ """
25
+
26
+ def __init__(self, attribute: str = "audio_data", weight: float = 1.0, **kwargs):
27
+ self.attribute = attribute
28
+ self.weight = weight
29
+ super().__init__(**kwargs)
30
+
31
+ def forward(self, x: AudioSignal, y: AudioSignal):
32
+ """
33
+ Parameters
34
+ ----------
35
+ x : AudioSignal
36
+ Estimate AudioSignal
37
+ y : AudioSignal
38
+ Reference AudioSignal
39
+
40
+ Returns
41
+ -------
42
+ torch.Tensor
43
+ L1 loss between AudioSignal attributes.
44
+ """
45
+ if isinstance(x, AudioSignal):
46
+ x = getattr(x, self.attribute)
47
+ y = getattr(y, self.attribute)
48
+ return super().forward(x, y)
49
+
50
+
51
+ class SISDRLoss(nn.Module):
52
+ """
53
+ Computes the Scale-Invariant Source-to-Distortion Ratio between a batch
54
+ of estimated and reference audio signals or aligned features.
55
+
56
+ Parameters
57
+ ----------
58
+ scaling : int, optional
59
+ Whether to use scale-invariant (True) or
60
+ signal-to-noise ratio (False), by default True
61
+ reduction : str, optional
62
+ How to reduce across the batch (either 'mean',
63
+ 'sum', or none).], by default ' mean'
64
+ zero_mean : int, optional
65
+ Zero mean the references and estimates before
66
+ computing the loss, by default True
67
+ clip_min : int, optional
68
+ The minimum possible loss value. Helps network
69
+ to not focus on making already good examples better, by default None
70
+ weight : float, optional
71
+ Weight of this loss, defaults to 1.0.
72
+
73
+ Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/distance.py
74
+ """
75
+
76
+ def __init__(
77
+ self,
78
+ scaling: int = True,
79
+ reduction: str = "mean",
80
+ zero_mean: int = True,
81
+ clip_min: int = None,
82
+ weight: float = 1.0,
83
+ ):
84
+ self.scaling = scaling
85
+ self.reduction = reduction
86
+ self.zero_mean = zero_mean
87
+ self.clip_min = clip_min
88
+ self.weight = weight
89
+ super().__init__()
90
+
91
+ def forward(self, x: AudioSignal, y: AudioSignal):
92
+ eps = 1e-8
93
+ # nb, nc, nt
94
+ if isinstance(x, AudioSignal):
95
+ references = x.audio_data
96
+ estimates = y.audio_data
97
+ else:
98
+ references = x
99
+ estimates = y
100
+
101
+ nb = references.shape[0]
102
+ references = references.reshape(nb, 1, -1).permute(0, 2, 1)
103
+ estimates = estimates.reshape(nb, 1, -1).permute(0, 2, 1)
104
+
105
+ # samples now on axis 1
106
+ if self.zero_mean:
107
+ mean_reference = references.mean(dim=1, keepdim=True)
108
+ mean_estimate = estimates.mean(dim=1, keepdim=True)
109
+ else:
110
+ mean_reference = 0
111
+ mean_estimate = 0
112
+
113
+ _references = references - mean_reference
114
+ _estimates = estimates - mean_estimate
115
+
116
+ references_projection = (_references**2).sum(dim=-2) + eps
117
+ references_on_estimates = (_estimates * _references).sum(dim=-2) + eps
118
+
119
+ scale = (
120
+ (references_on_estimates / references_projection).unsqueeze(1)
121
+ if self.scaling
122
+ else 1
123
+ )
124
+
125
+ e_true = scale * _references
126
+ e_res = _estimates - e_true
127
+
128
+ signal = (e_true**2).sum(dim=1)
129
+ noise = (e_res**2).sum(dim=1)
130
+ sdr = -10 * torch.log10(signal / noise + eps)
131
+
132
+ if self.clip_min is not None:
133
+ sdr = torch.clamp(sdr, min=self.clip_min)
134
+
135
+ if self.reduction == "mean":
136
+ sdr = sdr.mean()
137
+ elif self.reduction == "sum":
138
+ sdr = sdr.sum()
139
+ return sdr
140
+
141
+
142
+ class MultiScaleSTFTLoss(nn.Module):
143
+ """Computes the multi-scale STFT loss from [1].
144
+
145
+ Parameters
146
+ ----------
147
+ window_lengths : List[int], optional
148
+ Length of each window of each STFT, by default [2048, 512]
149
+ loss_fn : typing.Callable, optional
150
+ How to compare each loss, by default nn.L1Loss()
151
+ clamp_eps : float, optional
152
+ Clamp on the log magnitude, below, by default 1e-5
153
+ mag_weight : float, optional
154
+ Weight of raw magnitude portion of loss, by default 1.0
155
+ log_weight : float, optional
156
+ Weight of log magnitude portion of loss, by default 1.0
157
+ pow : float, optional
158
+ Power to raise magnitude to before taking log, by default 2.0
159
+ weight : float, optional
160
+ Weight of this loss, by default 1.0
161
+ match_stride : bool, optional
162
+ Whether to match the stride of convolutional layers, by default False
163
+
164
+ References
165
+ ----------
166
+
167
+ 1. Engel, Jesse, Chenjie Gu, and Adam Roberts.
168
+ "DDSP: Differentiable Digital Signal Processing."
169
+ International Conference on Learning Representations. 2019.
170
+
171
+ Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py
172
+ """
173
+
174
+ def __init__(
175
+ self,
176
+ window_lengths: List[int] = [2048, 512],
177
+ loss_fn: typing.Callable = nn.L1Loss(),
178
+ clamp_eps: float = 1e-5,
179
+ mag_weight: float = 1.0,
180
+ log_weight: float = 1.0,
181
+ pow: float = 2.0,
182
+ weight: float = 1.0,
183
+ match_stride: bool = False,
184
+ window_type: str = None,
185
+ ):
186
+ super().__init__()
187
+ self.stft_params = [
188
+ STFTParams(
189
+ window_length=w,
190
+ hop_length=w // 4,
191
+ match_stride=match_stride,
192
+ window_type=window_type,
193
+ )
194
+ for w in window_lengths
195
+ ]
196
+ self.loss_fn = loss_fn
197
+ self.log_weight = log_weight
198
+ self.mag_weight = mag_weight
199
+ self.clamp_eps = clamp_eps
200
+ self.weight = weight
201
+ self.pow = pow
202
+
203
+ def forward(self, x: AudioSignal, y: AudioSignal):
204
+ """Computes multi-scale STFT between an estimate and a reference
205
+ signal.
206
+
207
+ Parameters
208
+ ----------
209
+ x : AudioSignal
210
+ Estimate signal
211
+ y : AudioSignal
212
+ Reference signal
213
+
214
+ Returns
215
+ -------
216
+ torch.Tensor
217
+ Multi-scale STFT loss.
218
+ """
219
+ loss = 0.0
220
+ for s in self.stft_params:
221
+ x.stft(s.window_length, s.hop_length, s.window_type)
222
+ y.stft(s.window_length, s.hop_length, s.window_type)
223
+ loss += self.log_weight * self.loss_fn(
224
+ x.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(),
225
+ y.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(),
226
+ )
227
+ loss += self.mag_weight * self.loss_fn(x.magnitude, y.magnitude)
228
+ return loss
229
+
230
+
231
+ class MelSpectrogramLoss(nn.Module):
232
+ """Compute distance between mel spectrograms. Can be used
233
+ in a multi-scale way.
234
+
235
+ Parameters
236
+ ----------
237
+ n_mels : List[int]
238
+ Number of mels per STFT, by default [150, 80],
239
+ window_lengths : List[int], optional
240
+ Length of each window of each STFT, by default [2048, 512]
241
+ loss_fn : typing.Callable, optional
242
+ How to compare each loss, by default nn.L1Loss()
243
+ clamp_eps : float, optional
244
+ Clamp on the log magnitude, below, by default 1e-5
245
+ mag_weight : float, optional
246
+ Weight of raw magnitude portion of loss, by default 1.0
247
+ log_weight : float, optional
248
+ Weight of log magnitude portion of loss, by default 1.0
249
+ pow : float, optional
250
+ Power to raise magnitude to before taking log, by default 2.0
251
+ weight : float, optional
252
+ Weight of this loss, by default 1.0
253
+ match_stride : bool, optional
254
+ Whether to match the stride of convolutional layers, by default False
255
+
256
+ Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py
257
+ """
258
+
259
+ def __init__(
260
+ self,
261
+ n_mels: List[int] = [150, 80],
262
+ window_lengths: List[int] = [2048, 512],
263
+ loss_fn: typing.Callable = nn.L1Loss(),
264
+ clamp_eps: float = 1e-5,
265
+ mag_weight: float = 1.0,
266
+ log_weight: float = 1.0,
267
+ pow: float = 2.0,
268
+ weight: float = 1.0,
269
+ match_stride: bool = False,
270
+ mel_fmin: List[float] = [0.0, 0.0],
271
+ mel_fmax: List[float] = [None, None],
272
+ window_type: str = None,
273
+ ):
274
+ super().__init__()
275
+ self.stft_params = [
276
+ STFTParams(
277
+ window_length=w,
278
+ hop_length=w // 4,
279
+ match_stride=match_stride,
280
+ window_type=window_type,
281
+ )
282
+ for w in window_lengths
283
+ ]
284
+ self.n_mels = n_mels
285
+ self.loss_fn = loss_fn
286
+ self.clamp_eps = clamp_eps
287
+ self.log_weight = log_weight
288
+ self.mag_weight = mag_weight
289
+ self.weight = weight
290
+ self.mel_fmin = mel_fmin
291
+ self.mel_fmax = mel_fmax
292
+ self.pow = pow
293
+
294
+ def forward(self, x: AudioSignal, y: AudioSignal):
295
+ """Computes mel loss between an estimate and a reference
296
+ signal.
297
+
298
+ Parameters
299
+ ----------
300
+ x : AudioSignal
301
+ Estimate signal
302
+ y : AudioSignal
303
+ Reference signal
304
+
305
+ Returns
306
+ -------
307
+ torch.Tensor
308
+ Mel loss.
309
+ """
310
+ loss = 0.0
311
+ for n_mels, fmin, fmax, s in zip(
312
+ self.n_mels, self.mel_fmin, self.mel_fmax, self.stft_params
313
+ ):
314
+ kwargs = {
315
+ "window_length": s.window_length,
316
+ "hop_length": s.hop_length,
317
+ "window_type": s.window_type,
318
+ }
319
+ x_mels = x.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs)
320
+ y_mels = y.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs)
321
+
322
+ loss += self.log_weight * self.loss_fn(
323
+ x_mels.clamp(self.clamp_eps).pow(self.pow).log10(),
324
+ y_mels.clamp(self.clamp_eps).pow(self.pow).log10(),
325
+ )
326
+ loss += self.mag_weight * self.loss_fn(x_mels, y_mels)
327
+ return loss
328
+
329
+
330
+ class GANLoss(nn.Module):
331
+ """
332
+ Computes a discriminator loss, given a discriminator on
333
+ generated waveforms/spectrograms compared to ground truth
334
+ waveforms/spectrograms. Computes the loss for both the
335
+ discriminator and the generator in separate functions.
336
+ """
337
+
338
+ def __init__(self, discriminator):
339
+ super().__init__()
340
+ self.discriminator = discriminator
341
+
342
+ def forward(self, fake, real):
343
+ d_fake = self.discriminator(fake.audio_data)
344
+ d_real = self.discriminator(real.audio_data)
345
+ return d_fake, d_real
346
+
347
+ def discriminator_loss(self, fake, real):
348
+ d_fake, d_real = self.forward(fake.clone().detach(), real)
349
+
350
+ loss_d = 0
351
+ for x_fake, x_real in zip(d_fake, d_real):
352
+ loss_d += torch.mean(x_fake[-1] ** 2)
353
+ loss_d += torch.mean((1 - x_real[-1]) ** 2)
354
+ return loss_d
355
+
356
+ def generator_loss(self, fake, real):
357
+ d_fake, d_real = self.forward(fake, real)
358
+
359
+ loss_g = 0
360
+ for x_fake in d_fake:
361
+ loss_g += torch.mean((1 - x_fake[-1]) ** 2)
362
+
363
+ loss_feature = 0
364
+
365
+ for i in range(len(d_fake)):
366
+ for j in range(len(d_fake[i]) - 1):
367
+ loss_feature += F.l1_loss(d_fake[i][j], d_real[i][j].detach())
368
+ return loss_g, loss_feature
vae_modules/dac/nn/quantize.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from einops import rearrange
8
+ from torch.nn.utils import weight_norm
9
+
10
+ from .layers import WNConv1d
11
+
12
+
13
+ class VectorQuantize(nn.Module):
14
+ """
15
+ Implementation of VQ similar to Karpathy's repo:
16
+ https://github.com/karpathy/deep-vector-quantization
17
+ Additionally uses following tricks from Improved VQGAN
18
+ (https://arxiv.org/pdf/2110.04627.pdf):
19
+ 1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space
20
+ for improved codebook usage
21
+ 2. l2-normalized codes: Converts euclidean distance to cosine similarity which
22
+ improves training stability
23
+ """
24
+
25
+ def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int):
26
+ super().__init__()
27
+ self.codebook_size = codebook_size
28
+ self.codebook_dim = codebook_dim
29
+
30
+ self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1)
31
+ self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1)
32
+ self.codebook = nn.Embedding(codebook_size, codebook_dim)
33
+
34
+ def forward(self, z):
35
+ """Quantized the input tensor using a fixed codebook and returns
36
+ the corresponding codebook vectors
37
+
38
+ Parameters
39
+ ----------
40
+ z : Tensor[B x D x T]
41
+
42
+ Returns
43
+ -------
44
+ Tensor[B x D x T]
45
+ Quantized continuous representation of input
46
+ Tensor[1]
47
+ Commitment loss to train encoder to predict vectors closer to codebook
48
+ entries
49
+ Tensor[1]
50
+ Codebook loss to update the codebook
51
+ Tensor[B x T]
52
+ Codebook indices (quantized discrete representation of input)
53
+ Tensor[B x D x T]
54
+ Projected latents (continuous representation of input before quantization)
55
+ """
56
+
57
+ # Factorized codes (ViT-VQGAN) Project input into low-dimensional space
58
+ z_e = self.in_proj(z) # z_e : (B x D x T)
59
+ z_q, indices = self.decode_latents(z_e)
60
+
61
+ commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])
62
+ codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2])
63
+
64
+ z_q = (
65
+ z_e + (z_q - z_e).detach()
66
+ ) # noop in forward pass, straight-through gradient estimator in backward pass
67
+
68
+ z_q = self.out_proj(z_q)
69
+
70
+ return z_q, commitment_loss, codebook_loss, indices, z_e
71
+
72
+ def embed_code(self, embed_id):
73
+ return F.embedding(embed_id, self.codebook.weight)
74
+
75
+ def decode_code(self, embed_id):
76
+ return self.embed_code(embed_id).transpose(1, 2)
77
+
78
+ def decode_latents(self, latents):
79
+ encodings = rearrange(latents, "b d t -> (b t) d")
80
+ codebook = self.codebook.weight # codebook: (N x D)
81
+
82
+ # L2 normalize encodings and codebook (ViT-VQGAN)
83
+ encodings = F.normalize(encodings)
84
+ codebook = F.normalize(codebook)
85
+
86
+ # Compute euclidean distance with codebook
87
+ dist = (
88
+ encodings.pow(2).sum(1, keepdim=True)
89
+ - 2 * encodings @ codebook.t()
90
+ + codebook.pow(2).sum(1, keepdim=True).t()
91
+ )
92
+ indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
93
+ z_q = self.decode_code(indices)
94
+ return z_q, indices
95
+
96
+
97
+ class ResidualVectorQuantize(nn.Module):
98
+ """
99
+ Introduced in SoundStream: An end2end neural audio codec
100
+ https://arxiv.org/abs/2107.03312
101
+ """
102
+
103
+ def __init__(
104
+ self,
105
+ input_dim: int = 512,
106
+ n_codebooks: int = 9,
107
+ codebook_size: int = 1024,
108
+ codebook_dim: Union[int, list] = 8,
109
+ quantizer_dropout: float = 0.0,
110
+ ):
111
+ super().__init__()
112
+ if isinstance(codebook_dim, int):
113
+ codebook_dim = [codebook_dim for _ in range(n_codebooks)]
114
+
115
+ self.n_codebooks = n_codebooks
116
+ self.codebook_dim = codebook_dim
117
+ self.codebook_size = codebook_size
118
+
119
+ self.quantizers = nn.ModuleList(
120
+ [
121
+ VectorQuantize(input_dim, codebook_size, codebook_dim[i])
122
+ for i in range(n_codebooks)
123
+ ]
124
+ )
125
+ self.quantizer_dropout = quantizer_dropout
126
+
127
+ def forward(self, z, n_quantizers: int = None):
128
+ """Quantized the input tensor using a fixed set of `n` codebooks and returns
129
+ the corresponding codebook vectors
130
+ Parameters
131
+ ----------
132
+ z : Tensor[B x D x T]
133
+ n_quantizers : int, optional
134
+ No. of quantizers to use
135
+ (n_quantizers < self.n_codebooks ex: for quantizer dropout)
136
+ Note: if `self.quantizer_dropout` is True, this argument is ignored
137
+ when in training mode, and a random number of quantizers is used.
138
+ Returns
139
+ -------
140
+ dict
141
+ A dictionary with the following keys:
142
+
143
+ "z" : Tensor[B x D x T]
144
+ Quantized continuous representation of input
145
+ "codes" : Tensor[B x N x T]
146
+ Codebook indices for each codebook
147
+ (quantized discrete representation of input)
148
+ "latents" : Tensor[B x N*D x T]
149
+ Projected latents (continuous representation of input before quantization)
150
+ "vq/commitment_loss" : Tensor[1]
151
+ Commitment loss to train encoder to predict vectors closer to codebook
152
+ entries
153
+ "vq/codebook_loss" : Tensor[1]
154
+ Codebook loss to update the codebook
155
+ """
156
+ z_q = 0
157
+ residual = z
158
+ commitment_loss = 0
159
+ codebook_loss = 0
160
+
161
+ codebook_indices = []
162
+ latents = []
163
+
164
+ if n_quantizers is None:
165
+ n_quantizers = self.n_codebooks
166
+ if self.training:
167
+ n_quantizers = torch.ones((z.shape[0],)) * self.n_codebooks + 1
168
+ dropout = torch.randint(1, self.n_codebooks + 1, (z.shape[0],))
169
+ n_dropout = int(z.shape[0] * self.quantizer_dropout)
170
+ n_quantizers[:n_dropout] = dropout[:n_dropout]
171
+ n_quantizers = n_quantizers.to(z.device)
172
+
173
+ for i, quantizer in enumerate(self.quantizers):
174
+ if self.training is False and i >= n_quantizers:
175
+ break
176
+
177
+ z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer(
178
+ residual
179
+ )
180
+
181
+ # Create mask to apply quantizer dropout
182
+ mask = (
183
+ torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers
184
+ )
185
+ z_q = z_q + z_q_i * mask[:, None, None]
186
+ residual = residual - z_q_i
187
+
188
+ # Sum losses
189
+ commitment_loss += (commitment_loss_i * mask).mean()
190
+ codebook_loss += (codebook_loss_i * mask).mean()
191
+
192
+ codebook_indices.append(indices_i)
193
+ latents.append(z_e_i)
194
+
195
+ codes = torch.stack(codebook_indices, dim=1)
196
+ latents = torch.cat(latents, dim=1)
197
+
198
+ return z_q, codes, latents, commitment_loss, codebook_loss
199
+
200
+ def from_codes(self, codes: torch.Tensor):
201
+ """Given the quantized codes, reconstruct the continuous representation
202
+ Parameters
203
+ ----------
204
+ codes : Tensor[B x N x T]
205
+ Quantized discrete representation of input
206
+ Returns
207
+ -------
208
+ Tensor[B x D x T]
209
+ Quantized continuous representation of input
210
+ """
211
+ z_q = 0.0
212
+ z_p = []
213
+ n_codebooks = codes.shape[1]
214
+ for i in range(n_codebooks):
215
+ z_p_i = self.quantizers[i].decode_code(codes[:, i, :])
216
+ z_p.append(z_p_i)
217
+
218
+ z_q_i = self.quantizers[i].out_proj(z_p_i)
219
+ z_q = z_q + z_q_i
220
+ return z_q, torch.cat(z_p, dim=1), codes
221
+
222
+ def from_latents(self, latents: torch.Tensor):
223
+ """Given the unquantized latents, reconstruct the
224
+ continuous representation after quantization.
225
+
226
+ Parameters
227
+ ----------
228
+ latents : Tensor[B x N x T]
229
+ Continuous representation of input after projection
230
+
231
+ Returns
232
+ -------
233
+ Tensor[B x D x T]
234
+ Quantized representation of full-projected space
235
+ Tensor[B x D x T]
236
+ Quantized representation of latent space
237
+ """
238
+ z_q = 0
239
+ z_p = []
240
+ codes = []
241
+ dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers])
242
+
243
+ n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[
244
+ 0
245
+ ]
246
+ for i in range(n_codebooks):
247
+ j, k = dims[i], dims[i + 1]
248
+ z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :])
249
+ z_p.append(z_p_i)
250
+ codes.append(codes_i)
251
+
252
+ z_q_i = self.quantizers[i].out_proj(z_p_i)
253
+ z_q = z_q + z_q_i
254
+
255
+ return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1)
256
+
257
+
258
+ if __name__ == "__main__":
259
+ rvq = ResidualVectorQuantize(quantizer_dropout=True)
260
+ x = torch.randn(16, 512, 80)
261
+ y = rvq(x)
262
+ print(y["latents"].shape)
vae_modules/dac/utils/__init__.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ import argbind
4
+ from audiotools import ml
5
+
6
+ from ..model import DAC
7
+
8
+ Accelerator = ml.Accelerator
9
+
10
+ __MODEL_LATEST_TAGS__ = {
11
+ ("44khz", "8kbps"): "0.0.1",
12
+ ("24khz", "8kbps"): "0.0.4",
13
+ ("16khz", "8kbps"): "0.0.5",
14
+ ("44khz", "16kbps"): "1.0.0",
15
+ }
16
+
17
+ __MODEL_URLS__ = {
18
+ (
19
+ "44khz",
20
+ "0.0.1",
21
+ "8kbps",
22
+ ): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.1/weights.pth",
23
+ (
24
+ "24khz",
25
+ "0.0.4",
26
+ "8kbps",
27
+ ): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.4/weights_24khz.pth",
28
+ (
29
+ "16khz",
30
+ "0.0.5",
31
+ "8kbps",
32
+ ): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.5/weights_16khz.pth",
33
+ (
34
+ "44khz",
35
+ "1.0.0",
36
+ "16kbps",
37
+ ): "https://github.com/descriptinc/descript-audio-codec/releases/download/1.0.0/weights_44khz_16kbps.pth",
38
+ }
39
+
40
+
41
+ @argbind.bind(group="download", positional=True, without_prefix=True)
42
+ def download(
43
+ model_type: str = "44khz", model_bitrate: str = "8kbps", tag: str = "latest"
44
+ ):
45
+ """
46
+ Function that downloads the weights file from URL if a local cache is not found.
47
+
48
+ Parameters
49
+ ----------
50
+ model_type : str
51
+ The type of model to download. Must be one of "44khz", "24khz", or "16khz". Defaults to "44khz".
52
+ model_bitrate: str
53
+ Bitrate of the model. Must be one of "8kbps", or "16kbps". Defaults to "8kbps".
54
+ Only 44khz model supports 16kbps.
55
+ tag : str
56
+ The tag of the model to download. Defaults to "latest".
57
+
58
+ Returns
59
+ -------
60
+ Path
61
+ Directory path required to load model via audiotools.
62
+ """
63
+ model_type = model_type.lower()
64
+ tag = tag.lower()
65
+
66
+ assert model_type in [
67
+ "44khz",
68
+ "24khz",
69
+ "16khz",
70
+ ], "model_type must be one of '44khz', '24khz', or '16khz'"
71
+
72
+ assert model_bitrate in [
73
+ "8kbps",
74
+ "16kbps",
75
+ ], "model_bitrate must be one of '8kbps', or '16kbps'"
76
+
77
+ if tag == "latest":
78
+ tag = __MODEL_LATEST_TAGS__[(model_type, model_bitrate)]
79
+
80
+ download_link = __MODEL_URLS__.get((model_type, tag, model_bitrate), None)
81
+
82
+ if download_link is None:
83
+ raise ValueError(
84
+ f"Could not find model with tag {tag} and model type {model_type}"
85
+ )
86
+
87
+ local_path = (
88
+ Path.home()
89
+ / ".cache"
90
+ / "descript"
91
+ / "dac"
92
+ / f"weights_{model_type}_{model_bitrate}_{tag}.pth"
93
+ )
94
+ if not local_path.exists():
95
+ local_path.parent.mkdir(parents=True, exist_ok=True)
96
+
97
+ # Download the model
98
+ import requests
99
+
100
+ response = requests.get(download_link)
101
+
102
+ if response.status_code != 200:
103
+ raise ValueError(
104
+ f"Could not download model. Received response code {response.status_code}"
105
+ )
106
+ local_path.write_bytes(response.content)
107
+
108
+ return local_path
109
+
110
+
111
+ def load_model(
112
+ model_type: str = "44khz",
113
+ model_bitrate: str = "8kbps",
114
+ tag: str = "latest",
115
+ load_path: str = None,
116
+ ):
117
+ if not load_path:
118
+ load_path = download(
119
+ model_type=model_type, model_bitrate=model_bitrate, tag=tag
120
+ )
121
+ generator = DAC.load(load_path)
122
+ return generator
vae_modules/dac/utils/decode.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ from pathlib import Path
3
+
4
+ import argbind
5
+ import numpy as np
6
+ import torch
7
+ from audiotools import AudioSignal
8
+ from tqdm import tqdm
9
+
10
+ from dac import DACFile
11
+ from dac.utils import load_model
12
+
13
+ warnings.filterwarnings("ignore", category=UserWarning)
14
+
15
+
16
+ @argbind.bind(group="decode", positional=True, without_prefix=True)
17
+ @torch.inference_mode()
18
+ @torch.no_grad()
19
+ def decode(
20
+ input: str,
21
+ output: str = "",
22
+ weights_path: str = "",
23
+ model_tag: str = "latest",
24
+ model_bitrate: str = "8kbps",
25
+ device: str = "cuda",
26
+ model_type: str = "44khz",
27
+ verbose: bool = False,
28
+ ):
29
+ """Decode audio from codes.
30
+
31
+ Parameters
32
+ ----------
33
+ input : str
34
+ Path to input directory or file
35
+ output : str, optional
36
+ Path to output directory, by default "".
37
+ If `input` is a directory, the directory sub-tree relative to `input` is re-created in `output`.
38
+ weights_path : str, optional
39
+ Path to weights file, by default "". If not specified, the weights file will be downloaded from the internet using the
40
+ model_tag and model_type.
41
+ model_tag : str, optional
42
+ Tag of the model to use, by default "latest". Ignored if `weights_path` is specified.
43
+ model_bitrate: str
44
+ Bitrate of the model. Must be one of "8kbps", or "16kbps". Defaults to "8kbps".
45
+ device : str, optional
46
+ Device to use, by default "cuda". If "cpu", the model will be loaded on the CPU.
47
+ model_type : str, optional
48
+ The type of model to use. Must be one of "44khz", "24khz", or "16khz". Defaults to "44khz". Ignored if `weights_path` is specified.
49
+ """
50
+ generator = load_model(
51
+ model_type=model_type,
52
+ model_bitrate=model_bitrate,
53
+ tag=model_tag,
54
+ load_path=weights_path,
55
+ )
56
+ generator.to(device)
57
+ generator.eval()
58
+
59
+ # Find all .dac files in input directory
60
+ _input = Path(input)
61
+ input_files = list(_input.glob("**/*.dac"))
62
+
63
+ # If input is a .dac file, add it to the list
64
+ if _input.suffix == ".dac":
65
+ input_files.append(_input)
66
+
67
+ # Create output directory
68
+ output = Path(output)
69
+ output.mkdir(parents=True, exist_ok=True)
70
+
71
+ for i in tqdm(range(len(input_files)), desc=f"Decoding files"):
72
+ # Load file
73
+ artifact = DACFile.load(input_files[i])
74
+
75
+ # Reconstruct audio from codes
76
+ recons = generator.decompress(artifact, verbose=verbose)
77
+
78
+ # Compute output path
79
+ relative_path = input_files[i].relative_to(input)
80
+ output_dir = output / relative_path.parent
81
+ if not relative_path.name:
82
+ output_dir = output
83
+ relative_path = input_files[i]
84
+ output_name = relative_path.with_suffix(".wav").name
85
+ output_path = output_dir / output_name
86
+ output_path.parent.mkdir(parents=True, exist_ok=True)
87
+
88
+ # Write to file
89
+ recons.write(output_path)
90
+
91
+
92
+ if __name__ == "__main__":
93
+ args = argbind.parse_args()
94
+ with argbind.scope(args):
95
+ decode()
vae_modules/dac/utils/encode.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import warnings
3
+ from pathlib import Path
4
+
5
+ import argbind
6
+ import numpy as np
7
+ import torch
8
+ from audiotools import AudioSignal
9
+ from audiotools.core import util
10
+ from tqdm import tqdm
11
+
12
+ from dac.utils import load_model
13
+
14
+ warnings.filterwarnings("ignore", category=UserWarning)
15
+
16
+
17
+ @argbind.bind(group="encode", positional=True, without_prefix=True)
18
+ @torch.inference_mode()
19
+ @torch.no_grad()
20
+ def encode(
21
+ input: str,
22
+ output: str = "",
23
+ weights_path: str = "",
24
+ model_tag: str = "latest",
25
+ model_bitrate: str = "8kbps",
26
+ n_quantizers: int = None,
27
+ device: str = "cuda",
28
+ model_type: str = "44khz",
29
+ win_duration: float = 5.0,
30
+ verbose: bool = False,
31
+ ):
32
+ """Encode audio files in input path to .dac format.
33
+
34
+ Parameters
35
+ ----------
36
+ input : str
37
+ Path to input audio file or directory
38
+ output : str, optional
39
+ Path to output directory, by default "". If `input` is a directory, the directory sub-tree relative to `input` is re-created in `output`.
40
+ weights_path : str, optional
41
+ Path to weights file, by default "". If not specified, the weights file will be downloaded from the internet using the
42
+ model_tag and model_type.
43
+ model_tag : str, optional
44
+ Tag of the model to use, by default "latest". Ignored if `weights_path` is specified.
45
+ model_bitrate: str
46
+ Bitrate of the model. Must be one of "8kbps", or "16kbps". Defaults to "8kbps".
47
+ n_quantizers : int, optional
48
+ Number of quantizers to use, by default None. If not specified, all the quantizers will be used and the model will compress at maximum bitrate.
49
+ device : str, optional
50
+ Device to use, by default "cuda"
51
+ model_type : str, optional
52
+ The type of model to use. Must be one of "44khz", "24khz", or "16khz". Defaults to "44khz". Ignored if `weights_path` is specified.
53
+ """
54
+ generator = load_model(
55
+ model_type=model_type,
56
+ model_bitrate=model_bitrate,
57
+ tag=model_tag,
58
+ load_path=weights_path,
59
+ )
60
+ generator.to(device)
61
+ generator.eval()
62
+ kwargs = {"n_quantizers": n_quantizers}
63
+
64
+ # Find all audio files in input path
65
+ input = Path(input)
66
+ audio_files = util.find_audio(input)
67
+
68
+ output = Path(output)
69
+ output.mkdir(parents=True, exist_ok=True)
70
+
71
+ for i in tqdm(range(len(audio_files)), desc="Encoding files"):
72
+ # Load file
73
+ signal = AudioSignal(audio_files[i])
74
+
75
+ # Encode audio to .dac format
76
+ artifact = generator.compress(signal, win_duration, verbose=verbose, **kwargs)
77
+
78
+ # Compute output path
79
+ relative_path = audio_files[i].relative_to(input)
80
+ output_dir = output / relative_path.parent
81
+ if not relative_path.name:
82
+ output_dir = output
83
+ relative_path = audio_files[i]
84
+ output_name = relative_path.with_suffix(".dac").name
85
+ output_path = output_dir / output_name
86
+ output_path.parent.mkdir(parents=True, exist_ok=True)
87
+
88
+ artifact.save(output_path)
89
+
90
+
91
+ if __name__ == "__main__":
92
+ args = argbind.parse_args()
93
+ with argbind.scope(args):
94
+ encode()
vae_modules/stable_vae/__init__.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .models.autoencoders import create_autoencoder_from_config
2
+ import os
3
+ import json
4
+ import torch
5
+ from torch.nn.utils import remove_weight_norm
6
+
7
+
8
+ def remove_all_weight_norm(model):
9
+ for name, module in model.named_modules():
10
+ if hasattr(module, 'weight_g'):
11
+ remove_weight_norm(module)
12
+
13
+
14
+ def load_vae(ckpt_path, remove_weight_norm=False):
15
+ config_file = os.path.join(os.path.dirname(ckpt_path), 'config.json')
16
+
17
+ # Load the model configuration
18
+ with open(config_file) as f:
19
+ model_config = json.load(f)
20
+
21
+ # Create the model from the configuration
22
+ model = create_autoencoder_from_config(model_config)
23
+
24
+ # Load the state dictionary from the checkpoint
25
+ model_dict = torch.load(ckpt_path, map_location='cpu')['state_dict']
26
+
27
+ # Strip the "autoencoder." prefix from the keys
28
+ model_dict = {key[len("autoencoder."):]: value for key, value in model_dict.items() if key.startswith("autoencoder.")}
29
+
30
+ # Load the state dictionary into the model
31
+ model.load_state_dict(model_dict)
32
+
33
+ # Remove weight normalization
34
+ if remove_weight_norm:
35
+ remove_all_weight_norm(model)
36
+
37
+ # Set the model to evaluation mode
38
+ model.eval()
39
+
40
+ return model
vae_modules/stable_vae/models/autoencoders.py ADDED
@@ -0,0 +1,683 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import math
3
+ import numpy as np
4
+
5
+ from torch import nn
6
+ from torch.nn import functional as F
7
+ from torchaudio import transforms as T
8
+ from alias_free_torch import Activation1d
9
+ from .nn.layers import WNConv1d, WNConvTranspose1d
10
+ from typing import Literal, Dict, Any
11
+
12
+ # from .inference.sampling import sample
13
+ from .utils import prepare_audio
14
+ from .blocks import SnakeBeta
15
+ from .bottleneck import Bottleneck, DiscreteBottleneck
16
+ from .factory import create_pretransform_from_config, create_bottleneck_from_config
17
+ from .pretransforms import Pretransform
18
+
19
+ def checkpoint(function, *args, **kwargs):
20
+ kwargs.setdefault("use_reentrant", False)
21
+ return torch.utils.checkpoint.checkpoint(function, *args, **kwargs)
22
+
23
+ def get_activation(activation: Literal["elu", "snake", "none"], antialias=False, channels=None) -> nn.Module:
24
+ if activation == "elu":
25
+ act = nn.ELU()
26
+ elif activation == "snake":
27
+ act = SnakeBeta(channels)
28
+ elif activation == "none":
29
+ act = nn.Identity()
30
+ else:
31
+ raise ValueError(f"Unknown activation {activation}")
32
+
33
+ if antialias:
34
+ act = Activation1d(act)
35
+
36
+ return act
37
+
38
+ class ResidualUnit(nn.Module):
39
+ def __init__(self, in_channels, out_channels, dilation, use_snake=False, antialias_activation=False):
40
+ super().__init__()
41
+
42
+ self.dilation = dilation
43
+
44
+ padding = (dilation * (7-1)) // 2
45
+
46
+ self.layers = nn.Sequential(
47
+ get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=out_channels),
48
+ WNConv1d(in_channels=in_channels, out_channels=out_channels,
49
+ kernel_size=7, dilation=dilation, padding=padding),
50
+ get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=out_channels),
51
+ WNConv1d(in_channels=out_channels, out_channels=out_channels,
52
+ kernel_size=1)
53
+ )
54
+
55
+ def forward(self, x):
56
+ res = x
57
+
58
+ #x = checkpoint(self.layers, x)
59
+ x = self.layers(x)
60
+
61
+ return x + res
62
+
63
+ class EncoderBlock(nn.Module):
64
+ def __init__(self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False):
65
+ super().__init__()
66
+
67
+ self.layers = nn.Sequential(
68
+ ResidualUnit(in_channels=in_channels,
69
+ out_channels=in_channels, dilation=1, use_snake=use_snake),
70
+ ResidualUnit(in_channels=in_channels,
71
+ out_channels=in_channels, dilation=3, use_snake=use_snake),
72
+ ResidualUnit(in_channels=in_channels,
73
+ out_channels=in_channels, dilation=9, use_snake=use_snake),
74
+ get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels),
75
+ WNConv1d(in_channels=in_channels, out_channels=out_channels,
76
+ kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2)),
77
+ )
78
+
79
+ def forward(self, x):
80
+ return self.layers(x)
81
+
82
+ class DecoderBlock(nn.Module):
83
+ def __init__(self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False, use_nearest_upsample=False):
84
+ super().__init__()
85
+
86
+ if use_nearest_upsample:
87
+ upsample_layer = nn.Sequential(
88
+ nn.Upsample(scale_factor=stride, mode="nearest"),
89
+ WNConv1d(in_channels=in_channels,
90
+ out_channels=out_channels,
91
+ kernel_size=2*stride,
92
+ stride=1,
93
+ bias=False,
94
+ padding='same')
95
+ )
96
+ else:
97
+ upsample_layer = WNConvTranspose1d(in_channels=in_channels,
98
+ out_channels=out_channels,
99
+ kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2))
100
+
101
+ self.layers = nn.Sequential(
102
+ get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels),
103
+ upsample_layer,
104
+ ResidualUnit(in_channels=out_channels, out_channels=out_channels,
105
+ dilation=1, use_snake=use_snake),
106
+ ResidualUnit(in_channels=out_channels, out_channels=out_channels,
107
+ dilation=3, use_snake=use_snake),
108
+ ResidualUnit(in_channels=out_channels, out_channels=out_channels,
109
+ dilation=9, use_snake=use_snake),
110
+ )
111
+
112
+ def forward(self, x):
113
+ return self.layers(x)
114
+
115
+ class OobleckEncoder(nn.Module):
116
+ def __init__(self,
117
+ in_channels=2,
118
+ channels=128,
119
+ latent_dim=32,
120
+ c_mults = [1, 2, 4, 8],
121
+ strides = [2, 4, 8, 8],
122
+ use_snake=False,
123
+ antialias_activation=False
124
+ ):
125
+ super().__init__()
126
+
127
+ c_mults = [1] + c_mults
128
+
129
+ self.depth = len(c_mults)
130
+
131
+ layers = [
132
+ WNConv1d(in_channels=in_channels, out_channels=c_mults[0] * channels, kernel_size=7, padding=3)
133
+ ]
134
+
135
+ for i in range(self.depth-1):
136
+ layers += [EncoderBlock(in_channels=c_mults[i]*channels, out_channels=c_mults[i+1]*channels, stride=strides[i], use_snake=use_snake)]
137
+
138
+ layers += [
139
+ get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[-1] * channels),
140
+ WNConv1d(in_channels=c_mults[-1]*channels, out_channels=latent_dim, kernel_size=3, padding=1)
141
+ ]
142
+
143
+ self.layers = nn.Sequential(*layers)
144
+
145
+ def forward(self, x):
146
+ return self.layers(x)
147
+
148
+
149
+ class OobleckDecoder(nn.Module):
150
+ def __init__(self,
151
+ out_channels=2,
152
+ channels=128,
153
+ latent_dim=32,
154
+ c_mults = [1, 2, 4, 8],
155
+ strides = [2, 4, 8, 8],
156
+ use_snake=False,
157
+ antialias_activation=False,
158
+ use_nearest_upsample=False,
159
+ final_tanh=True):
160
+ super().__init__()
161
+
162
+ c_mults = [1] + c_mults
163
+
164
+ self.depth = len(c_mults)
165
+
166
+ layers = [
167
+ WNConv1d(in_channels=latent_dim, out_channels=c_mults[-1]*channels, kernel_size=7, padding=3),
168
+ ]
169
+
170
+ for i in range(self.depth-1, 0, -1):
171
+ layers += [DecoderBlock(
172
+ in_channels=c_mults[i]*channels,
173
+ out_channels=c_mults[i-1]*channels,
174
+ stride=strides[i-1],
175
+ use_snake=use_snake,
176
+ antialias_activation=antialias_activation,
177
+ use_nearest_upsample=use_nearest_upsample
178
+ )
179
+ ]
180
+
181
+ layers += [
182
+ get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[0] * channels),
183
+ WNConv1d(in_channels=c_mults[0] * channels, out_channels=out_channels, kernel_size=7, padding=3, bias=False),
184
+ nn.Tanh() if final_tanh else nn.Identity()
185
+ ]
186
+
187
+ self.layers = nn.Sequential(*layers)
188
+
189
+ def forward(self, x):
190
+ return self.layers(x)
191
+
192
+
193
+ class DACEncoderWrapper(nn.Module):
194
+ def __init__(self, in_channels=1, **kwargs):
195
+ super().__init__()
196
+
197
+ from dac.model.dac import Encoder as DACEncoder
198
+
199
+ latent_dim = kwargs.pop("latent_dim", None)
200
+
201
+ encoder_out_dim = kwargs["d_model"] * (2 ** len(kwargs["strides"]))
202
+ self.encoder = DACEncoder(d_latent=encoder_out_dim, **kwargs)
203
+ self.latent_dim = latent_dim
204
+
205
+ # Latent-dim support was added to DAC after this was first written, and implemented differently, so this is for backwards compatibility
206
+ self.proj_out = nn.Conv1d(self.encoder.enc_dim, latent_dim, kernel_size=1) if latent_dim is not None else nn.Identity()
207
+
208
+ if in_channels != 1:
209
+ self.encoder.block[0] = WNConv1d(in_channels, kwargs.get("d_model", 64), kernel_size=7, padding=3)
210
+
211
+ def forward(self, x):
212
+ x = self.encoder(x)
213
+ x = self.proj_out(x)
214
+ return x
215
+
216
+ class DACDecoderWrapper(nn.Module):
217
+ def __init__(self, latent_dim, out_channels=1, **kwargs):
218
+ super().__init__()
219
+
220
+ from dac.model.dac import Decoder as DACDecoder
221
+
222
+ self.decoder = DACDecoder(**kwargs, input_channel = latent_dim, d_out=out_channels)
223
+
224
+ self.latent_dim = latent_dim
225
+
226
+ def forward(self, x):
227
+ return self.decoder(x)
228
+
229
+ class AudioAutoencoder(nn.Module):
230
+ def __init__(
231
+ self,
232
+ encoder,
233
+ decoder,
234
+ latent_dim,
235
+ downsampling_ratio,
236
+ sample_rate,
237
+ io_channels=2,
238
+ bottleneck: Bottleneck = None,
239
+ pretransform: Pretransform = None,
240
+ in_channels = None,
241
+ out_channels = None,
242
+ soft_clip = False
243
+ ):
244
+ super().__init__()
245
+
246
+ self.downsampling_ratio = downsampling_ratio
247
+ self.sample_rate = sample_rate
248
+
249
+ self.latent_dim = latent_dim
250
+ self.io_channels = io_channels
251
+ self.in_channels = io_channels
252
+ self.out_channels = io_channels
253
+
254
+ self.min_length = self.downsampling_ratio
255
+
256
+ if in_channels is not None:
257
+ self.in_channels = in_channels
258
+
259
+ if out_channels is not None:
260
+ self.out_channels = out_channels
261
+
262
+ self.bottleneck = bottleneck
263
+
264
+ self.encoder = encoder
265
+
266
+ self.decoder = decoder
267
+
268
+ self.pretransform = pretransform
269
+
270
+ self.soft_clip = soft_clip
271
+
272
+ self.is_discrete = self.bottleneck is not None and self.bottleneck.is_discrete
273
+
274
+ def encode(self, audio, return_info=False, skip_pretransform=False, iterate_batch=False, **kwargs):
275
+
276
+ info = {}
277
+
278
+ if self.pretransform is not None and not skip_pretransform:
279
+ if self.pretransform.enable_grad:
280
+ if iterate_batch:
281
+ audios = []
282
+ for i in range(audio.shape[0]):
283
+ audios.append(self.pretransform.encode(audio[i:i+1]))
284
+ audio = torch.cat(audios, dim=0)
285
+ else:
286
+ audio = self.pretransform.encode(audio)
287
+ else:
288
+ with torch.no_grad():
289
+ if iterate_batch:
290
+ audios = []
291
+ for i in range(audio.shape[0]):
292
+ audios.append(self.pretransform.encode(audio[i:i+1]))
293
+ audio = torch.cat(audios, dim=0)
294
+ else:
295
+ audio = self.pretransform.encode(audio)
296
+
297
+ if self.encoder is not None:
298
+ if iterate_batch:
299
+ latents = []
300
+ for i in range(audio.shape[0]):
301
+ latents.append(self.encoder(audio[i:i+1]))
302
+ latents = torch.cat(latents, dim=0)
303
+ else:
304
+ latents = self.encoder(audio)
305
+ else:
306
+ latents = audio
307
+
308
+ if self.bottleneck is not None:
309
+ # TODO: Add iterate batch logic, needs to merge the info dicts
310
+ latents, bottleneck_info = self.bottleneck.encode(latents, return_info=True, **kwargs)
311
+
312
+ info.update(bottleneck_info)
313
+
314
+ if return_info:
315
+ return latents, info
316
+
317
+ return latents
318
+
319
+ def decode(self, latents, iterate_batch=False, **kwargs):
320
+
321
+ if self.bottleneck is not None:
322
+ if iterate_batch:
323
+ decoded = []
324
+ for i in range(latents.shape[0]):
325
+ decoded.append(self.bottleneck.decode(latents[i:i+1]))
326
+ decoded = torch.cat(decoded, dim=0)
327
+ else:
328
+ latents = self.bottleneck.decode(latents)
329
+
330
+ if iterate_batch:
331
+ decoded = []
332
+ for i in range(latents.shape[0]):
333
+ decoded.append(self.decoder(latents[i:i+1]))
334
+ decoded = torch.cat(decoded, dim=0)
335
+ else:
336
+ decoded = self.decoder(latents, **kwargs)
337
+
338
+ if self.pretransform is not None:
339
+ if self.pretransform.enable_grad:
340
+ if iterate_batch:
341
+ decodeds = []
342
+ for i in range(decoded.shape[0]):
343
+ decodeds.append(self.pretransform.decode(decoded[i:i+1]))
344
+ decoded = torch.cat(decodeds, dim=0)
345
+ else:
346
+ decoded = self.pretransform.decode(decoded)
347
+ else:
348
+ with torch.no_grad():
349
+ if iterate_batch:
350
+ decodeds = []
351
+ for i in range(latents.shape[0]):
352
+ decodeds.append(self.pretransform.decode(decoded[i:i+1]))
353
+ decoded = torch.cat(decodeds, dim=0)
354
+ else:
355
+ decoded = self.pretransform.decode(decoded)
356
+
357
+ if self.soft_clip:
358
+ decoded = torch.tanh(decoded)
359
+
360
+ return decoded
361
+
362
+ def decode_tokens(self, tokens, **kwargs):
363
+ '''
364
+ Decode discrete tokens to audio
365
+ Only works with discrete autoencoders
366
+ '''
367
+
368
+ assert isinstance(self.bottleneck, DiscreteBottleneck), "decode_tokens only works with discrete autoencoders"
369
+
370
+ latents = self.bottleneck.decode_tokens(tokens, **kwargs)
371
+
372
+ return self.decode(latents, **kwargs)
373
+
374
+
375
+ def preprocess_audio_for_encoder(self, audio, in_sr):
376
+ '''
377
+ Preprocess single audio tensor (Channels x Length) to be compatible with the encoder.
378
+ If the model is mono, stereo audio will be converted to mono.
379
+ Audio will be silence-padded to be a multiple of the model's downsampling ratio.
380
+ Audio will be resampled to the model's sample rate.
381
+ The output will have batch size 1 and be shape (1 x Channels x Length)
382
+ '''
383
+ return self.preprocess_audio_list_for_encoder([audio], [in_sr])
384
+
385
+ def preprocess_audio_list_for_encoder(self, audio_list, in_sr_list):
386
+ '''
387
+ Preprocess a [list] of audio (Channels x Length) into a batch tensor to be compatable with the encoder.
388
+ The audio in that list can be of different lengths and channels.
389
+ in_sr can be an integer or list. If it's an integer it will be assumed it is the input sample_rate for every audio.
390
+ All audio will be resampled to the model's sample rate.
391
+ Audio will be silence-padded to the longest length, and further padded to be a multiple of the model's downsampling ratio.
392
+ If the model is mono, all audio will be converted to mono.
393
+ The output will be a tensor of shape (Batch x Channels x Length)
394
+ '''
395
+ batch_size = len(audio_list)
396
+ if isinstance(in_sr_list, int):
397
+ in_sr_list = [in_sr_list]*batch_size
398
+ assert len(in_sr_list) == batch_size, "list of sample rates must be the same length of audio_list"
399
+ new_audio = []
400
+ max_length = 0
401
+ # resample & find the max length
402
+ for i in range(batch_size):
403
+ audio = audio_list[i]
404
+ in_sr = in_sr_list[i]
405
+ if len(audio.shape) == 3 and audio.shape[0] == 1:
406
+ # batchsize 1 was given by accident. Just squeeze it.
407
+ audio = audio.squeeze(0)
408
+ elif len(audio.shape) == 1:
409
+ # Mono signal, channel dimension is missing, unsqueeze it in
410
+ audio = audio.unsqueeze(0)
411
+ assert len(audio.shape)==2, "Audio should be shape (Channels x Length) with no batch dimension"
412
+ # Resample audio
413
+ if in_sr != self.sample_rate:
414
+ resample_tf = T.Resample(in_sr, self.sample_rate).to(audio.device)
415
+ audio = resample_tf(audio)
416
+ new_audio.append(audio)
417
+ if audio.shape[-1] > max_length:
418
+ max_length = audio.shape[-1]
419
+ # Pad every audio to the same length, multiple of model's downsampling ratio
420
+ padded_audio_length = max_length + (self.min_length - (max_length % self.min_length)) % self.min_length
421
+ for i in range(batch_size):
422
+ # Pad it & if necessary, mixdown/duplicate stereo/mono channels to support model
423
+ new_audio[i] = prepare_audio(new_audio[i], in_sr=in_sr, target_sr=in_sr, target_length=padded_audio_length,
424
+ target_channels=self.in_channels, device=new_audio[i].device).squeeze(0)
425
+ # convert to tensor
426
+ return torch.stack(new_audio)
427
+
428
+ def encode_audio(self, audio, chunked=False, overlap=32, chunk_size=128, **kwargs):
429
+ '''
430
+ Encode audios into latents. Audios should already be preprocesed by preprocess_audio_for_encoder.
431
+ If chunked is True, split the audio into chunks of a given maximum size chunk_size, with given overlap.
432
+ Overlap and chunk_size params are both measured in number of latents (not audio samples)
433
+ # and therefore you likely could use the same values with decode_audio.
434
+ A overlap of zero will cause discontinuity artefacts. Overlap should be => receptive field size.
435
+ Every autoencoder will have a different receptive field size, and thus ideal overlap.
436
+ You can determine it empirically by diffing unchunked vs chunked output and looking at maximum diff.
437
+ The final chunk may have a longer overlap in order to keep chunk_size consistent for all chunks.
438
+ Smaller chunk_size uses less memory, but more compute.
439
+ The chunk_size vs memory tradeoff isn't linear, and possibly depends on the GPU and CUDA version
440
+ For example, on a A6000 chunk_size 128 is overall faster than 256 and 512 even though it has more chunks
441
+ '''
442
+ if not chunked:
443
+ # default behavior. Encode the entire audio in parallel
444
+ return self.encode(audio, **kwargs)
445
+ else:
446
+ # CHUNKED ENCODING
447
+ # samples_per_latent is just the downsampling ratio (which is also the upsampling ratio)
448
+ samples_per_latent = self.downsampling_ratio
449
+ total_size = audio.shape[2] # in samples
450
+ batch_size = audio.shape[0]
451
+ chunk_size *= samples_per_latent # converting metric in latents to samples
452
+ overlap *= samples_per_latent # converting metric in latents to samples
453
+ hop_size = chunk_size - overlap
454
+ chunks = []
455
+ for i in range(0, total_size - chunk_size + 1, hop_size):
456
+ chunk = audio[:,:,i:i+chunk_size]
457
+ chunks.append(chunk)
458
+ if i+chunk_size != total_size:
459
+ # Final chunk
460
+ chunk = audio[:,:,-chunk_size:]
461
+ chunks.append(chunk)
462
+ chunks = torch.stack(chunks)
463
+ num_chunks = chunks.shape[0]
464
+ # Note: y_size might be a different value from the latent length used in diffusion training
465
+ # because we can encode audio of varying lengths
466
+ # However, the audio should've been padded to a multiple of samples_per_latent by now.
467
+ y_size = total_size // samples_per_latent
468
+ # Create an empty latent, we will populate it with chunks as we encode them
469
+ y_final = torch.zeros((batch_size,self.latent_dim,y_size)).to(audio.device)
470
+ for i in range(num_chunks):
471
+ x_chunk = chunks[i,:]
472
+ # encode the chunk
473
+ y_chunk = self.encode(x_chunk)
474
+ # figure out where to put the audio along the time domain
475
+ if i == num_chunks-1:
476
+ # final chunk always goes at the end
477
+ t_end = y_size
478
+ t_start = t_end - y_chunk.shape[2]
479
+ else:
480
+ t_start = i * hop_size // samples_per_latent
481
+ t_end = t_start + chunk_size // samples_per_latent
482
+ # remove the edges of the overlaps
483
+ ol = overlap//samples_per_latent//2
484
+ chunk_start = 0
485
+ chunk_end = y_chunk.shape[2]
486
+ if i > 0:
487
+ # no overlap for the start of the first chunk
488
+ t_start += ol
489
+ chunk_start += ol
490
+ if i < num_chunks-1:
491
+ # no overlap for the end of the last chunk
492
+ t_end -= ol
493
+ chunk_end -= ol
494
+ # paste the chunked audio into our y_final output audio
495
+ y_final[:,:,t_start:t_end] = y_chunk[:,:,chunk_start:chunk_end]
496
+ return y_final
497
+
498
+ def decode_audio(self, latents, chunked=False, overlap=32, chunk_size=128, **kwargs):
499
+ '''
500
+ Decode latents to audio.
501
+ If chunked is True, split the latents into chunks of a given maximum size chunk_size, with given overlap, both of which are measured in number of latents.
502
+ A overlap of zero will cause discontinuity artefacts. Overlap should be => receptive field size.
503
+ Every autoencoder will have a different receptive field size, and thus ideal overlap.
504
+ You can determine it empirically by diffing unchunked vs chunked audio and looking at maximum diff.
505
+ The final chunk may have a longer overlap in order to keep chunk_size consistent for all chunks.
506
+ Smaller chunk_size uses less memory, but more compute.
507
+ The chunk_size vs memory tradeoff isn't linear, and possibly depends on the GPU and CUDA version
508
+ For example, on a A6000 chunk_size 128 is overall faster than 256 and 512 even though it has more chunks
509
+ '''
510
+ if not chunked:
511
+ # default behavior. Decode the entire latent in parallel
512
+ return self.decode(latents, **kwargs)
513
+ else:
514
+ # chunked decoding
515
+ hop_size = chunk_size - overlap
516
+ total_size = latents.shape[2]
517
+ batch_size = latents.shape[0]
518
+ chunks = []
519
+ for i in range(0, total_size - chunk_size + 1, hop_size):
520
+ chunk = latents[:,:,i:i+chunk_size]
521
+ chunks.append(chunk)
522
+ if i+chunk_size != total_size:
523
+ # Final chunk
524
+ chunk = latents[:,:,-chunk_size:]
525
+ chunks.append(chunk)
526
+ chunks = torch.stack(chunks)
527
+ num_chunks = chunks.shape[0]
528
+ # samples_per_latent is just the downsampling ratio
529
+ samples_per_latent = self.downsampling_ratio
530
+ # Create an empty waveform, we will populate it with chunks as decode them
531
+ y_size = total_size * samples_per_latent
532
+ y_final = torch.zeros((batch_size,self.out_channels,y_size)).to(latents.device)
533
+ for i in range(num_chunks):
534
+ x_chunk = chunks[i,:]
535
+ # decode the chunk
536
+ y_chunk = self.decode(x_chunk)
537
+ # figure out where to put the audio along the time domain
538
+ if i == num_chunks-1:
539
+ # final chunk always goes at the end
540
+ t_end = y_size
541
+ t_start = t_end - y_chunk.shape[2]
542
+ else:
543
+ t_start = i * hop_size * samples_per_latent
544
+ t_end = t_start + chunk_size * samples_per_latent
545
+ # remove the edges of the overlaps
546
+ ol = (overlap//2) * samples_per_latent
547
+ chunk_start = 0
548
+ chunk_end = y_chunk.shape[2]
549
+ if i > 0:
550
+ # no overlap for the start of the first chunk
551
+ t_start += ol
552
+ chunk_start += ol
553
+ if i < num_chunks-1:
554
+ # no overlap for the end of the last chunk
555
+ t_end -= ol
556
+ chunk_end -= ol
557
+ # paste the chunked audio into our y_final output audio
558
+ y_final[:,:,t_start:t_end] = y_chunk[:,:,chunk_start:chunk_end]
559
+ return y_final
560
+
561
+
562
+ # AE factories
563
+
564
+ def create_encoder_from_config(encoder_config: Dict[str, Any]):
565
+ encoder_type = encoder_config.get("type", None)
566
+ assert encoder_type is not None, "Encoder type must be specified"
567
+
568
+ if encoder_type == "oobleck":
569
+ encoder = OobleckEncoder(
570
+ **encoder_config["config"]
571
+ )
572
+
573
+ elif encoder_type == "seanet":
574
+ from encodec.modules import SEANetEncoder
575
+ seanet_encoder_config = encoder_config["config"]
576
+
577
+ #SEANet encoder expects strides in reverse order
578
+ seanet_encoder_config["ratios"] = list(reversed(seanet_encoder_config.get("ratios", [2, 2, 2, 2, 2])))
579
+ encoder = SEANetEncoder(
580
+ **seanet_encoder_config
581
+ )
582
+ elif encoder_type == "dac":
583
+ dac_config = encoder_config["config"]
584
+
585
+ encoder = DACEncoderWrapper(**dac_config)
586
+ elif encoder_type == "local_attn":
587
+ from .local_attention import TransformerEncoder1D
588
+
589
+ local_attn_config = encoder_config["config"]
590
+
591
+ encoder = TransformerEncoder1D(
592
+ **local_attn_config
593
+ )
594
+ else:
595
+ raise ValueError(f"Unknown encoder type {encoder_type}")
596
+
597
+ requires_grad = encoder_config.get("requires_grad", True)
598
+ if not requires_grad:
599
+ for param in encoder.parameters():
600
+ param.requires_grad = False
601
+
602
+ return encoder
603
+
604
+ def create_decoder_from_config(decoder_config: Dict[str, Any]):
605
+ decoder_type = decoder_config.get("type", None)
606
+ assert decoder_type is not None, "Decoder type must be specified"
607
+
608
+ if decoder_type == "oobleck":
609
+ decoder = OobleckDecoder(
610
+ **decoder_config["config"]
611
+ )
612
+ elif decoder_type == "seanet":
613
+ from encodec.modules import SEANetDecoder
614
+
615
+ decoder = SEANetDecoder(
616
+ **decoder_config["config"]
617
+ )
618
+ elif decoder_type == "dac":
619
+ dac_config = decoder_config["config"]
620
+
621
+ decoder = DACDecoderWrapper(**dac_config)
622
+ elif decoder_type == "local_attn":
623
+ from .local_attention import TransformerDecoder1D
624
+
625
+ local_attn_config = decoder_config["config"]
626
+
627
+ decoder = TransformerDecoder1D(
628
+ **local_attn_config
629
+ )
630
+ else:
631
+ raise ValueError(f"Unknown decoder type {decoder_type}")
632
+
633
+ requires_grad = decoder_config.get("requires_grad", True)
634
+ if not requires_grad:
635
+ for param in decoder.parameters():
636
+ param.requires_grad = False
637
+
638
+ return decoder
639
+
640
+ def create_autoencoder_from_config(config: Dict[str, Any]):
641
+
642
+ ae_config = config["model"]
643
+
644
+ encoder = create_encoder_from_config(ae_config["encoder"])
645
+ decoder = create_decoder_from_config(ae_config["decoder"])
646
+
647
+ bottleneck = ae_config.get("bottleneck", None)
648
+
649
+ latent_dim = ae_config.get("latent_dim", None)
650
+ assert latent_dim is not None, "latent_dim must be specified in model config"
651
+ downsampling_ratio = ae_config.get("downsampling_ratio", None)
652
+ assert downsampling_ratio is not None, "downsampling_ratio must be specified in model config"
653
+ io_channels = ae_config.get("io_channels", None)
654
+ assert io_channels is not None, "io_channels must be specified in model config"
655
+ sample_rate = config.get("sample_rate", None)
656
+ assert sample_rate is not None, "sample_rate must be specified in model config"
657
+
658
+ in_channels = ae_config.get("in_channels", None)
659
+ out_channels = ae_config.get("out_channels", None)
660
+
661
+ pretransform = ae_config.get("pretransform", None)
662
+
663
+ if pretransform is not None:
664
+ pretransform = create_pretransform_from_config(pretransform, sample_rate)
665
+
666
+ if bottleneck is not None:
667
+ bottleneck = create_bottleneck_from_config(bottleneck)
668
+
669
+ soft_clip = ae_config["decoder"].get("soft_clip", False)
670
+
671
+ return AudioAutoencoder(
672
+ encoder,
673
+ decoder,
674
+ io_channels=io_channels,
675
+ latent_dim=latent_dim,
676
+ downsampling_ratio=downsampling_ratio,
677
+ sample_rate=sample_rate,
678
+ bottleneck=bottleneck,
679
+ pretransform=pretransform,
680
+ in_channels=in_channels,
681
+ out_channels=out_channels,
682
+ soft_clip=soft_clip
683
+ )
vae_modules/stable_vae/models/blocks.py ADDED
@@ -0,0 +1,359 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import reduce
2
+ import math
3
+ import numpy as np
4
+ import torch
5
+ from torch import nn
6
+ from torch.nn import functional as F
7
+
8
+ from torch.backends.cuda import sdp_kernel
9
+ from packaging import version
10
+
11
+ from .nn.layers import Snake1d
12
+
13
+
14
+ class ResidualBlock(nn.Module):
15
+ def __init__(self, main, skip=None):
16
+ super().__init__()
17
+ self.main = nn.Sequential(*main)
18
+ self.skip = skip if skip else nn.Identity()
19
+
20
+ def forward(self, input):
21
+ return self.main(input) + self.skip(input)
22
+
23
+
24
+ class ResConvBlock(ResidualBlock):
25
+ def __init__(self, c_in, c_mid, c_out, is_last=False, kernel_size=5, conv_bias=True, use_snake=False):
26
+ skip = None if c_in == c_out else nn.Conv1d(c_in, c_out, 1, bias=False)
27
+ super().__init__([
28
+ nn.Conv1d(c_in, c_mid, kernel_size, padding=kernel_size//2, bias=conv_bias),
29
+ nn.GroupNorm(1, c_mid),
30
+ Snake1d(c_mid) if use_snake else nn.GELU(),
31
+ nn.Conv1d(c_mid, c_out, kernel_size, padding=kernel_size//2, bias=conv_bias),
32
+ nn.GroupNorm(1, c_out) if not is_last else nn.Identity(),
33
+ (Snake1d(c_out) if use_snake else nn.GELU()) if not is_last else nn.Identity(),
34
+ ], skip)
35
+
36
+
37
+ class SelfAttention1d(nn.Module):
38
+ def __init__(self, c_in, n_head=1, dropout_rate=0.):
39
+ super().__init__()
40
+ assert c_in % n_head == 0
41
+ self.norm = nn.GroupNorm(1, c_in)
42
+ self.n_head = n_head
43
+ self.qkv_proj = nn.Conv1d(c_in, c_in * 3, 1)
44
+ self.out_proj = nn.Conv1d(c_in, c_in, 1)
45
+ self.dropout = nn.Dropout(dropout_rate, inplace=True)
46
+
47
+ self.use_flash = torch.cuda.is_available() and version.parse(torch.__version__) >= version.parse('2.0.0')
48
+
49
+ if not self.use_flash:
50
+ return
51
+
52
+ device_properties = torch.cuda.get_device_properties(torch.device('cuda'))
53
+
54
+ if device_properties.major == 8 and device_properties.minor == 0:
55
+ # Use flash attention for A100 GPUs
56
+ self.sdp_kernel_config = (True, False, False)
57
+ else:
58
+ # Don't use flash attention for other GPUs
59
+ self.sdp_kernel_config = (False, True, True)
60
+
61
+ def forward(self, input):
62
+ n, c, s = input.shape
63
+ qkv = self.qkv_proj(self.norm(input))
64
+ qkv = qkv.view(
65
+ [n, self.n_head * 3, c // self.n_head, s]).transpose(2, 3)
66
+ q, k, v = qkv.chunk(3, dim=1)
67
+ scale = k.shape[3]**-0.25
68
+
69
+ if self.use_flash:
70
+ with sdp_kernel(*self.sdp_kernel_config):
71
+ y = F.scaled_dot_product_attention(q, k, v, is_causal=False).contiguous().view([n, c, s])
72
+ else:
73
+ att = ((q * scale) @ (k.transpose(2, 3) * scale)).softmax(3)
74
+ y = (att @ v).transpose(2, 3).contiguous().view([n, c, s])
75
+
76
+
77
+ return input + self.dropout(self.out_proj(y))
78
+
79
+
80
+ class SkipBlock(nn.Module):
81
+ def __init__(self, *main):
82
+ super().__init__()
83
+ self.main = nn.Sequential(*main)
84
+
85
+ def forward(self, input):
86
+ return torch.cat([self.main(input), input], dim=1)
87
+
88
+
89
+ class FourierFeatures(nn.Module):
90
+ def __init__(self, in_features, out_features, std=1.):
91
+ super().__init__()
92
+ assert out_features % 2 == 0
93
+ self.weight = nn.Parameter(torch.randn(
94
+ [out_features // 2, in_features]) * std)
95
+
96
+ def forward(self, input):
97
+ f = 2 * math.pi * input @ self.weight.T
98
+ return torch.cat([f.cos(), f.sin()], dim=-1)
99
+
100
+
101
+ def expand_to_planes(input, shape):
102
+ return input[..., None].repeat([1, 1, shape[2]])
103
+
104
+ _kernels = {
105
+ 'linear':
106
+ [1 / 8, 3 / 8, 3 / 8, 1 / 8],
107
+ 'cubic':
108
+ [-0.01171875, -0.03515625, 0.11328125, 0.43359375,
109
+ 0.43359375, 0.11328125, -0.03515625, -0.01171875],
110
+ 'lanczos3':
111
+ [0.003689131001010537, 0.015056144446134567, -0.03399861603975296,
112
+ -0.066637322306633, 0.13550527393817902, 0.44638532400131226,
113
+ 0.44638532400131226, 0.13550527393817902, -0.066637322306633,
114
+ -0.03399861603975296, 0.015056144446134567, 0.003689131001010537]
115
+ }
116
+
117
+
118
+ class Downsample1d(nn.Module):
119
+ def __init__(self, kernel='linear', pad_mode='reflect', channels_last=False):
120
+ super().__init__()
121
+ self.pad_mode = pad_mode
122
+ kernel_1d = torch.tensor(_kernels[kernel])
123
+ self.pad = kernel_1d.shape[0] // 2 - 1
124
+ self.register_buffer('kernel', kernel_1d)
125
+ self.channels_last = channels_last
126
+
127
+ def forward(self, x):
128
+ if self.channels_last:
129
+ x = x.permute(0, 2, 1)
130
+ x = F.pad(x, (self.pad,) * 2, self.pad_mode)
131
+ weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0]])
132
+ indices = torch.arange(x.shape[1], device=x.device)
133
+ weight[indices, indices] = self.kernel.to(weight)
134
+ x = F.conv1d(x, weight, stride=2)
135
+ if self.channels_last:
136
+ x = x.permute(0, 2, 1)
137
+ return x
138
+
139
+
140
+ class Upsample1d(nn.Module):
141
+ def __init__(self, kernel='linear', pad_mode='reflect', channels_last=False):
142
+ super().__init__()
143
+ self.pad_mode = pad_mode
144
+ kernel_1d = torch.tensor(_kernels[kernel]) * 2
145
+ self.pad = kernel_1d.shape[0] // 2 - 1
146
+ self.register_buffer('kernel', kernel_1d)
147
+ self.channels_last = channels_last
148
+
149
+ def forward(self, x):
150
+ if self.channels_last:
151
+ x = x.permute(0, 2, 1)
152
+ x = F.pad(x, ((self.pad + 1) // 2,) * 2, self.pad_mode)
153
+ weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0]])
154
+ indices = torch.arange(x.shape[1], device=x.device)
155
+ weight[indices, indices] = self.kernel.to(weight)
156
+ x = F.conv_transpose1d(x, weight, stride=2, padding=self.pad * 2 + 1)
157
+ if self.channels_last:
158
+ x = x.permute(0, 2, 1)
159
+ return x
160
+
161
+
162
+ def Downsample1d_2(
163
+ in_channels: int, out_channels: int, factor: int, kernel_multiplier: int = 2
164
+ ) -> nn.Module:
165
+ assert kernel_multiplier % 2 == 0, "Kernel multiplier must be even"
166
+
167
+ return nn.Conv1d(
168
+ in_channels=in_channels,
169
+ out_channels=out_channels,
170
+ kernel_size=factor * kernel_multiplier + 1,
171
+ stride=factor,
172
+ padding=factor * (kernel_multiplier // 2),
173
+ )
174
+
175
+
176
+ def Upsample1d_2(
177
+ in_channels: int, out_channels: int, factor: int, use_nearest: bool = False
178
+ ) -> nn.Module:
179
+
180
+ if factor == 1:
181
+ return nn.Conv1d(
182
+ in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1
183
+ )
184
+
185
+ if use_nearest:
186
+ return nn.Sequential(
187
+ nn.Upsample(scale_factor=factor, mode="nearest"),
188
+ nn.Conv1d(
189
+ in_channels=in_channels,
190
+ out_channels=out_channels,
191
+ kernel_size=3,
192
+ padding=1,
193
+ ),
194
+ )
195
+ else:
196
+ return nn.ConvTranspose1d(
197
+ in_channels=in_channels,
198
+ out_channels=out_channels,
199
+ kernel_size=factor * 2,
200
+ stride=factor,
201
+ padding=factor // 2 + factor % 2,
202
+ output_padding=factor % 2,
203
+ )
204
+
205
+
206
+ def zero_init(layer):
207
+ nn.init.zeros_(layer.weight)
208
+ if layer.bias is not None:
209
+ nn.init.zeros_(layer.bias)
210
+ return layer
211
+
212
+
213
+ def rms_norm(x, scale, eps):
214
+ dtype = reduce(torch.promote_types, (x.dtype, scale.dtype, torch.float32))
215
+ mean_sq = torch.mean(x.to(dtype)**2, dim=-1, keepdim=True)
216
+ scale = scale.to(dtype) * torch.rsqrt(mean_sq + eps)
217
+ return x * scale.to(x.dtype)
218
+
219
+ #rms_norm = torch.compile(rms_norm)
220
+
221
+ class AdaRMSNorm(nn.Module):
222
+ def __init__(self, features, cond_features, eps=1e-6):
223
+ super().__init__()
224
+ self.eps = eps
225
+ self.linear = zero_init(nn.Linear(cond_features, features, bias=False))
226
+
227
+ def extra_repr(self):
228
+ return f"eps={self.eps},"
229
+
230
+ def forward(self, x, cond):
231
+ return rms_norm(x, self.linear(cond)[:, None, :] + 1, self.eps)
232
+
233
+
234
+ def normalize(x, eps=1e-4):
235
+ dim = list(range(1, x.ndim))
236
+ n = torch.linalg.vector_norm(x, dim=dim, keepdim=True)
237
+ alpha = np.sqrt(n.numel() / x.numel())
238
+ return x / torch.add(eps, n, alpha=alpha)
239
+
240
+
241
+ class ForcedWNConv1d(nn.Module):
242
+ def __init__(self, in_channels, out_channels, kernel_size=1):
243
+ super().__init__()
244
+ self.weight = nn.Parameter(torch.randn([out_channels, in_channels, kernel_size]))
245
+
246
+ def forward(self, x):
247
+ if self.training:
248
+ with torch.no_grad():
249
+ self.weight.copy_(normalize(self.weight))
250
+
251
+ fan_in = self.weight[0].numel()
252
+
253
+ w = normalize(self.weight) / math.sqrt(fan_in)
254
+
255
+ return F.conv1d(x, w, padding='same')
256
+
257
+ # Kernels
258
+
259
+ use_compile = True
260
+
261
+ def compile(function, *args, **kwargs):
262
+ if not use_compile:
263
+ return function
264
+ try:
265
+ return torch.compile(function, *args, **kwargs)
266
+ except RuntimeError:
267
+ return function
268
+
269
+
270
+ @compile
271
+ def linear_geglu(x, weight, bias=None):
272
+ x = x @ weight.mT
273
+ if bias is not None:
274
+ x = x + bias
275
+ x, gate = x.chunk(2, dim=-1)
276
+ return x * F.gelu(gate)
277
+
278
+
279
+ @compile
280
+ def rms_norm(x, scale, eps):
281
+ dtype = reduce(torch.promote_types, (x.dtype, scale.dtype, torch.float32))
282
+ mean_sq = torch.mean(x.to(dtype)**2, dim=-1, keepdim=True)
283
+ scale = scale.to(dtype) * torch.rsqrt(mean_sq + eps)
284
+ return x * scale.to(x.dtype)
285
+
286
+ # Layers
287
+
288
+
289
+ class LinearGEGLU(nn.Linear):
290
+ def __init__(self, in_features, out_features, bias=True):
291
+ super().__init__(in_features, out_features * 2, bias=bias)
292
+ self.out_features = out_features
293
+
294
+ def forward(self, x):
295
+ return linear_geglu(x, self.weight, self.bias)
296
+
297
+
298
+ class RMSNorm(nn.Module):
299
+ def __init__(self, shape, fix_scale = False, eps=1e-6):
300
+ super().__init__()
301
+ self.eps = eps
302
+
303
+ if fix_scale:
304
+ self.register_buffer("scale", torch.ones(shape))
305
+ else:
306
+ self.scale = nn.Parameter(torch.ones(shape))
307
+
308
+ def extra_repr(self):
309
+ return f"shape={tuple(self.scale.shape)}, eps={self.eps}"
310
+
311
+ def forward(self, x):
312
+ return rms_norm(x, self.scale, self.eps)
313
+
314
+
315
+ # jit script make it 1.4x faster and save GPU memory
316
+ @torch.jit.script
317
+ def snake_beta(x, alpha, beta):
318
+ return x + (1.0 / (beta + 0.000000001)) * pow(torch.sin(x * alpha), 2)
319
+
320
+ # try:
321
+ # snake_beta = torch.compile(snake_beta)
322
+ # except RuntimeError:
323
+ # pass
324
+
325
+
326
+ # Adapted from https://github.com/NVIDIA/BigVGAN/blob/main/activations.py under MIT license
327
+ # License available in LICENSES/LICENSE_NVIDIA.txt
328
+ class SnakeBeta(nn.Module):
329
+
330
+ def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True):
331
+ super(SnakeBeta, self).__init__()
332
+ self.in_features = in_features
333
+
334
+ # initialize alpha
335
+ self.alpha_logscale = alpha_logscale
336
+ if self.alpha_logscale:
337
+ # log scale alphas initialized to zeros
338
+ self.alpha = nn.Parameter(torch.zeros(in_features) * alpha)
339
+ self.beta = nn.Parameter(torch.zeros(in_features) * alpha)
340
+ else:
341
+ # linear scale alphas initialized to ones
342
+ self.alpha = nn.Parameter(torch.ones(in_features) * alpha)
343
+ self.beta = nn.Parameter(torch.ones(in_features) * alpha)
344
+
345
+ self.alpha.requires_grad = alpha_trainable
346
+ self.beta.requires_grad = alpha_trainable
347
+
348
+ # self.no_div_by_zero = 0.000000001
349
+
350
+ def forward(self, x):
351
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1)
352
+ # line up with x to [B, C, T]
353
+ beta = self.beta.unsqueeze(0).unsqueeze(-1)
354
+ if self.alpha_logscale:
355
+ alpha = torch.exp(alpha)
356
+ beta = torch.exp(beta)
357
+ x = snake_beta(x, alpha, beta)
358
+
359
+ return x
vae_modules/stable_vae/models/bottleneck.py ADDED
@@ -0,0 +1,346 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torch.nn import functional as F
4
+
5
+ from einops import rearrange
6
+ from vector_quantize_pytorch import ResidualVQ, FSQ
7
+ from .nn.quantize import ResidualVectorQuantize as DACResidualVQ
8
+
9
+
10
+ class Bottleneck(nn.Module):
11
+ def __init__(self, is_discrete: bool = False):
12
+ super().__init__()
13
+
14
+ self.is_discrete = is_discrete
15
+
16
+ def encode(self, x, return_info=False, **kwargs):
17
+ raise NotImplementedError
18
+
19
+ def decode(self, x):
20
+ raise NotImplementedError
21
+
22
+
23
+ class DiscreteBottleneck(Bottleneck):
24
+ def __init__(self, num_quantizers, codebook_size, tokens_id):
25
+ super().__init__(is_discrete=True)
26
+
27
+ self.num_quantizers = num_quantizers
28
+ self.codebook_size = codebook_size
29
+ self.tokens_id = tokens_id
30
+
31
+ def decode_tokens(self, codes, **kwargs):
32
+ raise NotImplementedError
33
+
34
+
35
+ class TanhBottleneck(Bottleneck):
36
+ def __init__(self):
37
+ super().__init__(is_discrete=False)
38
+ self.tanh = nn.Tanh()
39
+
40
+ def encode(self, x, return_info=False):
41
+ info = {}
42
+
43
+ x = torch.tanh(x)
44
+
45
+ if return_info:
46
+ return x, info
47
+ else:
48
+ return x
49
+
50
+ def decode(self, x):
51
+ return x
52
+
53
+
54
+ @torch.jit.script
55
+ def vae_sample_kl(mean, scale):
56
+ stdev = nn.functional.softplus(scale) + 1e-4
57
+ var = stdev * stdev
58
+ logvar = torch.log(var)
59
+ latents = torch.randn_like(mean) * stdev + mean
60
+
61
+ kl = (mean * mean + var - logvar - 1).sum(1).mean()
62
+
63
+ return latents, kl
64
+
65
+
66
+ @torch.jit.script
67
+ def vae_sample(mean, scale):
68
+ stdev = nn.functional.softplus(scale) + 1e-4
69
+ latents = torch.randn_like(mean) * stdev + mean
70
+ return latents
71
+
72
+
73
+ class VAEBottleneck(Bottleneck):
74
+ def __init__(self):
75
+ super().__init__(is_discrete=False)
76
+
77
+ def encode(self, x, return_info=False, **kwargs):
78
+ mean, scale = x.chunk(2, dim=1)
79
+
80
+ if return_info:
81
+ info = {}
82
+ x, kl = vae_sample_kl(mean, scale)
83
+ info["kl"] = kl
84
+ return x, info
85
+ else:
86
+ x = vae_sample(mean, scale)
87
+ return x
88
+
89
+ def decode(self, x):
90
+ return x
91
+
92
+
93
+ def compute_mean_kernel(x, y):
94
+ kernel_input = (x[:, None] - y[None]).pow(2).mean(2) / x.shape[-1]
95
+ return torch.exp(-kernel_input).mean()
96
+
97
+
98
+ def compute_mmd(latents):
99
+ latents_reshaped = latents.permute(0, 2, 1).reshape(-1, latents.shape[1])
100
+ noise = torch.randn_like(latents_reshaped)
101
+
102
+ latents_kernel = compute_mean_kernel(latents_reshaped, latents_reshaped)
103
+ noise_kernel = compute_mean_kernel(noise, noise)
104
+ latents_noise_kernel = compute_mean_kernel(latents_reshaped, noise)
105
+
106
+ mmd = latents_kernel + noise_kernel - 2 * latents_noise_kernel
107
+ return mmd.mean()
108
+
109
+
110
+ class WassersteinBottleneck(Bottleneck):
111
+ def __init__(self, noise_augment_dim: int = 0):
112
+ super().__init__(is_discrete=False)
113
+
114
+ self.noise_augment_dim = noise_augment_dim
115
+
116
+ def encode(self, x, return_info=False):
117
+ info = {}
118
+
119
+ if self.training and return_info:
120
+ mmd = compute_mmd(x)
121
+ info["mmd"] = mmd
122
+
123
+ if return_info:
124
+ return x, info
125
+
126
+ return x
127
+
128
+ def decode(self, x):
129
+
130
+ if self.noise_augment_dim > 0:
131
+ noise = torch.randn(x.shape[0], self.noise_augment_dim,
132
+ x.shape[-1]).type_as(x)
133
+ x = torch.cat([x, noise], dim=1)
134
+
135
+ return x
136
+
137
+
138
+ class L2Bottleneck(Bottleneck):
139
+ def __init__(self):
140
+ super().__init__(is_discrete=False)
141
+
142
+ def encode(self, x, return_info=False):
143
+ info = {}
144
+
145
+ x = F.normalize(x, dim=1)
146
+
147
+ if return_info:
148
+ return x, info
149
+ else:
150
+ return x
151
+
152
+ def decode(self, x):
153
+ return F.normalize(x, dim=1)
154
+
155
+
156
+ class RVQBottleneck(DiscreteBottleneck):
157
+ def __init__(self, **quantizer_kwargs):
158
+ super().__init__(num_quantizers = quantizer_kwargs["num_quantizers"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "quantizer_indices")
159
+ self.quantizer = ResidualVQ(**quantizer_kwargs)
160
+ self.num_quantizers = quantizer_kwargs["num_quantizers"]
161
+
162
+ def encode(self, x, return_info=False, **kwargs):
163
+ info = {}
164
+
165
+ x = rearrange(x, "b c n -> b n c")
166
+ x, indices, loss = self.quantizer(x)
167
+ x = rearrange(x, "b n c -> b c n")
168
+
169
+ info["quantizer_indices"] = indices
170
+ info["quantizer_loss"] = loss.mean()
171
+
172
+ if return_info:
173
+ return x, info
174
+ else:
175
+ return x
176
+
177
+ def decode(self, x):
178
+ return x
179
+
180
+ def decode_tokens(self, codes, **kwargs):
181
+ latents = self.quantizer.get_outputs_from_indices(codes)
182
+
183
+ return self.decode(latents, **kwargs)
184
+
185
+
186
+ class RVQVAEBottleneck(DiscreteBottleneck):
187
+ def __init__(self, **quantizer_kwargs):
188
+ super().__init__(num_quantizers = quantizer_kwargs["num_quantizers"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "quantizer_indices")
189
+ self.quantizer = ResidualVQ(**quantizer_kwargs)
190
+ self.num_quantizers = quantizer_kwargs["num_quantizers"]
191
+
192
+ def encode(self, x, return_info=False):
193
+ info = {}
194
+
195
+ x, kl = vae_sample(*x.chunk(2, dim=1))
196
+
197
+ info["kl"] = kl
198
+
199
+ x = rearrange(x, "b c n -> b n c")
200
+ x, indices, loss = self.quantizer(x)
201
+ x = rearrange(x, "b n c -> b c n")
202
+
203
+ info["quantizer_indices"] = indices
204
+ info["quantizer_loss"] = loss.mean()
205
+
206
+ if return_info:
207
+ return x, info
208
+ else:
209
+ return x
210
+
211
+ def decode(self, x):
212
+ return x
213
+
214
+ def decode_tokens(self, codes, **kwargs):
215
+ latents = self.quantizer.get_outputs_from_indices(codes)
216
+
217
+ return self.decode(latents, **kwargs)
218
+
219
+
220
+ class DACRVQBottleneck(DiscreteBottleneck):
221
+ def __init__(self, quantize_on_decode=False, **quantizer_kwargs):
222
+ super().__init__(num_quantizers = quantizer_kwargs["n_codebooks"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "codes")
223
+ self.quantizer = DACResidualVQ(**quantizer_kwargs)
224
+ self.num_quantizers = quantizer_kwargs["n_codebooks"]
225
+ self.quantize_on_decode = quantize_on_decode
226
+
227
+ def encode(self, x, return_info=False, **kwargs):
228
+ info = {}
229
+
230
+ info["pre_quantizer"] = x
231
+
232
+ if self.quantize_on_decode:
233
+ return x, info if return_info else x
234
+
235
+ z, codes, latents, commitment_loss, codebook_loss = self.quantizer(x, **kwargs)
236
+
237
+ output = {
238
+ "z": z,
239
+ "codes": codes,
240
+ "latents": latents,
241
+ "vq/commitment_loss": commitment_loss,
242
+ "vq/codebook_loss": codebook_loss,
243
+ }
244
+
245
+ output["vq/commitment_loss"] /= self.num_quantizers
246
+ output["vq/codebook_loss"] /= self.num_quantizers
247
+
248
+ info.update(output)
249
+
250
+ if return_info:
251
+ return output["z"], info
252
+
253
+ return output["z"]
254
+
255
+ def decode(self, x):
256
+
257
+ if self.quantize_on_decode:
258
+ x = self.quantizer(x)[0]
259
+
260
+ return x
261
+
262
+ def decode_tokens(self, codes, **kwargs):
263
+ latents, _, _ = self.quantizer.from_codes(codes)
264
+
265
+ return self.decode(latents, **kwargs)
266
+
267
+
268
+ class DACRVQVAEBottleneck(DiscreteBottleneck):
269
+ def __init__(self, quantize_on_decode=False, **quantizer_kwargs):
270
+ super().__init__(num_quantizers = quantizer_kwargs["n_codebooks"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "codes")
271
+ self.quantizer = DACResidualVQ(**quantizer_kwargs)
272
+ self.num_quantizers = quantizer_kwargs["n_codebooks"]
273
+ self.quantize_on_decode = quantize_on_decode
274
+
275
+ def encode(self, x, return_info=False, n_quantizers: int = None):
276
+ info = {}
277
+
278
+ mean, scale = x.chunk(2, dim=1)
279
+
280
+ x, kl = vae_sample(mean, scale)
281
+
282
+ info["pre_quantizer"] = x
283
+ info["kl"] = kl
284
+
285
+ if self.quantize_on_decode:
286
+ return x, info if return_info else x
287
+
288
+ z, codes, latents, commitment_loss, codebook_loss = self.quantizer(x, n_quantizers=n_quantizers)
289
+
290
+ output = {
291
+ "z": z,
292
+ "codes": codes,
293
+ "latents": latents,
294
+ "vq/commitment_loss": commitment_loss,
295
+ "vq/codebook_loss": codebook_loss,
296
+ }
297
+
298
+ output["vq/commitment_loss"] /= self.num_quantizers
299
+ output["vq/codebook_loss"] /= self.num_quantizers
300
+
301
+ info.update(output)
302
+
303
+ if return_info:
304
+ return output["z"], info
305
+
306
+ return output["z"]
307
+
308
+ def decode(self, x):
309
+
310
+ if self.quantize_on_decode:
311
+ x = self.quantizer(x)[0]
312
+
313
+ return x
314
+
315
+ def decode_tokens(self, codes, **kwargs):
316
+ latents, _, _ = self.quantizer.from_codes(codes)
317
+
318
+ return self.decode(latents, **kwargs)
319
+
320
+
321
+ class FSQBottleneck(DiscreteBottleneck):
322
+ def __init__(self, dim, levels):
323
+ super().__init__(num_quantizers = 1, codebook_size = levels ** dim, tokens_id = "quantizer_indices")
324
+ self.quantizer = FSQ(levels=[levels] * dim)
325
+
326
+ def encode(self, x, return_info=False):
327
+ info = {}
328
+
329
+ x = rearrange(x, "b c n -> b n c")
330
+ x, indices = self.quantizer(x)
331
+ x = rearrange(x, "b n c -> b c n")
332
+
333
+ info["quantizer_indices"] = indices
334
+
335
+ if return_info:
336
+ return x, info
337
+ else:
338
+ return x
339
+
340
+ def decode(self, x):
341
+ return x
342
+
343
+ def decode_tokens(self, tokens, **kwargs):
344
+ latents = self.quantizer.indices_to_codes(tokens)
345
+
346
+ return self.decode(latents, **kwargs)
vae_modules/stable_vae/models/factory.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ def create_model_from_config(model_config):
4
+ model_type = model_config.get('model_type', None)
5
+
6
+ assert model_type is not None, 'model_type must be specified in model config'
7
+
8
+ if model_type == 'autoencoder':
9
+ from .autoencoders import create_autoencoder_from_config
10
+ return create_autoencoder_from_config(model_config)
11
+ elif model_type == 'diffusion_uncond':
12
+ from .diffusion import create_diffusion_uncond_from_config
13
+ return create_diffusion_uncond_from_config(model_config)
14
+ elif model_type == 'diffusion_cond' or model_type == 'diffusion_cond_inpaint' or model_type == "diffusion_prior":
15
+ from .diffusion import create_diffusion_cond_from_config
16
+ return create_diffusion_cond_from_config(model_config)
17
+ elif model_type == 'diffusion_autoencoder':
18
+ from .autoencoders import create_diffAE_from_config
19
+ return create_diffAE_from_config(model_config)
20
+ elif model_type == 'lm':
21
+ from .lm import create_audio_lm_from_config
22
+ return create_audio_lm_from_config(model_config)
23
+ else:
24
+ raise NotImplementedError(f'Unknown model type: {model_type}')
25
+
26
+ def create_model_from_config_path(model_config_path):
27
+ with open(model_config_path) as f:
28
+ model_config = json.load(f)
29
+
30
+ return create_model_from_config(model_config)
31
+
32
+ def create_pretransform_from_config(pretransform_config, sample_rate):
33
+ pretransform_type = pretransform_config.get('type', None)
34
+
35
+ assert pretransform_type is not None, 'type must be specified in pretransform config'
36
+
37
+ if pretransform_type == 'autoencoder':
38
+ from .autoencoders import create_autoencoder_from_config
39
+ from .pretransforms import AutoencoderPretransform
40
+
41
+ # Create fake top-level config to pass sample rate to autoencoder constructor
42
+ # This is a bit of a hack but it keeps us from re-defining the sample rate in the config
43
+ autoencoder_config = {"sample_rate": sample_rate, "model": pretransform_config["config"]}
44
+ autoencoder = create_autoencoder_from_config(autoencoder_config)
45
+
46
+ scale = pretransform_config.get("scale", 1.0)
47
+ model_half = pretransform_config.get("model_half", False)
48
+ iterate_batch = pretransform_config.get("iterate_batch", False)
49
+ chunked = pretransform_config.get("chunked", False)
50
+
51
+ pretransform = AutoencoderPretransform(autoencoder, scale=scale, model_half=model_half, iterate_batch=iterate_batch, chunked=chunked)
52
+ elif pretransform_type == 'wavelet':
53
+ from .pretransforms import WaveletPretransform
54
+
55
+ wavelet_config = pretransform_config["config"]
56
+ channels = wavelet_config["channels"]
57
+ levels = wavelet_config["levels"]
58
+ wavelet = wavelet_config["wavelet"]
59
+
60
+ pretransform = WaveletPretransform(channels, levels, wavelet)
61
+ elif pretransform_type == 'pqmf':
62
+ from .pretransforms import PQMFPretransform
63
+ pqmf_config = pretransform_config["config"]
64
+ pretransform = PQMFPretransform(**pqmf_config)
65
+ elif pretransform_type == 'dac_pretrained':
66
+ from .pretransforms import PretrainedDACPretransform
67
+ pretrained_dac_config = pretransform_config["config"]
68
+ pretransform = PretrainedDACPretransform(**pretrained_dac_config)
69
+ elif pretransform_type == "audiocraft_pretrained":
70
+ from .pretransforms import AudiocraftCompressionPretransform
71
+
72
+ audiocraft_config = pretransform_config["config"]
73
+ pretransform = AudiocraftCompressionPretransform(**audiocraft_config)
74
+ else:
75
+ raise NotImplementedError(f'Unknown pretransform type: {pretransform_type}')
76
+
77
+ enable_grad = pretransform_config.get('enable_grad', False)
78
+ pretransform.enable_grad = enable_grad
79
+
80
+ pretransform.eval().requires_grad_(pretransform.enable_grad)
81
+
82
+ return pretransform
83
+
84
+ def create_bottleneck_from_config(bottleneck_config):
85
+ bottleneck_type = bottleneck_config.get('type', None)
86
+
87
+ assert bottleneck_type is not None, 'type must be specified in bottleneck config'
88
+
89
+ if bottleneck_type == 'tanh':
90
+ from .bottleneck import TanhBottleneck
91
+ bottleneck = TanhBottleneck()
92
+ elif bottleneck_type == 'vae':
93
+ from .bottleneck import VAEBottleneck
94
+ bottleneck = VAEBottleneck()
95
+ elif bottleneck_type == 'rvq':
96
+ from .bottleneck import RVQBottleneck
97
+
98
+ quantizer_params = {
99
+ "dim": 128,
100
+ "codebook_size": 1024,
101
+ "num_quantizers": 8,
102
+ "decay": 0.99,
103
+ "kmeans_init": True,
104
+ "kmeans_iters": 50,
105
+ "threshold_ema_dead_code": 2,
106
+ }
107
+
108
+ quantizer_params.update(bottleneck_config["config"])
109
+
110
+ bottleneck = RVQBottleneck(**quantizer_params)
111
+ elif bottleneck_type == "dac_rvq":
112
+ from .bottleneck import DACRVQBottleneck
113
+
114
+ bottleneck = DACRVQBottleneck(**bottleneck_config["config"])
115
+
116
+ elif bottleneck_type == 'rvq_vae':
117
+ from .bottleneck import RVQVAEBottleneck
118
+
119
+ quantizer_params = {
120
+ "dim": 128,
121
+ "codebook_size": 1024,
122
+ "num_quantizers": 8,
123
+ "decay": 0.99,
124
+ "kmeans_init": True,
125
+ "kmeans_iters": 50,
126
+ "threshold_ema_dead_code": 2,
127
+ }
128
+
129
+ quantizer_params.update(bottleneck_config["config"])
130
+
131
+ bottleneck = RVQVAEBottleneck(**quantizer_params)
132
+
133
+ elif bottleneck_type == 'dac_rvq_vae':
134
+ from .bottleneck import DACRVQVAEBottleneck
135
+ bottleneck = DACRVQVAEBottleneck(**bottleneck_config["config"])
136
+ elif bottleneck_type == 'l2_norm':
137
+ from .bottleneck import L2Bottleneck
138
+ bottleneck = L2Bottleneck()
139
+ elif bottleneck_type == "wasserstein":
140
+ from .bottleneck import WassersteinBottleneck
141
+ bottleneck = WassersteinBottleneck(**bottleneck_config.get("config", {}))
142
+ elif bottleneck_type == "fsq":
143
+ from .bottleneck import FSQBottleneck
144
+ bottleneck = FSQBottleneck(**bottleneck_config["config"])
145
+ else:
146
+ raise NotImplementedError(f'Unknown bottleneck type: {bottleneck_type}')
147
+
148
+ requires_grad = bottleneck_config.get('requires_grad', True)
149
+ if not requires_grad:
150
+ for param in bottleneck.parameters():
151
+ param.requires_grad = False
152
+
153
+ return bottleneck
vae_modules/stable_vae/models/nn/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from . import layers
2
+ from . import loss
3
+ from . import quantize
vae_modules/stable_vae/models/nn/layers.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from einops import rearrange
6
+ from torch.nn.utils import weight_norm
7
+
8
+
9
+ def WNConv1d(*args, **kwargs):
10
+ return weight_norm(nn.Conv1d(*args, **kwargs))
11
+
12
+
13
+ def WNConvTranspose1d(*args, **kwargs):
14
+ return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
15
+
16
+
17
+ # Scripting this brings model speed up 1.4x
18
+ @torch.jit.script
19
+ def snake(x, alpha):
20
+ shape = x.shape
21
+ x = x.reshape(shape[0], shape[1], -1)
22
+ x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
23
+ x = x.reshape(shape)
24
+ return x
25
+
26
+
27
+ class Snake1d(nn.Module):
28
+ def __init__(self, channels):
29
+ super().__init__()
30
+ self.alpha = nn.Parameter(torch.ones(1, channels, 1))
31
+
32
+ def forward(self, x):
33
+ return snake(x, self.alpha)
vae_modules/stable_vae/models/nn/loss.py ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import typing
2
+ from typing import List
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from audiotools import AudioSignal
7
+ from audiotools import STFTParams
8
+ from torch import nn
9
+
10
+
11
+ class L1Loss(nn.L1Loss):
12
+ """L1 Loss between AudioSignals. Defaults
13
+ to comparing ``audio_data``, but any
14
+ attribute of an AudioSignal can be used.
15
+
16
+ Parameters
17
+ ----------
18
+ attribute : str, optional
19
+ Attribute of signal to compare, defaults to ``audio_data``.
20
+ weight : float, optional
21
+ Weight of this loss, defaults to 1.0.
22
+
23
+ Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/distance.py
24
+ """
25
+
26
+ def __init__(self, attribute: str = "audio_data", weight: float = 1.0, **kwargs):
27
+ self.attribute = attribute
28
+ self.weight = weight
29
+ super().__init__(**kwargs)
30
+
31
+ def forward(self, x: AudioSignal, y: AudioSignal):
32
+ """
33
+ Parameters
34
+ ----------
35
+ x : AudioSignal
36
+ Estimate AudioSignal
37
+ y : AudioSignal
38
+ Reference AudioSignal
39
+
40
+ Returns
41
+ -------
42
+ torch.Tensor
43
+ L1 loss between AudioSignal attributes.
44
+ """
45
+ if isinstance(x, AudioSignal):
46
+ x = getattr(x, self.attribute)
47
+ y = getattr(y, self.attribute)
48
+ return super().forward(x, y)
49
+
50
+
51
+ class SISDRLoss(nn.Module):
52
+ """
53
+ Computes the Scale-Invariant Source-to-Distortion Ratio between a batch
54
+ of estimated and reference audio signals or aligned features.
55
+
56
+ Parameters
57
+ ----------
58
+ scaling : int, optional
59
+ Whether to use scale-invariant (True) or
60
+ signal-to-noise ratio (False), by default True
61
+ reduction : str, optional
62
+ How to reduce across the batch (either 'mean',
63
+ 'sum', or none).], by default ' mean'
64
+ zero_mean : int, optional
65
+ Zero mean the references and estimates before
66
+ computing the loss, by default True
67
+ clip_min : int, optional
68
+ The minimum possible loss value. Helps network
69
+ to not focus on making already good examples better, by default None
70
+ weight : float, optional
71
+ Weight of this loss, defaults to 1.0.
72
+
73
+ Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/distance.py
74
+ """
75
+
76
+ def __init__(
77
+ self,
78
+ scaling: int = True,
79
+ reduction: str = "mean",
80
+ zero_mean: int = True,
81
+ clip_min: int = None,
82
+ weight: float = 1.0,
83
+ ):
84
+ self.scaling = scaling
85
+ self.reduction = reduction
86
+ self.zero_mean = zero_mean
87
+ self.clip_min = clip_min
88
+ self.weight = weight
89
+ super().__init__()
90
+
91
+ def forward(self, x: AudioSignal, y: AudioSignal):
92
+ eps = 1e-8
93
+ # nb, nc, nt
94
+ if isinstance(x, AudioSignal):
95
+ references = x.audio_data
96
+ estimates = y.audio_data
97
+ else:
98
+ references = x
99
+ estimates = y
100
+
101
+ nb = references.shape[0]
102
+ references = references.reshape(nb, 1, -1).permute(0, 2, 1)
103
+ estimates = estimates.reshape(nb, 1, -1).permute(0, 2, 1)
104
+
105
+ # samples now on axis 1
106
+ if self.zero_mean:
107
+ mean_reference = references.mean(dim=1, keepdim=True)
108
+ mean_estimate = estimates.mean(dim=1, keepdim=True)
109
+ else:
110
+ mean_reference = 0
111
+ mean_estimate = 0
112
+
113
+ _references = references - mean_reference
114
+ _estimates = estimates - mean_estimate
115
+
116
+ references_projection = (_references**2).sum(dim=-2) + eps
117
+ references_on_estimates = (_estimates * _references).sum(dim=-2) + eps
118
+
119
+ scale = (
120
+ (references_on_estimates / references_projection).unsqueeze(1)
121
+ if self.scaling
122
+ else 1
123
+ )
124
+
125
+ e_true = scale * _references
126
+ e_res = _estimates - e_true
127
+
128
+ signal = (e_true**2).sum(dim=1)
129
+ noise = (e_res**2).sum(dim=1)
130
+ sdr = -10 * torch.log10(signal / noise + eps)
131
+
132
+ if self.clip_min is not None:
133
+ sdr = torch.clamp(sdr, min=self.clip_min)
134
+
135
+ if self.reduction == "mean":
136
+ sdr = sdr.mean()
137
+ elif self.reduction == "sum":
138
+ sdr = sdr.sum()
139
+ return sdr
140
+
141
+
142
+ class MultiScaleSTFTLoss(nn.Module):
143
+ """Computes the multi-scale STFT loss from [1].
144
+
145
+ Parameters
146
+ ----------
147
+ window_lengths : List[int], optional
148
+ Length of each window of each STFT, by default [2048, 512]
149
+ loss_fn : typing.Callable, optional
150
+ How to compare each loss, by default nn.L1Loss()
151
+ clamp_eps : float, optional
152
+ Clamp on the log magnitude, below, by default 1e-5
153
+ mag_weight : float, optional
154
+ Weight of raw magnitude portion of loss, by default 1.0
155
+ log_weight : float, optional
156
+ Weight of log magnitude portion of loss, by default 1.0
157
+ pow : float, optional
158
+ Power to raise magnitude to before taking log, by default 2.0
159
+ weight : float, optional
160
+ Weight of this loss, by default 1.0
161
+ match_stride : bool, optional
162
+ Whether to match the stride of convolutional layers, by default False
163
+
164
+ References
165
+ ----------
166
+
167
+ 1. Engel, Jesse, Chenjie Gu, and Adam Roberts.
168
+ "DDSP: Differentiable Digital Signal Processing."
169
+ International Conference on Learning Representations. 2019.
170
+
171
+ Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py
172
+ """
173
+
174
+ def __init__(
175
+ self,
176
+ window_lengths: List[int] = [2048, 512],
177
+ loss_fn: typing.Callable = nn.L1Loss(),
178
+ clamp_eps: float = 1e-5,
179
+ mag_weight: float = 1.0,
180
+ log_weight: float = 1.0,
181
+ pow: float = 2.0,
182
+ weight: float = 1.0,
183
+ match_stride: bool = False,
184
+ window_type: str = None,
185
+ ):
186
+ super().__init__()
187
+ self.stft_params = [
188
+ STFTParams(
189
+ window_length=w,
190
+ hop_length=w // 4,
191
+ match_stride=match_stride,
192
+ window_type=window_type,
193
+ )
194
+ for w in window_lengths
195
+ ]
196
+ self.loss_fn = loss_fn
197
+ self.log_weight = log_weight
198
+ self.mag_weight = mag_weight
199
+ self.clamp_eps = clamp_eps
200
+ self.weight = weight
201
+ self.pow = pow
202
+
203
+ def forward(self, x: AudioSignal, y: AudioSignal):
204
+ """Computes multi-scale STFT between an estimate and a reference
205
+ signal.
206
+
207
+ Parameters
208
+ ----------
209
+ x : AudioSignal
210
+ Estimate signal
211
+ y : AudioSignal
212
+ Reference signal
213
+
214
+ Returns
215
+ -------
216
+ torch.Tensor
217
+ Multi-scale STFT loss.
218
+ """
219
+ loss = 0.0
220
+ for s in self.stft_params:
221
+ x.stft(s.window_length, s.hop_length, s.window_type)
222
+ y.stft(s.window_length, s.hop_length, s.window_type)
223
+ loss += self.log_weight * self.loss_fn(
224
+ x.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(),
225
+ y.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(),
226
+ )
227
+ loss += self.mag_weight * self.loss_fn(x.magnitude, y.magnitude)
228
+ return loss
229
+
230
+
231
+ class MelSpectrogramLoss(nn.Module):
232
+ """Compute distance between mel spectrograms. Can be used
233
+ in a multi-scale way.
234
+
235
+ Parameters
236
+ ----------
237
+ n_mels : List[int]
238
+ Number of mels per STFT, by default [150, 80],
239
+ window_lengths : List[int], optional
240
+ Length of each window of each STFT, by default [2048, 512]
241
+ loss_fn : typing.Callable, optional
242
+ How to compare each loss, by default nn.L1Loss()
243
+ clamp_eps : float, optional
244
+ Clamp on the log magnitude, below, by default 1e-5
245
+ mag_weight : float, optional
246
+ Weight of raw magnitude portion of loss, by default 1.0
247
+ log_weight : float, optional
248
+ Weight of log magnitude portion of loss, by default 1.0
249
+ pow : float, optional
250
+ Power to raise magnitude to before taking log, by default 2.0
251
+ weight : float, optional
252
+ Weight of this loss, by default 1.0
253
+ match_stride : bool, optional
254
+ Whether to match the stride of convolutional layers, by default False
255
+
256
+ Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py
257
+ """
258
+
259
+ def __init__(
260
+ self,
261
+ n_mels: List[int] = [150, 80],
262
+ window_lengths: List[int] = [2048, 512],
263
+ loss_fn: typing.Callable = nn.L1Loss(),
264
+ clamp_eps: float = 1e-5,
265
+ mag_weight: float = 1.0,
266
+ log_weight: float = 1.0,
267
+ pow: float = 2.0,
268
+ weight: float = 1.0,
269
+ match_stride: bool = False,
270
+ mel_fmin: List[float] = [0.0, 0.0],
271
+ mel_fmax: List[float] = [None, None],
272
+ window_type: str = None,
273
+ ):
274
+ super().__init__()
275
+ self.stft_params = [
276
+ STFTParams(
277
+ window_length=w,
278
+ hop_length=w // 4,
279
+ match_stride=match_stride,
280
+ window_type=window_type,
281
+ )
282
+ for w in window_lengths
283
+ ]
284
+ self.n_mels = n_mels
285
+ self.loss_fn = loss_fn
286
+ self.clamp_eps = clamp_eps
287
+ self.log_weight = log_weight
288
+ self.mag_weight = mag_weight
289
+ self.weight = weight
290
+ self.mel_fmin = mel_fmin
291
+ self.mel_fmax = mel_fmax
292
+ self.pow = pow
293
+
294
+ def forward(self, x: AudioSignal, y: AudioSignal):
295
+ """Computes mel loss between an estimate and a reference
296
+ signal.
297
+
298
+ Parameters
299
+ ----------
300
+ x : AudioSignal
301
+ Estimate signal
302
+ y : AudioSignal
303
+ Reference signal
304
+
305
+ Returns
306
+ -------
307
+ torch.Tensor
308
+ Mel loss.
309
+ """
310
+ loss = 0.0
311
+ for n_mels, fmin, fmax, s in zip(
312
+ self.n_mels, self.mel_fmin, self.mel_fmax, self.stft_params
313
+ ):
314
+ kwargs = {
315
+ "window_length": s.window_length,
316
+ "hop_length": s.hop_length,
317
+ "window_type": s.window_type,
318
+ }
319
+ x_mels = x.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs)
320
+ y_mels = y.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs)
321
+
322
+ loss += self.log_weight * self.loss_fn(
323
+ x_mels.clamp(self.clamp_eps).pow(self.pow).log10(),
324
+ y_mels.clamp(self.clamp_eps).pow(self.pow).log10(),
325
+ )
326
+ loss += self.mag_weight * self.loss_fn(x_mels, y_mels)
327
+ return loss
328
+
329
+
330
+ class GANLoss(nn.Module):
331
+ """
332
+ Computes a discriminator loss, given a discriminator on
333
+ generated waveforms/spectrograms compared to ground truth
334
+ waveforms/spectrograms. Computes the loss for both the
335
+ discriminator and the generator in separate functions.
336
+ """
337
+
338
+ def __init__(self, discriminator):
339
+ super().__init__()
340
+ self.discriminator = discriminator
341
+
342
+ def forward(self, fake, real):
343
+ d_fake = self.discriminator(fake.audio_data)
344
+ d_real = self.discriminator(real.audio_data)
345
+ return d_fake, d_real
346
+
347
+ def discriminator_loss(self, fake, real):
348
+ d_fake, d_real = self.forward(fake.clone().detach(), real)
349
+
350
+ loss_d = 0
351
+ for x_fake, x_real in zip(d_fake, d_real):
352
+ loss_d += torch.mean(x_fake[-1] ** 2)
353
+ loss_d += torch.mean((1 - x_real[-1]) ** 2)
354
+ return loss_d
355
+
356
+ def generator_loss(self, fake, real):
357
+ d_fake, d_real = self.forward(fake, real)
358
+
359
+ loss_g = 0
360
+ for x_fake in d_fake:
361
+ loss_g += torch.mean((1 - x_fake[-1]) ** 2)
362
+
363
+ loss_feature = 0
364
+
365
+ for i in range(len(d_fake)):
366
+ for j in range(len(d_fake[i]) - 1):
367
+ loss_feature += F.l1_loss(d_fake[i][j], d_real[i][j].detach())
368
+ return loss_g, loss_feature
vae_modules/stable_vae/models/nn/quantize.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from einops import rearrange
8
+ from torch.nn.utils import weight_norm
9
+
10
+ from .layers import WNConv1d
11
+
12
+
13
+ class VectorQuantize(nn.Module):
14
+ """
15
+ Implementation of VQ similar to Karpathy's repo:
16
+ https://github.com/karpathy/deep-vector-quantization
17
+ Additionally uses following tricks from Improved VQGAN
18
+ (https://arxiv.org/pdf/2110.04627.pdf):
19
+ 1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space
20
+ for improved codebook usage
21
+ 2. l2-normalized codes: Converts euclidean distance to cosine similarity which
22
+ improves training stability
23
+ """
24
+
25
+ def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int):
26
+ super().__init__()
27
+ self.codebook_size = codebook_size
28
+ self.codebook_dim = codebook_dim
29
+
30
+ self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1)
31
+ self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1)
32
+ self.codebook = nn.Embedding(codebook_size, codebook_dim)
33
+
34
+ def forward(self, z):
35
+ """Quantized the input tensor using a fixed codebook and returns
36
+ the corresponding codebook vectors
37
+
38
+ Parameters
39
+ ----------
40
+ z : Tensor[B x D x T]
41
+
42
+ Returns
43
+ -------
44
+ Tensor[B x D x T]
45
+ Quantized continuous representation of input
46
+ Tensor[1]
47
+ Commitment loss to train encoder to predict vectors closer to codebook
48
+ entries
49
+ Tensor[1]
50
+ Codebook loss to update the codebook
51
+ Tensor[B x T]
52
+ Codebook indices (quantized discrete representation of input)
53
+ Tensor[B x D x T]
54
+ Projected latents (continuous representation of input before quantization)
55
+ """
56
+
57
+ # Factorized codes (ViT-VQGAN) Project input into low-dimensional space
58
+ z_e = self.in_proj(z) # z_e : (B x D x T)
59
+ z_q, indices = self.decode_latents(z_e)
60
+
61
+ commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])
62
+ codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2])
63
+
64
+ z_q = (
65
+ z_e + (z_q - z_e).detach()
66
+ ) # noop in forward pass, straight-through gradient estimator in backward pass
67
+
68
+ z_q = self.out_proj(z_q)
69
+
70
+ return z_q, commitment_loss, codebook_loss, indices, z_e
71
+
72
+ def embed_code(self, embed_id):
73
+ return F.embedding(embed_id, self.codebook.weight)
74
+
75
+ def decode_code(self, embed_id):
76
+ return self.embed_code(embed_id).transpose(1, 2)
77
+
78
+ def decode_latents(self, latents):
79
+ encodings = rearrange(latents, "b d t -> (b t) d")
80
+ codebook = self.codebook.weight # codebook: (N x D)
81
+
82
+ # L2 normalize encodings and codebook (ViT-VQGAN)
83
+ encodings = F.normalize(encodings)
84
+ codebook = F.normalize(codebook)
85
+
86
+ # Compute euclidean distance with codebook
87
+ dist = (
88
+ encodings.pow(2).sum(1, keepdim=True)
89
+ - 2 * encodings @ codebook.t()
90
+ + codebook.pow(2).sum(1, keepdim=True).t()
91
+ )
92
+ indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
93
+ z_q = self.decode_code(indices)
94
+ return z_q, indices
95
+
96
+
97
+ class ResidualVectorQuantize(nn.Module):
98
+ """
99
+ Introduced in SoundStream: An end2end neural audio codec
100
+ https://arxiv.org/abs/2107.03312
101
+ """
102
+
103
+ def __init__(
104
+ self,
105
+ input_dim: int = 512,
106
+ n_codebooks: int = 9,
107
+ codebook_size: int = 1024,
108
+ codebook_dim: Union[int, list] = 8,
109
+ quantizer_dropout: float = 0.0,
110
+ ):
111
+ super().__init__()
112
+ if isinstance(codebook_dim, int):
113
+ codebook_dim = [codebook_dim for _ in range(n_codebooks)]
114
+
115
+ self.n_codebooks = n_codebooks
116
+ self.codebook_dim = codebook_dim
117
+ self.codebook_size = codebook_size
118
+
119
+ self.quantizers = nn.ModuleList(
120
+ [
121
+ VectorQuantize(input_dim, codebook_size, codebook_dim[i])
122
+ for i in range(n_codebooks)
123
+ ]
124
+ )
125
+ self.quantizer_dropout = quantizer_dropout
126
+
127
+ def forward(self, z, n_quantizers: int = None):
128
+ """Quantized the input tensor using a fixed set of `n` codebooks and returns
129
+ the corresponding codebook vectors
130
+ Parameters
131
+ ----------
132
+ z : Tensor[B x D x T]
133
+ n_quantizers : int, optional
134
+ No. of quantizers to use
135
+ (n_quantizers < self.n_codebooks ex: for quantizer dropout)
136
+ Note: if `self.quantizer_dropout` is True, this argument is ignored
137
+ when in training mode, and a random number of quantizers is used.
138
+ Returns
139
+ -------
140
+ dict
141
+ A dictionary with the following keys:
142
+
143
+ "z" : Tensor[B x D x T]
144
+ Quantized continuous representation of input
145
+ "codes" : Tensor[B x N x T]
146
+ Codebook indices for each codebook
147
+ (quantized discrete representation of input)
148
+ "latents" : Tensor[B x N*D x T]
149
+ Projected latents (continuous representation of input before quantization)
150
+ "vq/commitment_loss" : Tensor[1]
151
+ Commitment loss to train encoder to predict vectors closer to codebook
152
+ entries
153
+ "vq/codebook_loss" : Tensor[1]
154
+ Codebook loss to update the codebook
155
+ """
156
+ z_q = 0
157
+ residual = z
158
+ commitment_loss = 0
159
+ codebook_loss = 0
160
+
161
+ codebook_indices = []
162
+ latents = []
163
+
164
+ if n_quantizers is None:
165
+ n_quantizers = self.n_codebooks
166
+ if self.training:
167
+ n_quantizers = torch.ones((z.shape[0],)) * self.n_codebooks + 1
168
+ dropout = torch.randint(1, self.n_codebooks + 1, (z.shape[0],))
169
+ n_dropout = int(z.shape[0] * self.quantizer_dropout)
170
+ n_quantizers[:n_dropout] = dropout[:n_dropout]
171
+ n_quantizers = n_quantizers.to(z.device)
172
+
173
+ for i, quantizer in enumerate(self.quantizers):
174
+ if self.training is False and i >= n_quantizers:
175
+ break
176
+
177
+ z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer(
178
+ residual
179
+ )
180
+
181
+ # Create mask to apply quantizer dropout
182
+ mask = (
183
+ torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers
184
+ )
185
+ z_q = z_q + z_q_i * mask[:, None, None]
186
+ residual = residual - z_q_i
187
+
188
+ # Sum losses
189
+ commitment_loss += (commitment_loss_i * mask).mean()
190
+ codebook_loss += (codebook_loss_i * mask).mean()
191
+
192
+ codebook_indices.append(indices_i)
193
+ latents.append(z_e_i)
194
+
195
+ codes = torch.stack(codebook_indices, dim=1)
196
+ latents = torch.cat(latents, dim=1)
197
+
198
+ return z_q, codes, latents, commitment_loss, codebook_loss
199
+
200
+ def from_codes(self, codes: torch.Tensor):
201
+ """Given the quantized codes, reconstruct the continuous representation
202
+ Parameters
203
+ ----------
204
+ codes : Tensor[B x N x T]
205
+ Quantized discrete representation of input
206
+ Returns
207
+ -------
208
+ Tensor[B x D x T]
209
+ Quantized continuous representation of input
210
+ """
211
+ z_q = 0.0
212
+ z_p = []
213
+ n_codebooks = codes.shape[1]
214
+ for i in range(n_codebooks):
215
+ z_p_i = self.quantizers[i].decode_code(codes[:, i, :])
216
+ z_p.append(z_p_i)
217
+
218
+ z_q_i = self.quantizers[i].out_proj(z_p_i)
219
+ z_q = z_q + z_q_i
220
+ return z_q, torch.cat(z_p, dim=1), codes
221
+
222
+ def from_latents(self, latents: torch.Tensor):
223
+ """Given the unquantized latents, reconstruct the
224
+ continuous representation after quantization.
225
+
226
+ Parameters
227
+ ----------
228
+ latents : Tensor[B x N x T]
229
+ Continuous representation of input after projection
230
+
231
+ Returns
232
+ -------
233
+ Tensor[B x D x T]
234
+ Quantized representation of full-projected space
235
+ Tensor[B x D x T]
236
+ Quantized representation of latent space
237
+ """
238
+ z_q = 0
239
+ z_p = []
240
+ codes = []
241
+ dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers])
242
+
243
+ n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[
244
+ 0
245
+ ]
246
+ for i in range(n_codebooks):
247
+ j, k = dims[i], dims[i + 1]
248
+ z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :])
249
+ z_p.append(z_p_i)
250
+ codes.append(codes_i)
251
+
252
+ z_q_i = self.quantizers[i].out_proj(z_p_i)
253
+ z_q = z_q + z_q_i
254
+
255
+ return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1)
256
+
257
+
258
+ if __name__ == "__main__":
259
+ rvq = ResidualVectorQuantize(quantizer_dropout=True)
260
+ x = torch.randn(16, 512, 80)
261
+ y = rvq(x)
262
+ print(y["latents"].shape)
vae_modules/stable_vae/models/pretransforms.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from einops import rearrange
3
+ from torch import nn
4
+
5
+ class Pretransform(nn.Module):
6
+ def __init__(self, enable_grad, io_channels, is_discrete):
7
+ super().__init__()
8
+
9
+ self.is_discrete = is_discrete
10
+ self.io_channels = io_channels
11
+ self.encoded_channels = None
12
+ self.downsampling_ratio = None
13
+
14
+ self.enable_grad = enable_grad
15
+
16
+ def encode(self, x):
17
+ raise NotImplementedError
18
+
19
+ def decode(self, z):
20
+ raise NotImplementedError
21
+
22
+ def tokenize(self, x):
23
+ raise NotImplementedError
24
+
25
+ def decode_tokens(self, tokens):
26
+ raise NotImplementedError
27
+
28
+ class AutoencoderPretransform(Pretransform):
29
+ def __init__(self, model, scale=1.0, model_half=False, iterate_batch=False, chunked=False):
30
+ super().__init__(enable_grad=False, io_channels=model.io_channels, is_discrete=model.bottleneck is not None and model.bottleneck.is_discrete)
31
+ self.model = model
32
+ self.model.requires_grad_(False).eval()
33
+ self.scale=scale
34
+ self.downsampling_ratio = model.downsampling_ratio
35
+ self.io_channels = model.io_channels
36
+ self.sample_rate = model.sample_rate
37
+
38
+ self.model_half = model_half
39
+ self.iterate_batch = iterate_batch
40
+
41
+ self.encoded_channels = model.latent_dim
42
+
43
+ self.chunked = chunked
44
+ self.num_quantizers = model.bottleneck.num_quantizers if model.bottleneck is not None and model.bottleneck.is_discrete else None
45
+ self.codebook_size = model.bottleneck.codebook_size if model.bottleneck is not None and model.bottleneck.is_discrete else None
46
+
47
+ if self.model_half:
48
+ self.model.half()
49
+
50
+ def encode(self, x, **kwargs):
51
+
52
+ if self.model_half:
53
+ x = x.half()
54
+ self.model.to(torch.float16)
55
+
56
+ encoded = self.model.encode_audio(x, chunked=self.chunked, iterate_batch=self.iterate_batch, **kwargs)
57
+
58
+ if self.model_half:
59
+ encoded = encoded.float()
60
+
61
+ return encoded / self.scale
62
+
63
+ def decode(self, z, **kwargs):
64
+ z = z * self.scale
65
+
66
+ if self.model_half:
67
+ z = z.half()
68
+ self.model.to(torch.float16)
69
+
70
+ decoded = self.model.decode_audio(z, chunked=self.chunked, iterate_batch=self.iterate_batch, **kwargs)
71
+
72
+ if self.model_half:
73
+ decoded = decoded.float()
74
+
75
+ return decoded
76
+
77
+ def tokenize(self, x, **kwargs):
78
+ assert self.model.is_discrete, "Cannot tokenize with a continuous model"
79
+
80
+ _, info = self.model.encode(x, return_info = True, **kwargs)
81
+
82
+ return info[self.model.bottleneck.tokens_id]
83
+
84
+ def decode_tokens(self, tokens, **kwargs):
85
+ assert self.model.is_discrete, "Cannot decode tokens with a continuous model"
86
+
87
+ return self.model.decode_tokens(tokens, **kwargs)
88
+
89
+ def load_state_dict(self, state_dict, strict=True):
90
+ self.model.load_state_dict(state_dict, strict=strict)
91
+
92
+ class WaveletPretransform(Pretransform):
93
+ def __init__(self, channels, levels, wavelet):
94
+ super().__init__(enable_grad=False, io_channels=channels, is_discrete=False)
95
+
96
+ from .wavelets import WaveletEncode1d, WaveletDecode1d
97
+
98
+ self.encoder = WaveletEncode1d(channels, levels, wavelet)
99
+ self.decoder = WaveletDecode1d(channels, levels, wavelet)
100
+
101
+ self.downsampling_ratio = 2 ** levels
102
+ self.io_channels = channels
103
+ self.encoded_channels = channels * self.downsampling_ratio
104
+
105
+ def encode(self, x):
106
+ return self.encoder(x)
107
+
108
+ def decode(self, z):
109
+ return self.decoder(z)
110
+
111
+ class PQMFPretransform(Pretransform):
112
+ def __init__(self, attenuation=100, num_bands=16):
113
+ # TODO: Fix PQMF to take in in-channels
114
+ super().__init__(enable_grad=False, io_channels=1, is_discrete=False)
115
+ from .pqmf import PQMF
116
+ self.pqmf = PQMF(attenuation, num_bands)
117
+
118
+
119
+ def encode(self, x):
120
+ # x is (Batch x Channels x Time)
121
+ x = self.pqmf.forward(x)
122
+ # pqmf.forward returns (Batch x Channels x Bands x Time)
123
+ # but Pretransform needs Batch x Channels x Time
124
+ # so concatenate channels and bands into one axis
125
+ return rearrange(x, "b c n t -> b (c n) t")
126
+
127
+ def decode(self, x):
128
+ # x is (Batch x (Channels Bands) x Time), convert back to (Batch x Channels x Bands x Time)
129
+ x = rearrange(x, "b (c n) t -> b c n t", n=self.pqmf.num_bands)
130
+ # returns (Batch x Channels x Time)
131
+ return self.pqmf.inverse(x)
132
+
133
+ class PretrainedDACPretransform(Pretransform):
134
+ def __init__(self, model_type="44khz", model_bitrate="8kbps", scale=1.0, quantize_on_decode: bool = True, chunked=True):
135
+ super().__init__(enable_grad=False, io_channels=1, is_discrete=True)
136
+
137
+ import dac
138
+
139
+ model_path = dac.utils.download(model_type=model_type, model_bitrate=model_bitrate)
140
+
141
+ self.model = dac.DAC.load(model_path)
142
+
143
+ self.quantize_on_decode = quantize_on_decode
144
+
145
+ if model_type == "44khz":
146
+ self.downsampling_ratio = 512
147
+ else:
148
+ self.downsampling_ratio = 320
149
+
150
+ self.io_channels = 1
151
+
152
+ self.scale = scale
153
+
154
+ self.chunked = chunked
155
+
156
+ self.encoded_channels = self.model.latent_dim
157
+
158
+ self.num_quantizers = self.model.n_codebooks
159
+
160
+ self.codebook_size = self.model.codebook_size
161
+
162
+ def encode(self, x):
163
+
164
+ latents = self.model.encoder(x)
165
+
166
+ if self.quantize_on_decode:
167
+ output = latents
168
+ else:
169
+ z, _, _, _, _ = self.model.quantizer(latents, n_quantizers=self.model.n_codebooks)
170
+ output = z
171
+
172
+ if self.scale != 1.0:
173
+ output = output / self.scale
174
+
175
+ return output
176
+
177
+ def decode(self, z):
178
+
179
+ if self.scale != 1.0:
180
+ z = z * self.scale
181
+
182
+ if self.quantize_on_decode:
183
+ z, _, _, _, _ = self.model.quantizer(z, n_quantizers=self.model.n_codebooks)
184
+
185
+ return self.model.decode(z)
186
+
187
+ def tokenize(self, x):
188
+ return self.model.encode(x)[1]
189
+
190
+ def decode_tokens(self, tokens):
191
+ latents = self.model.quantizer.from_codes(tokens)
192
+ return self.model.decode(latents)
193
+
194
+ class AudiocraftCompressionPretransform(Pretransform):
195
+ def __init__(self, model_type="facebook/encodec_32khz", scale=1.0, quantize_on_decode: bool = True):
196
+ super().__init__(enable_grad=False, io_channels=1, is_discrete=True)
197
+
198
+ try:
199
+ from audiocraft.models import CompressionModel
200
+ except ImportError:
201
+ raise ImportError("Audiocraft is not installed. Please install audiocraft to use Audiocraft models.")
202
+
203
+ self.model = CompressionModel.get_pretrained(model_type)
204
+
205
+ self.quantize_on_decode = quantize_on_decode
206
+
207
+ self.downsampling_ratio = round(self.model.sample_rate / self.model.frame_rate)
208
+
209
+ self.sample_rate = self.model.sample_rate
210
+
211
+ self.io_channels = self.model.channels
212
+
213
+ self.scale = scale
214
+
215
+ #self.encoded_channels = self.model.latent_dim
216
+
217
+ self.num_quantizers = self.model.num_codebooks
218
+
219
+ self.codebook_size = self.model.cardinality
220
+
221
+ self.model.to(torch.float16).eval().requires_grad_(False)
222
+
223
+ def encode(self, x):
224
+
225
+ assert False, "Audiocraft compression models do not support continuous encoding"
226
+
227
+ # latents = self.model.encoder(x)
228
+
229
+ # if self.quantize_on_decode:
230
+ # output = latents
231
+ # else:
232
+ # z, _, _, _, _ = self.model.quantizer(latents, n_quantizers=self.model.n_codebooks)
233
+ # output = z
234
+
235
+ # if self.scale != 1.0:
236
+ # output = output / self.scale
237
+
238
+ # return output
239
+
240
+ def decode(self, z):
241
+
242
+ assert False, "Audiocraft compression models do not support continuous decoding"
243
+
244
+ # if self.scale != 1.0:
245
+ # z = z * self.scale
246
+
247
+ # if self.quantize_on_decode:
248
+ # z, _, _, _, _ = self.model.quantizer(z, n_quantizers=self.model.n_codebooks)
249
+
250
+ # return self.model.decode(z)
251
+
252
+ def tokenize(self, x):
253
+ with torch.cuda.amp.autocast(enabled=False):
254
+ return self.model.encode(x.to(torch.float16))[0]
255
+
256
+ def decode_tokens(self, tokens):
257
+ with torch.cuda.amp.autocast(enabled=False):
258
+ return self.model.decode(tokens)
vae_modules/stable_vae/models/utils.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torchaudio import transforms as T
4
+
5
+
6
+ class PadCrop(nn.Module):
7
+ def __init__(self, n_samples, randomize=True):
8
+ super().__init__()
9
+ self.n_samples = n_samples
10
+ self.randomize = randomize
11
+
12
+ def __call__(self, signal):
13
+ n, s = signal.shape
14
+ start = 0 if (not self.randomize) else torch.randint(0, max(0, s - self.n_samples) + 1, []).item()
15
+ end = start + self.n_samples
16
+ output = signal.new_zeros([n, self.n_samples])
17
+ output[:, :min(s, self.n_samples)] = signal[:, start:end]
18
+ return output
19
+
20
+
21
+ def set_audio_channels(audio, target_channels):
22
+ if target_channels == 1:
23
+ # Convert to mono
24
+ audio = audio.mean(1, keepdim=True)
25
+ elif target_channels == 2:
26
+ # Convert to stereo
27
+ if audio.shape[1] == 1:
28
+ audio = audio.repeat(1, 2, 1)
29
+ elif audio.shape[1] > 2:
30
+ audio = audio[:, :2, :]
31
+ return audio
32
+
33
+ def prepare_audio(audio, in_sr, target_sr, target_length, target_channels, device):
34
+
35
+ audio = audio.to(device)
36
+
37
+ if in_sr != target_sr:
38
+ resample_tf = T.Resample(in_sr, target_sr).to(device)
39
+ audio = resample_tf(audio)
40
+
41
+ audio = PadCrop(target_length, randomize=False)(audio)
42
+
43
+ # Add batch dimension
44
+ if audio.dim() == 1:
45
+ audio = audio.unsqueeze(0).unsqueeze(0)
46
+ elif audio.dim() == 2:
47
+ audio = audio.unsqueeze(0)
48
+
49
+ audio = set_audio_channels(audio, target_channels)
50
+
51
+ return audio