L0SG commited on
Commit
1de35a2
1 Parent(s): cfe9514

update space demo

Browse files
Files changed (5) hide show
  1. app.py +14 -44
  2. bigvgan.py +351 -0
  3. inference.py +0 -105
  4. meldataset.py +2 -149
  5. models.py +0 -955
app.py CHANGED
@@ -6,8 +6,8 @@ import json
6
  import torch
7
  import os
8
  from env import AttrDict
9
- from meldataset import mel_spectrogram, MAX_WAV_VALUE
10
- from models import BigVGAN as Generator
11
  import librosa
12
  import numpy as np
13
  from utils import plot_spectrogram
@@ -35,22 +35,21 @@ def inference_gradio(input, model_choice): # input is audio waveform in [T, cha
35
  audio = np.transpose(audio) # transpose to [channel, T] for librosa
36
  audio = audio / MAX_WAV_VALUE # convert int16 to float range used by BigVGAN
37
 
38
- h = dict_config[model_choice]
39
  model = dict_model[model_choice]
40
 
41
- if sr != h.sampling_rate: # convert audio to model's sampling rate
42
- audio = librosa.resample(audio, orig_sr=sr, target_sr=h.sampling_rate)
43
  if len(audio.shape) == 2: # stereo
44
  audio = librosa.to_mono(audio) # convert to mono if stereo
45
  audio = librosa.util.normalize(audio) * 0.95
46
 
47
  output, spec_gen = inference_model(
48
- audio, h, model
49
  ) # output is generated audio in ndarray, int16
50
 
51
  spec_plot_gen = plot_spectrogram(spec_gen)
52
 
53
- output_audio = (h.sampling_rate, output) # tuple for gr.Audio output
54
 
55
  buffer = spec_plot_gen.canvas.buffer_rgba()
56
  output_image = PIL.Image.frombuffer(
@@ -67,22 +66,19 @@ def inference_gradio(input, model_choice): # input is audio waveform in [T, cha
67
 
68
 
69
  @spaces.GPU(duration=120)
70
- def inference_model(audio_input, h, model):
71
  # load model to device
72
  model.to(device)
73
 
74
- def get_mel(x):
75
- return mel_spectrogram(x, h.n_fft, h.num_mels, h.sampling_rate, h.hop_size, h.win_size, h.fmin, h.fmax)
76
-
77
  with torch.inference_mode():
78
  wav = torch.FloatTensor(audio_input)
79
  # compute mel spectrogram from the ground truth audio
80
- spec_gt = get_mel(wav.unsqueeze(0)).to(device)
81
 
82
  y_g_hat = model(spec_gt)
83
 
84
  audio_gen = y_g_hat.squeeze().cpu()
85
- spec_gen = get_mel(audio_gen.unsqueeze(0))
86
  audio_gen = audio_gen.numpy() # [T], float [-1, 1]
87
  audio_gen = (audio_gen * MAX_WAV_VALUE).astype("int16") # [T], int16
88
  spec_gen = spec_gen.squeeze().numpy() # [C, T_frame]
@@ -234,9 +230,7 @@ css = """
234
 
235
  ######################## script for loading the models ########################
236
 
237
- MODEL_PATH = "nvidia/BigVGAN"
238
-
239
- LIST_MODEL_NAME = [
240
  "bigvgan_24khz_100band",
241
  "bigvgan_base_24khz_100band",
242
  "bigvgan_22khz_80band",
@@ -248,41 +242,17 @@ LIST_MODEL_NAME = [
248
  "bigvgan_v2_44khz_128band_512x"
249
  ]
250
 
251
- DICT_MODEL_NAME_FILE_PAIRS = {
252
- "bigvgan_24khz_100band": "g_05000000",
253
- "bigvgan_base_24khz_100band": "g_05000000",
254
- "bigvgan_22khz_80band": "g_05000000",
255
- "bigvgan_base_22khz_80band": "g_05000000",
256
- "bigvgan_v2_22khz_80band_256x": "g_03000000",
257
- "bigvgan_v2_22khz_80band_fmax8k_256x": "g_03000000",
258
- "bigvgan_v2_24khz_100band_256x": "g_03000000",
259
- "bigvgan_v2_44khz_128band_256x": "g_03000000",
260
- "bigvgan_v2_44khz_128band_512x": "g_03000000"
261
- }
262
-
263
  dict_model = {}
264
  dict_config = {}
265
 
266
- for model_name in LIST_MODEL_NAME:
267
- model_file = hf_hub_download(MODEL_PATH, f"{model_name}/{DICT_MODEL_NAME_FILE_PAIRS[model_name]}", use_auth_token=os.environ['TOKEN'])
268
- config_file = hf_hub_download(MODEL_PATH, f"{model_name}/config.json", use_auth_token=os.environ['TOKEN'])
269
-
270
- with open(config_file) as f:
271
- data = f.read()
272
-
273
- json_config = json.loads(data)
274
- h = AttrDict(json_config)
275
-
276
- torch.manual_seed(h.seed)
277
 
278
- generator = Generator(h)
279
- state_dict_g = load_checkpoint(model_file)
280
- generator.load_state_dict(state_dict_g['generator'])
281
  generator.eval()
282
  generator.remove_weight_norm()
283
 
284
  dict_model[model_name] = generator
285
- dict_config[model_name] = h
286
 
287
  ######################## script for gradio UI ########################
288
 
@@ -338,7 +308,7 @@ with iface:
338
  model_choice = gr.Dropdown(
339
  label="Select the model. Default: bigvgan_v2_24khz_100band_256x",
340
  value="bigvgan_v2_24khz_100band_256x",
341
- choices=[m for m in LIST_MODEL_NAME],
342
  interactive=True,
343
  )
344
 
 
6
  import torch
7
  import os
8
  from env import AttrDict
9
+ from meldataset import get_mel_spectrogram, MAX_WAV_VALUE
10
+ from bigvgan import BigVGAN
11
  import librosa
12
  import numpy as np
13
  from utils import plot_spectrogram
 
35
  audio = np.transpose(audio) # transpose to [channel, T] for librosa
36
  audio = audio / MAX_WAV_VALUE # convert int16 to float range used by BigVGAN
37
 
 
38
  model = dict_model[model_choice]
39
 
40
+ if sr != model.h.sampling_rate: # convert audio to model's sampling rate
41
+ audio = librosa.resample(audio, orig_sr=sr, target_sr=model.h.sampling_rate)
42
  if len(audio.shape) == 2: # stereo
43
  audio = librosa.to_mono(audio) # convert to mono if stereo
44
  audio = librosa.util.normalize(audio) * 0.95
45
 
46
  output, spec_gen = inference_model(
47
+ audio, model
48
  ) # output is generated audio in ndarray, int16
49
 
50
  spec_plot_gen = plot_spectrogram(spec_gen)
51
 
52
+ output_audio = (model.h.sampling_rate, output) # tuple for gr.Audio output
53
 
54
  buffer = spec_plot_gen.canvas.buffer_rgba()
55
  output_image = PIL.Image.frombuffer(
 
66
 
67
 
68
  @spaces.GPU(duration=120)
69
+ def inference_model(audio_input, model):
70
  # load model to device
71
  model.to(device)
72
 
 
 
 
73
  with torch.inference_mode():
74
  wav = torch.FloatTensor(audio_input)
75
  # compute mel spectrogram from the ground truth audio
76
+ spec_gt = get_mel_spectrogram(wav.unsqueeze(0), model.h).to(device)
77
 
78
  y_g_hat = model(spec_gt)
79
 
80
  audio_gen = y_g_hat.squeeze().cpu()
81
+ spec_gen = get_mel_spectrogram(audio_gen.unsqueeze(0))
82
  audio_gen = audio_gen.numpy() # [T], float [-1, 1]
83
  audio_gen = (audio_gen * MAX_WAV_VALUE).astype("int16") # [T], int16
84
  spec_gen = spec_gen.squeeze().numpy() # [C, T_frame]
 
230
 
231
  ######################## script for loading the models ########################
232
 
233
+ LIST_MODEL_ID = [
 
 
234
  "bigvgan_24khz_100band",
235
  "bigvgan_base_24khz_100band",
236
  "bigvgan_22khz_80band",
 
242
  "bigvgan_v2_44khz_128band_512x"
243
  ]
244
 
 
 
 
 
 
 
 
 
 
 
 
 
245
  dict_model = {}
246
  dict_config = {}
247
 
248
+ for model_name in LIST_MODEL_ID:
 
 
 
 
 
 
 
 
 
 
249
 
250
+ generator = BigVGAN.from_pretrained('nvidia/'+model_name, token=os.environ['TOKEN'])
 
 
251
  generator.eval()
252
  generator.remove_weight_norm()
253
 
254
  dict_model[model_name] = generator
255
+ dict_config[model_name] = generator.h
256
 
257
  ######################## script for gradio UI ########################
258
 
 
308
  model_choice = gr.Dropdown(
309
  label="Select the model. Default: bigvgan_v2_24khz_100band_256x",
310
  value="bigvgan_v2_24khz_100band_256x",
311
+ choices=[m for m in LIST_MODEL_ID],
312
  interactive=True,
313
  )
314
 
bigvgan.py ADDED
@@ -0,0 +1,351 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 NVIDIA CORPORATION.
2
+ # Licensed under the MIT license.
3
+
4
+ # Adapted from https://github.com/jik876/hifi-gan under the MIT license.
5
+ # LICENSE is in incl_licenses directory.
6
+
7
+ import os
8
+ import json
9
+ from pathlib import Path
10
+
11
+ from collections import namedtuple
12
+ from typing import Optional, List, Union, Dict
13
+
14
+ import torch
15
+ import torch.nn.functional as F
16
+ import torch.nn as nn
17
+ from torch.nn import Conv1d, ConvTranspose1d
18
+ from torch.nn.utils import weight_norm, remove_weight_norm
19
+
20
+ import activations
21
+ from utils import init_weights, get_padding
22
+ from alias_free_torch.act import Activation1d as TorchActivation1d
23
+ from env import AttrDict
24
+
25
+ from huggingface_hub import PyTorchModelHubMixin, hf_hub_download
26
+
27
+ def load_hparams_from_json(path) -> AttrDict:
28
+ with open(path) as f:
29
+ data = f.read()
30
+ h = json.loads(data)
31
+ return AttrDict(h)
32
+
33
+ class AMPBlock1(torch.nn.Module):
34
+ def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5), activation=None):
35
+ super(AMPBlock1, self).__init__()
36
+ self.h = h
37
+
38
+ self.convs1 = nn.ModuleList([
39
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
40
+ padding=get_padding(kernel_size, dilation[0]))),
41
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
42
+ padding=get_padding(kernel_size, dilation[1]))),
43
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
44
+ padding=get_padding(kernel_size, dilation[2])))
45
+ ])
46
+ self.convs1.apply(init_weights)
47
+
48
+ self.convs2 = nn.ModuleList([
49
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
50
+ padding=get_padding(kernel_size, 1))),
51
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
52
+ padding=get_padding(kernel_size, 1))),
53
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
54
+ padding=get_padding(kernel_size, 1)))
55
+ ])
56
+ self.convs2.apply(init_weights)
57
+
58
+ self.num_layers = len(self.convs1) + len(self.convs2) # total number of conv layers
59
+
60
+ # select which Activation1d, lazy-load cuda version to ensure backward compatibility
61
+ if self.h.get("use_cuda_kernel", False):
62
+ # faster CUDA kernel implementation of Activation1d
63
+ from alias_free_cuda.activation1d import Activation1d as CudaActivation1d
64
+ Activation1d = CudaActivation1d
65
+ else:
66
+ Activation1d = TorchActivation1d
67
+
68
+ if activation == 'snake': # periodic nonlinearity with snake function and anti-aliasing
69
+ self.activations = nn.ModuleList([
70
+ Activation1d(
71
+ activation=activations.Snake(channels, alpha_logscale=h.snake_logscale))
72
+ for _ in range(self.num_layers)
73
+ ])
74
+ elif activation == 'snakebeta': # periodic nonlinearity with snakebeta function and anti-aliasing
75
+ self.activations = nn.ModuleList([
76
+ Activation1d(
77
+ activation=activations.SnakeBeta(channels, alpha_logscale=h.snake_logscale))
78
+ for _ in range(self.num_layers)
79
+ ])
80
+ else:
81
+ raise NotImplementedError("activation incorrectly specified. check the config file and look for 'activation'.")
82
+
83
+ def forward(self, x):
84
+ acts1, acts2 = self.activations[::2], self.activations[1::2]
85
+ for c1, c2, a1, a2 in zip(self.convs1, self.convs2, acts1, acts2):
86
+ xt = a1(x)
87
+ xt = c1(xt)
88
+ xt = a2(xt)
89
+ xt = c2(xt)
90
+ x = xt + x
91
+
92
+ return x
93
+
94
+ def remove_weight_norm(self):
95
+ for l in self.convs1:
96
+ remove_weight_norm(l)
97
+ for l in self.convs2:
98
+ remove_weight_norm(l)
99
+
100
+
101
+ class AMPBlock2(torch.nn.Module):
102
+ def __init__(self, h, channels, kernel_size=3, dilation=(1, 3), activation=None):
103
+ super(AMPBlock2, self).__init__()
104
+ self.h = h
105
+
106
+ self.convs = nn.ModuleList([
107
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
108
+ padding=get_padding(kernel_size, dilation[0]))),
109
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
110
+ padding=get_padding(kernel_size, dilation[1])))
111
+ ])
112
+ self.convs.apply(init_weights)
113
+
114
+ self.num_layers = len(self.convs) # total number of conv layers
115
+
116
+ # select which Activation1d, lazy-load cuda version to ensure backward compatibility
117
+ if self.h.get("use_cuda_kernel", False):
118
+ # faster CUDA kernel implementation of Activation1d
119
+ from alias_free_cuda.activation1d import Activation1d as CudaActivation1d
120
+ Activation1d = CudaActivation1d
121
+ else:
122
+ Activation1d = TorchActivation1d
123
+
124
+ if activation == 'snake': # periodic nonlinearity with snake function and anti-aliasing
125
+ self.activations = nn.ModuleList([
126
+ Activation1d(
127
+ activation=activations.Snake(channels, alpha_logscale=h.snake_logscale))
128
+ for _ in range(self.num_layers)
129
+ ])
130
+ elif activation == 'snakebeta': # periodic nonlinearity with snakebeta function and anti-aliasing
131
+ self.activations = nn.ModuleList([
132
+ Activation1d(
133
+ activation=activations.SnakeBeta(channels, alpha_logscale=h.snake_logscale))
134
+ for _ in range(self.num_layers)
135
+ ])
136
+ else:
137
+ raise NotImplementedError("activation incorrectly specified. check the config file and look for 'activation'.")
138
+
139
+ def forward(self, x):
140
+ for c, a in zip (self.convs, self.activations):
141
+ xt = a(x)
142
+ xt = c(xt)
143
+ x = xt + x
144
+
145
+ return x
146
+
147
+ def remove_weight_norm(self):
148
+ for l in self.convs:
149
+ remove_weight_norm(l)
150
+
151
+
152
+ class BigVGAN(
153
+ torch.nn.Module,
154
+ PyTorchModelHubMixin,
155
+ library_name="bigvgan",
156
+ repo_url="https://github.com/NVIDIA/BigVGAN",
157
+ docs_url="https://github.com/NVIDIA/BigVGAN/blob/main/README.md",
158
+ pipeline_tag="audio-to-audio",
159
+ license="mit",
160
+ tags=["neural-vocoder", "audio-generation", "arxiv:2206.04658"]
161
+ ):
162
+ # this is our main BigVGAN model. Applies anti-aliased periodic activation for resblocks.
163
+ # New in v2: if use_cuda_kernel is set to True, it loads optimized CUDA kernels for AMP.
164
+ # NOTE: use_cuda_kernel=True should be used for inference only (training is not supported).
165
+ def __init__(
166
+ self,
167
+ h,
168
+ use_cuda_kernel: bool=False
169
+ ):
170
+ super(BigVGAN, self).__init__()
171
+ self.h = h
172
+ self.h["use_cuda_kernel"] = use_cuda_kernel # add it to global hyperparameters (h)
173
+
174
+ self.num_kernels = len(h.resblock_kernel_sizes)
175
+ self.num_upsamples = len(h.upsample_rates)
176
+
177
+ # pre conv
178
+ self.conv_pre = weight_norm(Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3))
179
+
180
+ # define which AMPBlock to use. BigVGAN uses AMPBlock1 as default
181
+ resblock = AMPBlock1 if h.resblock == '1' else AMPBlock2
182
+
183
+ # transposed conv-based upsamplers. does not apply anti-aliasing
184
+ self.ups = nn.ModuleList()
185
+ for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
186
+ self.ups.append(nn.ModuleList([
187
+ weight_norm(ConvTranspose1d(h.upsample_initial_channel // (2 ** i),
188
+ h.upsample_initial_channel // (2 ** (i + 1)),
189
+ k, u, padding=(k - u) // 2))
190
+ ]))
191
+
192
+ # residual blocks using anti-aliased multi-periodicity composition modules (AMP)
193
+ self.resblocks = nn.ModuleList()
194
+ for i in range(len(self.ups)):
195
+ ch = h.upsample_initial_channel // (2 ** (i + 1))
196
+ for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
197
+ self.resblocks.append(resblock(h, ch, k, d, activation=h.activation))
198
+
199
+ # select which Activation1d, lazy-load cuda version to ensure backward compatibility
200
+ if self.h.get("use_cuda_kernel", False):
201
+ # faster CUDA kernel implementation of Activation1d
202
+ from alias_free_cuda.activation1d import Activation1d as CudaActivation1d
203
+ Activation1d = CudaActivation1d
204
+ else:
205
+ Activation1d = TorchActivation1d
206
+
207
+ # post conv
208
+ if h.activation == "snake": # periodic nonlinearity with snake function and anti-aliasing
209
+ activation_post = activations.Snake(ch, alpha_logscale=h.snake_logscale)
210
+ self.activation_post = Activation1d(activation=activation_post)
211
+ elif h.activation == "snakebeta": # periodic nonlinearity with snakebeta function and anti-aliasing
212
+ activation_post = activations.SnakeBeta(ch, alpha_logscale=h.snake_logscale)
213
+ self.activation_post = Activation1d(activation=activation_post)
214
+ else:
215
+ raise NotImplementedError("activation incorrectly specified. check the config file and look for 'activation'.")
216
+
217
+ # whether to use bias for the final conv_post. Defaults to True for backward compatibility
218
+ self.use_bias_at_final = h.get("use_bias_at_final", True)
219
+ self.conv_post = weight_norm(Conv1d(
220
+ ch, 1, 7, 1, padding=3, bias=self.use_bias_at_final
221
+ ))
222
+
223
+ # weight initialization
224
+ for i in range(len(self.ups)):
225
+ self.ups[i].apply(init_weights)
226
+ self.conv_post.apply(init_weights)
227
+
228
+ # final tanh activation. Defaults to True for backward compatibility
229
+ self.use_tanh_at_final = h.get("use_tanh_at_final", True)
230
+
231
+ def forward(self, x):
232
+ # pre conv
233
+ x = self.conv_pre(x)
234
+
235
+ for i in range(self.num_upsamples):
236
+ # upsampling
237
+ for i_up in range(len(self.ups[i])):
238
+ x = self.ups[i][i_up](x)
239
+ # AMP blocks
240
+ xs = None
241
+ for j in range(self.num_kernels):
242
+ if xs is None:
243
+ xs = self.resblocks[i * self.num_kernels + j](x)
244
+ else:
245
+ xs += self.resblocks[i * self.num_kernels + j](x)
246
+ x = xs / self.num_kernels
247
+
248
+ # post conv
249
+ x = self.activation_post(x)
250
+ x = self.conv_post(x)
251
+ # final tanh activation
252
+ if self.use_tanh_at_final:
253
+ x = torch.tanh(x)
254
+ else:
255
+ x = torch.clamp(x, min=-1., max=1.) # bound the output to [-1, 1]
256
+
257
+ return x
258
+
259
+ def remove_weight_norm(self):
260
+ print('Removing weight norm...')
261
+ for l in self.ups:
262
+ for l_i in l:
263
+ remove_weight_norm(l_i)
264
+ for l in self.resblocks:
265
+ l.remove_weight_norm()
266
+ remove_weight_norm(self.conv_pre)
267
+ remove_weight_norm(self.conv_post)
268
+
269
+ ##################################################################
270
+ # additional methods for huggingface_hub support
271
+ ##################################################################
272
+ def _save_pretrained(self, save_directory: Path) -> None:
273
+ """Save weights and config.json from a Pytorch model to a local directory."""
274
+
275
+ model_path = save_directory / 'bigvgan_generator.pt'
276
+ torch.save(
277
+ {'generator': self.state_dict()},
278
+ model_path
279
+ )
280
+
281
+ config_path = save_directory / 'config.json'
282
+ with open(config_path, 'w') as config_file:
283
+ json.dump(self.h, config_file, indent=4)
284
+
285
+ @classmethod
286
+ def _from_pretrained(
287
+ cls,
288
+ *,
289
+ model_id: str,
290
+ revision: str,
291
+ cache_dir: str,
292
+ force_download: bool,
293
+ proxies: Optional[Dict],
294
+ resume_download: bool,
295
+ local_files_only: bool,
296
+ token: Union[str, bool, None],
297
+ map_location: str = "cpu", # additional argument
298
+ strict: bool = False, # additional argument
299
+ use_cuda_kernel: bool = False,
300
+ **model_kwargs,
301
+ ):
302
+ """Load Pytorch pretrained weights and return the loaded model."""
303
+
304
+ ##################################################################
305
+ # download and load hyperparameters (h) used by BigVGAN
306
+ ##################################################################
307
+ config_file = hf_hub_download(
308
+ repo_id=model_id,
309
+ filename='config.json',
310
+ revision=revision,
311
+ cache_dir=cache_dir,
312
+ force_download=force_download,
313
+ proxies=proxies,
314
+ resume_download=resume_download,
315
+ token=token,
316
+ local_files_only=local_files_only,
317
+ )
318
+ h = load_hparams_from_json(config_file)
319
+
320
+ ##################################################################
321
+ # instantiate BigVGAN using h
322
+ ##################################################################
323
+ if use_cuda_kernel:
324
+ print(f"[INFO] You have specified use_cuda_kernel=True during BigVGAN.from_pretrained(). Only inference is supported (training is not implemented)!")
325
+ print(f"[INFO] You need nvcc and ninja installed in your system to build the kernel. For detail, see: https://github.com/NVIDIA/BigVGAN?tab=readme-ov-file#using-custom-cuda-kernel-for-synthesis")
326
+ model = cls(h, use_cuda_kernel=use_cuda_kernel)
327
+
328
+ ##################################################################
329
+ # download and load pretrained generator weight
330
+ ##################################################################
331
+ if os.path.isdir(model_id):
332
+ print("Loading weights from local directory")
333
+ model_file = os.path.join(model_id, 'bigvgan_generator.pt')
334
+ else:
335
+ print(f"Downloading weights from {model_id}")
336
+ model_file = hf_hub_download(
337
+ repo_id=model_id,
338
+ filename='bigvgan_generator.pt',
339
+ revision=revision,
340
+ cache_dir=cache_dir,
341
+ force_download=force_download,
342
+ proxies=proxies,
343
+ resume_download=resume_download,
344
+ token=token,
345
+ local_files_only=local_files_only,
346
+ )
347
+
348
+ checkpoint_dict = torch.load(model_file, map_location=map_location)
349
+ model.load_state_dict(checkpoint_dict['generator'])
350
+
351
+ return model
inference.py DELETED
@@ -1,105 +0,0 @@
1
- # Adapted from https://github.com/jik876/hifi-gan under the MIT license.
2
- # LICENSE is in incl_licenses directory.
3
-
4
- from __future__ import absolute_import, division, print_function, unicode_literals
5
-
6
- import glob
7
- import os
8
- import argparse
9
- import json
10
- import torch
11
- from scipy.io.wavfile import write
12
- from env import AttrDict
13
- from meldataset import mel_spectrogram, MAX_WAV_VALUE
14
- from models import BigVGAN as Generator
15
- import librosa
16
-
17
- h = None
18
- device = None
19
- torch.backends.cudnn.benchmark = False
20
-
21
-
22
- def load_checkpoint(filepath, device):
23
- assert os.path.isfile(filepath)
24
- print("Loading '{}'".format(filepath))
25
- checkpoint_dict = torch.load(filepath, map_location=device)
26
- print("Complete.")
27
- return checkpoint_dict
28
-
29
-
30
- def get_mel(x):
31
- return mel_spectrogram(x, h.n_fft, h.num_mels, h.sampling_rate, h.hop_size, h.win_size, h.fmin, h.fmax)
32
-
33
-
34
- def scan_checkpoint(cp_dir, prefix):
35
- pattern = os.path.join(cp_dir, prefix + '*')
36
- cp_list = glob.glob(pattern)
37
- if len(cp_list) == 0:
38
- return ''
39
- return sorted(cp_list)[-1]
40
-
41
-
42
- def inference(a, h):
43
- generator = Generator(h, use_cuda_kernel=a.use_cuda_kernel).to(device)
44
-
45
- state_dict_g = load_checkpoint(a.checkpoint_file, device)
46
- generator.load_state_dict(state_dict_g['generator'])
47
-
48
- filelist = os.listdir(a.input_wavs_dir)
49
-
50
- os.makedirs(a.output_dir, exist_ok=True)
51
-
52
- generator.eval()
53
- generator.remove_weight_norm()
54
- with torch.no_grad():
55
- for i, filname in enumerate(filelist):
56
- # load the ground truth audio and resample if necessary
57
- wav, sr = librosa.load(os.path.join(a.input_wavs_dir, filname), sr=h.sampling_rate, mono=True)
58
- wav = torch.FloatTensor(wav).to(device)
59
- # compute mel spectrogram from the ground truth audio
60
- x = get_mel(wav.unsqueeze(0))
61
-
62
- y_g_hat = generator(x)
63
-
64
- audio = y_g_hat.squeeze()
65
- audio = audio * MAX_WAV_VALUE
66
- audio = audio.cpu().numpy().astype('int16')
67
-
68
- output_file = os.path.join(a.output_dir, os.path.splitext(filname)[0] + '_generated.wav')
69
- write(output_file, h.sampling_rate, audio)
70
- print(output_file)
71
-
72
-
73
- def main():
74
- print('Initializing Inference Process..')
75
-
76
- parser = argparse.ArgumentParser()
77
- parser.add_argument('--input_wavs_dir', default='test_files')
78
- parser.add_argument('--output_dir', default='generated_files')
79
- parser.add_argument('--checkpoint_file', required=True)
80
- parser.add_argument('--use_cuda_kernel', action='store_true', default=False)
81
-
82
- a = parser.parse_args()
83
-
84
- config_file = os.path.join(os.path.split(a.checkpoint_file)[0], 'config.json')
85
- with open(config_file) as f:
86
- data = f.read()
87
-
88
- global h
89
- json_config = json.loads(data)
90
- h = AttrDict(json_config)
91
-
92
- torch.manual_seed(h.seed)
93
- global device
94
- if torch.cuda.is_available():
95
- torch.cuda.manual_seed(h.seed)
96
- device = torch.device('cuda')
97
- else:
98
- device = torch.device('cpu')
99
-
100
- inference(a, h)
101
-
102
-
103
- if __name__ == '__main__':
104
- main()
105
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
meldataset.py CHANGED
@@ -4,59 +4,37 @@
4
  # Adapted from https://github.com/jik876/hifi-gan under the MIT license.
5
  # LICENSE is in incl_licenses directory.
6
 
7
- import math
8
- import os
9
- import random
10
  import torch
11
  import torch.utils.data
12
  import numpy as np
13
- from librosa.util import normalize
14
  from scipy.io.wavfile import read
15
  from librosa.filters import mel as librosa_mel_fn
16
- import pathlib
17
- from tqdm import tqdm
18
 
19
  MAX_WAV_VALUE = 32767.0 # NOTE: 32768.0 -1 to prevent int16 overflow (results in popping sound in corner cases)
20
 
21
-
22
- def load_wav(full_path, sr_target):
23
- sampling_rate, data = read(full_path)
24
- if sampling_rate != sr_target:
25
- raise RuntimeError("Sampling rate of the file {} is {} Hz, but the model requires {} Hz".
26
- format(full_path, sampling_rate, sr_target))
27
- return data, sampling_rate
28
-
29
-
30
  def dynamic_range_compression(x, C=1, clip_val=1e-5):
31
  return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
32
 
33
-
34
  def dynamic_range_decompression(x, C=1):
35
  return np.exp(x) / C
36
 
37
-
38
  def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
39
  return torch.log(torch.clamp(x, min=clip_val) * C)
40
 
41
-
42
  def dynamic_range_decompression_torch(x, C=1):
43
  return torch.exp(x) / C
44
 
45
-
46
  def spectral_normalize_torch(magnitudes):
47
  output = dynamic_range_compression_torch(magnitudes)
48
  return output
49
 
50
-
51
  def spectral_de_normalize_torch(magnitudes):
52
  output = dynamic_range_decompression_torch(magnitudes)
53
  return output
54
 
55
-
56
  mel_basis = {}
57
  hann_window = {}
58
 
59
-
60
  def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
61
  if torch.min(y) < -1.:
62
  print('min value is ', torch.min(y))
@@ -84,130 +62,5 @@ def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin,
84
 
85
  return spec
86
 
87
-
88
- def get_dataset_filelist(a):
89
- with open(a.input_training_file, 'r', encoding='utf-8') as fi:
90
- training_files = [os.path.join(a.input_wavs_dir, x.split('|')[0] + '.wav')
91
- for x in fi.read().split('\n') if len(x) > 0]
92
- print("first training file: {}".format(training_files[0]))
93
-
94
- with open(a.input_validation_file, 'r', encoding='utf-8') as fi:
95
- validation_files = [os.path.join(a.input_wavs_dir, x.split('|')[0] + '.wav')
96
- for x in fi.read().split('\n') if len(x) > 0]
97
- print("first validation file: {}".format(validation_files[0]))
98
-
99
- list_unseen_validation_files = []
100
- for i in range(len(a.list_input_unseen_validation_file)):
101
- with open(a.list_input_unseen_validation_file[i], 'r', encoding='utf-8') as fi:
102
- unseen_validation_files = [os.path.join(a.list_input_unseen_wavs_dir[i], x.split('|')[0] + '.wav')
103
- for x in fi.read().split('\n') if len(x) > 0]
104
- print("first unseen {}th validation fileset: {}".format(i, unseen_validation_files[0]))
105
- list_unseen_validation_files.append(unseen_validation_files)
106
-
107
- return training_files, validation_files, list_unseen_validation_files
108
-
109
-
110
- class MelDataset(torch.utils.data.Dataset):
111
- def __init__(self, training_files, hparams, segment_size, n_fft, num_mels,
112
- hop_size, win_size, sampling_rate, fmin, fmax, split=True, shuffle=True, n_cache_reuse=1,
113
- device=None, fmax_loss=None, fine_tuning=False, base_mels_path=None, is_seen=True):
114
- self.audio_files = training_files
115
- random.seed(1234)
116
- if shuffle:
117
- random.shuffle(self.audio_files)
118
- self.hparams = hparams
119
- self.is_seen = is_seen
120
- if self.is_seen:
121
- self.name = pathlib.Path(self.audio_files[0]).parts[0]
122
- else:
123
- self.name = '-'.join(pathlib.Path(self.audio_files[0]).parts[:2]).strip("/")
124
-
125
- self.segment_size = segment_size
126
- self.sampling_rate = sampling_rate
127
- self.split = split
128
- self.n_fft = n_fft
129
- self.num_mels = num_mels
130
- self.hop_size = hop_size
131
- self.win_size = win_size
132
- self.fmin = fmin
133
- self.fmax = fmax
134
- self.fmax_loss = fmax_loss
135
- self.cached_wav = None
136
- self.n_cache_reuse = n_cache_reuse
137
- self._cache_ref_count = 0
138
- self.device = device
139
- self.fine_tuning = fine_tuning
140
- self.base_mels_path = base_mels_path
141
-
142
- print("INFO: checking dataset integrity...")
143
- for i in tqdm(range(len(self.audio_files))):
144
- assert os.path.exists(self.audio_files[i]), "{} not found".format(self.audio_files[i])
145
-
146
- def __getitem__(self, index):
147
-
148
- filename = self.audio_files[index]
149
- if self._cache_ref_count == 0:
150
- audio, sampling_rate = load_wav(filename, self.sampling_rate)
151
- audio = audio / MAX_WAV_VALUE
152
- if not self.fine_tuning:
153
- audio = normalize(audio) * 0.95
154
- self.cached_wav = audio
155
- if sampling_rate != self.sampling_rate:
156
- raise ValueError("{} SR doesn't match target {} SR".format(
157
- sampling_rate, self.sampling_rate))
158
- self._cache_ref_count = self.n_cache_reuse
159
- else:
160
- audio = self.cached_wav
161
- self._cache_ref_count -= 1
162
-
163
- audio = torch.FloatTensor(audio)
164
- audio = audio.unsqueeze(0)
165
-
166
- if not self.fine_tuning:
167
- if self.split:
168
- if audio.size(1) >= self.segment_size:
169
- max_audio_start = audio.size(1) - self.segment_size
170
- audio_start = random.randint(0, max_audio_start)
171
- audio = audio[:, audio_start:audio_start+self.segment_size]
172
- else:
173
- audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), 'constant')
174
-
175
- mel = mel_spectrogram(audio, self.n_fft, self.num_mels,
176
- self.sampling_rate, self.hop_size, self.win_size, self.fmin, self.fmax,
177
- center=False)
178
- else: # validation step
179
- # match audio length to self.hop_size * n for evaluation
180
- if (audio.size(1) % self.hop_size) != 0:
181
- audio = audio[:, :-(audio.size(1) % self.hop_size)]
182
- mel = mel_spectrogram(audio, self.n_fft, self.num_mels,
183
- self.sampling_rate, self.hop_size, self.win_size, self.fmin, self.fmax,
184
- center=False)
185
- assert audio.shape[1] == mel.shape[2] * self.hop_size, "audio shape {} mel shape {}".format(audio.shape, mel.shape)
186
-
187
- else:
188
- mel = np.load(
189
- os.path.join(self.base_mels_path, os.path.splitext(os.path.split(filename)[-1])[0] + '.npy'))
190
- mel = torch.from_numpy(mel)
191
-
192
- if len(mel.shape) < 3:
193
- mel = mel.unsqueeze(0)
194
-
195
- if self.split:
196
- frames_per_seg = math.ceil(self.segment_size / self.hop_size)
197
-
198
- if audio.size(1) >= self.segment_size:
199
- mel_start = random.randint(0, mel.size(2) - frames_per_seg - 1)
200
- mel = mel[:, :, mel_start:mel_start + frames_per_seg]
201
- audio = audio[:, mel_start * self.hop_size:(mel_start + frames_per_seg) * self.hop_size]
202
- else:
203
- mel = torch.nn.functional.pad(mel, (0, frames_per_seg - mel.size(2)), 'constant')
204
- audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), 'constant')
205
-
206
- mel_loss = mel_spectrogram(audio, self.n_fft, self.num_mels,
207
- self.sampling_rate, self.hop_size, self.win_size, self.fmin, self.fmax_loss,
208
- center=False)
209
-
210
- return (mel.squeeze(), audio.squeeze(0), filename, mel_loss.squeeze())
211
-
212
- def __len__(self):
213
- return len(self.audio_files)
 
4
  # Adapted from https://github.com/jik876/hifi-gan under the MIT license.
5
  # LICENSE is in incl_licenses directory.
6
 
 
 
 
7
  import torch
8
  import torch.utils.data
9
  import numpy as np
 
10
  from scipy.io.wavfile import read
11
  from librosa.filters import mel as librosa_mel_fn
 
 
12
 
13
  MAX_WAV_VALUE = 32767.0 # NOTE: 32768.0 -1 to prevent int16 overflow (results in popping sound in corner cases)
14
 
 
 
 
 
 
 
 
 
 
15
  def dynamic_range_compression(x, C=1, clip_val=1e-5):
16
  return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
17
 
 
18
  def dynamic_range_decompression(x, C=1):
19
  return np.exp(x) / C
20
 
 
21
  def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
22
  return torch.log(torch.clamp(x, min=clip_val) * C)
23
 
 
24
  def dynamic_range_decompression_torch(x, C=1):
25
  return torch.exp(x) / C
26
 
 
27
  def spectral_normalize_torch(magnitudes):
28
  output = dynamic_range_compression_torch(magnitudes)
29
  return output
30
 
 
31
  def spectral_de_normalize_torch(magnitudes):
32
  output = dynamic_range_decompression_torch(magnitudes)
33
  return output
34
 
 
35
  mel_basis = {}
36
  hann_window = {}
37
 
 
38
  def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
39
  if torch.min(y) < -1.:
40
  print('min value is ', torch.min(y))
 
62
 
63
  return spec
64
 
65
+ def get_mel_spectrogram(wav, h):
66
+ return mel_spectrogram(wav, h.n_fft, h.num_mels, h.sampling_rate, h.hop_size, h.win_size, h.fmin, h.fmax)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models.py DELETED
@@ -1,955 +0,0 @@
1
- # Copyright (c) 2024 NVIDIA CORPORATION.
2
- # Licensed under the MIT license.
3
-
4
- # Adapted from https://github.com/jik876/hifi-gan under the MIT license.
5
- # LICENSE is in incl_licenses directory.
6
-
7
-
8
- import torch
9
- import torch.nn.functional as F
10
- import torch.nn as nn
11
- from torch.nn import Conv1d, ConvTranspose1d, Conv2d
12
- from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
13
- from torchaudio.transforms import Spectrogram, Resample
14
- from librosa.filters import mel as librosa_mel_fn
15
- from scipy import signal
16
-
17
- import activations
18
- from utils import init_weights, get_padding
19
- from alias_free_torch.act import Activation1d as TorchActivation1d
20
- import typing
21
- from typing import List, Optional, Tuple
22
- from collections import namedtuple
23
- import math
24
- import functools
25
-
26
-
27
- class AMPBlock1(torch.nn.Module):
28
- def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5), activation=None):
29
- super(AMPBlock1, self).__init__()
30
- self.h = h
31
-
32
- self.convs1 = nn.ModuleList([
33
- weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
34
- padding=get_padding(kernel_size, dilation[0]))),
35
- weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
36
- padding=get_padding(kernel_size, dilation[1]))),
37
- weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
38
- padding=get_padding(kernel_size, dilation[2])))
39
- ])
40
- self.convs1.apply(init_weights)
41
-
42
- self.convs2 = nn.ModuleList([
43
- weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
44
- padding=get_padding(kernel_size, 1))),
45
- weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
46
- padding=get_padding(kernel_size, 1))),
47
- weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
48
- padding=get_padding(kernel_size, 1)))
49
- ])
50
- self.convs2.apply(init_weights)
51
-
52
- self.num_layers = len(self.convs1) + len(self.convs2) # total number of conv layers
53
-
54
- # select which Activation1d, lazy-load cuda version to ensure backward compatibility
55
- if self.h.get("use_cuda_kernel", False):
56
- # faster CUDA kernel implementation of Activation1d
57
- from alias_free_cuda.activation1d import Activation1d as CudaActivation1d
58
- Activation1d = CudaActivation1d
59
- else:
60
- Activation1d = TorchActivation1d
61
-
62
- if activation == 'snake': # periodic nonlinearity with snake function and anti-aliasing
63
- self.activations = nn.ModuleList([
64
- Activation1d(
65
- activation=activations.Snake(channels, alpha_logscale=h.snake_logscale))
66
- for _ in range(self.num_layers)
67
- ])
68
- elif activation == 'snakebeta': # periodic nonlinearity with snakebeta function and anti-aliasing
69
- self.activations = nn.ModuleList([
70
- Activation1d(
71
- activation=activations.SnakeBeta(channels, alpha_logscale=h.snake_logscale))
72
- for _ in range(self.num_layers)
73
- ])
74
- else:
75
- raise NotImplementedError("activation incorrectly specified. check the config file and look for 'activation'.")
76
-
77
- def forward(self, x):
78
- acts1, acts2 = self.activations[::2], self.activations[1::2]
79
- for c1, c2, a1, a2 in zip(self.convs1, self.convs2, acts1, acts2):
80
- xt = a1(x)
81
- xt = c1(xt)
82
- xt = a2(xt)
83
- xt = c2(xt)
84
- x = xt + x
85
-
86
- return x
87
-
88
- def remove_weight_norm(self):
89
- for l in self.convs1:
90
- remove_weight_norm(l)
91
- for l in self.convs2:
92
- remove_weight_norm(l)
93
-
94
-
95
- class AMPBlock2(torch.nn.Module):
96
- def __init__(self, h, channels, kernel_size=3, dilation=(1, 3), activation=None):
97
- super(AMPBlock2, self).__init__()
98
- self.h = h
99
-
100
- self.convs = nn.ModuleList([
101
- weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
102
- padding=get_padding(kernel_size, dilation[0]))),
103
- weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
104
- padding=get_padding(kernel_size, dilation[1])))
105
- ])
106
- self.convs.apply(init_weights)
107
-
108
- self.num_layers = len(self.convs) # total number of conv layers
109
-
110
- # select which Activation1d, lazy-load cuda version to ensure backward compatibility
111
- if self.h.get("use_cuda_kernel", False):
112
- # faster CUDA kernel implementation of Activation1d
113
- from alias_free_cuda.activation1d import Activation1d as CudaActivation1d
114
- Activation1d = CudaActivation1d
115
- else:
116
- Activation1d = TorchActivation1d
117
-
118
- if activation == 'snake': # periodic nonlinearity with snake function and anti-aliasing
119
- self.activations = nn.ModuleList([
120
- Activation1d(
121
- activation=activations.Snake(channels, alpha_logscale=h.snake_logscale))
122
- for _ in range(self.num_layers)
123
- ])
124
- elif activation == 'snakebeta': # periodic nonlinearity with snakebeta function and anti-aliasing
125
- self.activations = nn.ModuleList([
126
- Activation1d(
127
- activation=activations.SnakeBeta(channels, alpha_logscale=h.snake_logscale))
128
- for _ in range(self.num_layers)
129
- ])
130
- else:
131
- raise NotImplementedError("activation incorrectly specified. check the config file and look for 'activation'.")
132
-
133
- def forward(self, x):
134
- for c, a in zip (self.convs, self.activations):
135
- xt = a(x)
136
- xt = c(xt)
137
- x = xt + x
138
-
139
- return x
140
-
141
- def remove_weight_norm(self):
142
- for l in self.convs:
143
- remove_weight_norm(l)
144
-
145
-
146
- class BigVGAN(torch.nn.Module):
147
- # this is our main BigVGAN model. Applies anti-aliased periodic activation for resblocks.
148
- # New in v2: if use_cuda_kernel is set to True, it loads optimized CUDA kernels for AMP.
149
- # NOTE: use_cuda_kernel=True should be used for inference only (training is not supported).
150
- def __init__(
151
- self,
152
- h,
153
- use_cuda_kernel: bool=False
154
- ):
155
- super(BigVGAN, self).__init__()
156
- self.h = h
157
- self.h["use_cuda_kernel"] = use_cuda_kernel # add it to global hyperparameters (h)
158
-
159
- self.num_kernels = len(h.resblock_kernel_sizes)
160
- self.num_upsamples = len(h.upsample_rates)
161
-
162
- # pre conv
163
- self.conv_pre = weight_norm(Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3))
164
-
165
- # define which AMPBlock to use. BigVGAN uses AMPBlock1 as default
166
- resblock = AMPBlock1 if h.resblock == '1' else AMPBlock2
167
-
168
- # transposed conv-based upsamplers. does not apply anti-aliasing
169
- self.ups = nn.ModuleList()
170
- for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
171
- self.ups.append(nn.ModuleList([
172
- weight_norm(ConvTranspose1d(h.upsample_initial_channel // (2 ** i),
173
- h.upsample_initial_channel // (2 ** (i + 1)),
174
- k, u, padding=(k - u) // 2))
175
- ]))
176
-
177
- # residual blocks using anti-aliased multi-periodicity composition modules (AMP)
178
- self.resblocks = nn.ModuleList()
179
- for i in range(len(self.ups)):
180
- ch = h.upsample_initial_channel // (2 ** (i + 1))
181
- for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
182
- self.resblocks.append(resblock(h, ch, k, d, activation=h.activation))
183
-
184
- # select which Activation1d, lazy-load cuda version to ensure backward compatibility
185
- if self.h.get("use_cuda_kernel", False):
186
- # faster CUDA kernel implementation of Activation1d
187
- from alias_free_cuda.activation1d import Activation1d as CudaActivation1d
188
- Activation1d = CudaActivation1d
189
- else:
190
- Activation1d = TorchActivation1d
191
-
192
- # post conv
193
- if h.activation == "snake": # periodic nonlinearity with snake function and anti-aliasing
194
- activation_post = activations.Snake(ch, alpha_logscale=h.snake_logscale)
195
- self.activation_post = Activation1d(activation=activation_post)
196
- elif h.activation == "snakebeta": # periodic nonlinearity with snakebeta function and anti-aliasing
197
- activation_post = activations.SnakeBeta(ch, alpha_logscale=h.snake_logscale)
198
- self.activation_post = Activation1d(activation=activation_post)
199
- else:
200
- raise NotImplementedError("activation incorrectly specified. check the config file and look for 'activation'.")
201
-
202
- # whether to use bias for the final conv_post. Defaults to True for backward compatibility
203
- self.use_bias_at_final = h.get("use_bias_at_final", True)
204
- self.conv_post = weight_norm(Conv1d(
205
- ch, 1, 7, 1, padding=3, bias=self.use_bias_at_final
206
- ))
207
-
208
- # weight initialization
209
- for i in range(len(self.ups)):
210
- self.ups[i].apply(init_weights)
211
- self.conv_post.apply(init_weights)
212
-
213
- # final tanh activation. Defaults to True for backward compatibility
214
- self.use_tanh_at_final = h.get("use_tanh_at_final", True)
215
-
216
- def forward(self, x):
217
- # pre conv
218
- x = self.conv_pre(x)
219
-
220
- for i in range(self.num_upsamples):
221
- # upsampling
222
- for i_up in range(len(self.ups[i])):
223
- x = self.ups[i][i_up](x)
224
- # AMP blocks
225
- xs = None
226
- for j in range(self.num_kernels):
227
- if xs is None:
228
- xs = self.resblocks[i * self.num_kernels + j](x)
229
- else:
230
- xs += self.resblocks[i * self.num_kernels + j](x)
231
- x = xs / self.num_kernels
232
-
233
- # post conv
234
- x = self.activation_post(x)
235
- x = self.conv_post(x)
236
- # final tanh activation
237
- if self.use_tanh_at_final:
238
- x = torch.tanh(x)
239
- else:
240
- x = torch.clamp(x, min=-1., max=1.) # bound the output to [-1, 1]
241
-
242
- return x
243
-
244
- def remove_weight_norm(self):
245
- print('Removing weight norm...')
246
- for l in self.ups:
247
- for l_i in l:
248
- remove_weight_norm(l_i)
249
- for l in self.resblocks:
250
- l.remove_weight_norm()
251
- remove_weight_norm(self.conv_pre)
252
- remove_weight_norm(self.conv_post)
253
-
254
-
255
- class DiscriminatorP(torch.nn.Module):
256
- def __init__(self, h, period, kernel_size=5, stride=3, use_spectral_norm=False):
257
- super(DiscriminatorP, self).__init__()
258
- self.period = period
259
- self.d_mult = h.discriminator_channel_mult
260
- norm_f = weight_norm if use_spectral_norm == False else spectral_norm
261
- self.convs = nn.ModuleList([
262
- norm_f(Conv2d(1, int(32*self.d_mult), (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
263
- norm_f(Conv2d(int(32*self.d_mult), int(128*self.d_mult), (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
264
- norm_f(Conv2d(int(128*self.d_mult), int(512*self.d_mult), (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
265
- norm_f(Conv2d(int(512*self.d_mult), int(1024*self.d_mult), (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
266
- norm_f(Conv2d(int(1024*self.d_mult), int(1024*self.d_mult), (kernel_size, 1), 1, padding=(2, 0))),
267
- ])
268
- self.conv_post = norm_f(Conv2d(int(1024*self.d_mult), 1, (3, 1), 1, padding=(1, 0)))
269
-
270
- def forward(self, x):
271
- fmap = []
272
-
273
- # 1d to 2d
274
- b, c, t = x.shape
275
- if t % self.period != 0: # pad first
276
- n_pad = self.period - (t % self.period)
277
- x = F.pad(x, (0, n_pad), "reflect")
278
- t = t + n_pad
279
- x = x.view(b, c, t // self.period, self.period)
280
-
281
- for l in self.convs:
282
- x = l(x)
283
- x = F.leaky_relu(x, 0.1)
284
- fmap.append(x)
285
- x = self.conv_post(x)
286
- fmap.append(x)
287
- x = torch.flatten(x, 1, -1)
288
-
289
- return x, fmap
290
-
291
-
292
- class MultiPeriodDiscriminator(torch.nn.Module):
293
- def __init__(self, h):
294
- super(MultiPeriodDiscriminator, self).__init__()
295
- self.mpd_reshapes = h.mpd_reshapes
296
- print("mpd_reshapes: {}".format(self.mpd_reshapes))
297
- discriminators = [DiscriminatorP(h, rs, use_spectral_norm=h.use_spectral_norm) for rs in self.mpd_reshapes]
298
- self.discriminators = nn.ModuleList(discriminators)
299
-
300
- def forward(self, y, y_hat):
301
- y_d_rs = []
302
- y_d_gs = []
303
- fmap_rs = []
304
- fmap_gs = []
305
- for i, d in enumerate(self.discriminators):
306
- y_d_r, fmap_r = d(y)
307
- y_d_g, fmap_g = d(y_hat)
308
- y_d_rs.append(y_d_r)
309
- fmap_rs.append(fmap_r)
310
- y_d_gs.append(y_d_g)
311
- fmap_gs.append(fmap_g)
312
-
313
- return y_d_rs, y_d_gs, fmap_rs, fmap_gs
314
-
315
-
316
- class DiscriminatorR(nn.Module):
317
- def __init__(self, cfg, resolution):
318
- super().__init__()
319
-
320
- self.resolution = resolution
321
- assert len(self.resolution) == 3, \
322
- "MRD layer requires list with len=3, got {}".format(self.resolution)
323
- self.lrelu_slope = 0.1
324
-
325
- norm_f = weight_norm if cfg.use_spectral_norm == False else spectral_norm
326
- if hasattr(cfg, "mrd_use_spectral_norm"):
327
- print("INFO: overriding MRD use_spectral_norm as {}".format(cfg.mrd_use_spectral_norm))
328
- norm_f = weight_norm if cfg.mrd_use_spectral_norm == False else spectral_norm
329
- self.d_mult = cfg.discriminator_channel_mult
330
- if hasattr(cfg, "mrd_channel_mult"):
331
- print("INFO: overriding mrd channel multiplier as {}".format(cfg.mrd_channel_mult))
332
- self.d_mult = cfg.mrd_channel_mult
333
-
334
- self.convs = nn.ModuleList([
335
- norm_f(nn.Conv2d(1, int(32*self.d_mult), (3, 9), padding=(1, 4))),
336
- norm_f(nn.Conv2d(int(32*self.d_mult), int(32*self.d_mult), (3, 9), stride=(1, 2), padding=(1, 4))),
337
- norm_f(nn.Conv2d(int(32*self.d_mult), int(32*self.d_mult), (3, 9), stride=(1, 2), padding=(1, 4))),
338
- norm_f(nn.Conv2d(int(32*self.d_mult), int(32*self.d_mult), (3, 9), stride=(1, 2), padding=(1, 4))),
339
- norm_f(nn.Conv2d(int(32*self.d_mult), int(32*self.d_mult), (3, 3), padding=(1, 1))),
340
- ])
341
- self.conv_post = norm_f(nn.Conv2d(int(32 * self.d_mult), 1, (3, 3), padding=(1, 1)))
342
-
343
- def forward(self, x):
344
- fmap = []
345
-
346
- x = self.spectrogram(x)
347
- x = x.unsqueeze(1)
348
- for l in self.convs:
349
- x = l(x)
350
- x = F.leaky_relu(x, self.lrelu_slope)
351
- fmap.append(x)
352
- x = self.conv_post(x)
353
- fmap.append(x)
354
- x = torch.flatten(x, 1, -1)
355
-
356
- return x, fmap
357
-
358
- def spectrogram(self, x):
359
- n_fft, hop_length, win_length = self.resolution
360
- x = F.pad(x, (int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)), mode='reflect')
361
- x = x.squeeze(1)
362
- x = torch.stft(x, n_fft=n_fft, hop_length=hop_length, win_length=win_length, center=False, return_complex=True)
363
- x = torch.view_as_real(x) # [B, F, TT, 2]
364
- mag = torch.norm(x, p=2, dim =-1) #[B, F, TT]
365
-
366
- return mag
367
-
368
-
369
- class MultiResolutionDiscriminator(nn.Module):
370
- def __init__(self, cfg, debug=False):
371
- super().__init__()
372
- self.resolutions = cfg.resolutions
373
- assert len(self.resolutions) == 3,\
374
- "MRD requires list of list with len=3, each element having a list with len=3. got {}".\
375
- format(self.resolutions)
376
- self.discriminators = nn.ModuleList(
377
- [DiscriminatorR(cfg, resolution) for resolution in self.resolutions]
378
- )
379
-
380
- def forward(self, y, y_hat):
381
- y_d_rs = []
382
- y_d_gs = []
383
- fmap_rs = []
384
- fmap_gs = []
385
-
386
- for i, d in enumerate(self.discriminators):
387
- y_d_r, fmap_r = d(x=y)
388
- y_d_g, fmap_g = d(x=y_hat)
389
- y_d_rs.append(y_d_r)
390
- fmap_rs.append(fmap_r)
391
- y_d_gs.append(y_d_g)
392
- fmap_gs.append(fmap_g)
393
-
394
- return y_d_rs, y_d_gs, fmap_rs, fmap_gs
395
-
396
- # Method based on descript-audio-codec: https://github.com/descriptinc/descript-audio-codec
397
- # Modified code adapted from https://github.com/gemelo-ai/vocos under the MIT license.
398
- # LICENSE is in incl_licenses directory.
399
- class DiscriminatorB(nn.Module):
400
- def __init__(
401
- self,
402
- window_length: int,
403
- channels: int = 32,
404
- hop_factor: float = 0.25,
405
- bands: Tuple[Tuple[float, float], ...] = ((0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)),
406
- ):
407
- super().__init__()
408
- self.window_length = window_length
409
- self.hop_factor = hop_factor
410
- self.spec_fn = Spectrogram(
411
- n_fft=window_length, hop_length=int(window_length * hop_factor), win_length=window_length, power=None
412
- )
413
- n_fft = window_length // 2 + 1
414
- bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands]
415
- self.bands = bands
416
- convs = lambda: nn.ModuleList(
417
- [
418
- weight_norm(nn.Conv2d(2, channels, (3, 9), (1, 1), padding=(1, 4))),
419
- weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
420
- weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
421
- weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
422
- weight_norm(nn.Conv2d(channels, channels, (3, 3), (1, 1), padding=(1, 1))),
423
- ]
424
- )
425
- self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))])
426
-
427
- self.conv_post = weight_norm(nn.Conv2d(channels, 1, (3, 3), (1, 1), padding=(1, 1)))
428
-
429
- def spectrogram(self, x):
430
- # Remove DC offset
431
- x = x - x.mean(dim=-1, keepdims=True)
432
- # Peak normalize the volume of input audio
433
- x = 0.8 * x / (x.abs().max(dim=-1, keepdim=True)[0] + 1e-9)
434
- x = self.spec_fn(x)
435
- x = torch.view_as_real(x)
436
- x = x.permute(0, 3, 2, 1) # [B, F, T, C] -> [B, C, T, F]
437
- # Split into bands
438
- x_bands = [x[..., b[0] : b[1]] for b in self.bands]
439
- return x_bands
440
-
441
- def forward(self, x: torch.Tensor):
442
- x_bands = self.spectrogram(x.squeeze(1))
443
- fmap = []
444
- x = []
445
-
446
- for band, stack in zip(x_bands, self.band_convs):
447
- for i, layer in enumerate(stack):
448
- band = layer(band)
449
- band = torch.nn.functional.leaky_relu(band, 0.1)
450
- if i > 0:
451
- fmap.append(band)
452
- x.append(band)
453
-
454
- x = torch.cat(x, dim=-1)
455
- x = self.conv_post(x)
456
- fmap.append(x)
457
-
458
- return x, fmap
459
-
460
- # Method based on descript-audio-codec: https://github.com/descriptinc/descript-audio-codec
461
- # Modified code adapted from https://github.com/gemelo-ai/vocos under the MIT license.
462
- # LICENSE is in incl_licenses directory.
463
- class MultiBandDiscriminator(nn.Module):
464
- def __init__(
465
- self,
466
- h,
467
- ):
468
- """
469
- Multi-band multi-scale STFT discriminator, with the architecture based on https://github.com/descriptinc/descript-audio-codec.
470
- and the modified code adapted from https://github.com/gemelo-ai/vocos.
471
- """
472
- super().__init__()
473
- # fft_sizes (list[int]): Tuple of window lengths for FFT. Defaults to [2048, 1024, 512] if not set in h.
474
- self.fft_sizes = h.get("mbd_fft_sizes", [2048, 1024, 512])
475
- self.discriminators = nn.ModuleList(
476
- [DiscriminatorB(window_length=w) for w in self.fft_sizes]
477
- )
478
-
479
- def forward(
480
- self,
481
- y: torch.Tensor,
482
- y_hat: torch.Tensor
483
- ) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[List[torch.Tensor]], List[List[torch.Tensor]]]:
484
-
485
- y_d_rs = []
486
- y_d_gs = []
487
- fmap_rs = []
488
- fmap_gs = []
489
-
490
- for d in self.discriminators:
491
- y_d_r, fmap_r = d(x=y)
492
- y_d_g, fmap_g = d(x=y_hat)
493
- y_d_rs.append(y_d_r)
494
- fmap_rs.append(fmap_r)
495
- y_d_gs.append(y_d_g)
496
- fmap_gs.append(fmap_g)
497
-
498
- return y_d_rs, y_d_gs, fmap_rs, fmap_gs
499
-
500
-
501
- # Adapted from https://github.com/open-mmlab/Amphion/blob/main/models/vocoders/gan/discriminator/mssbcqtd.py under the MIT license.
502
- # LICENSE is in incl_licenses directory.
503
- class DiscriminatorCQT(nn.Module):
504
- def __init__(self, cfg, hop_length, n_octaves, bins_per_octave):
505
- super().__init__()
506
- self.cfg = cfg
507
-
508
- self.filters = cfg["cqtd_filters"]
509
- self.max_filters = cfg["cqtd_max_filters"]
510
- self.filters_scale = cfg["cqtd_filters_scale"]
511
- self.kernel_size = (3, 9)
512
- self.dilations = cfg["cqtd_dilations"]
513
- self.stride = (1, 2)
514
-
515
- self.in_channels = cfg["cqtd_in_channels"]
516
- self.out_channels = cfg["cqtd_out_channels"]
517
- self.fs = cfg["sampling_rate"]
518
- self.hop_length = hop_length
519
- self.n_octaves = n_octaves
520
- self.bins_per_octave = bins_per_octave
521
-
522
- # lazy-load
523
- from nnAudio import features
524
- self.cqt_transform = features.cqt.CQT2010v2(
525
- sr=self.fs * 2,
526
- hop_length=self.hop_length,
527
- n_bins=self.bins_per_octave * self.n_octaves,
528
- bins_per_octave=self.bins_per_octave,
529
- output_format="Complex",
530
- pad_mode="constant",
531
- )
532
-
533
- self.conv_pres = nn.ModuleList()
534
- for i in range(self.n_octaves):
535
- self.conv_pres.append(
536
- nn.Conv2d(
537
- self.in_channels * 2,
538
- self.in_channels * 2,
539
- kernel_size=self.kernel_size,
540
- padding=self.get_2d_padding(self.kernel_size),
541
- )
542
- )
543
-
544
- self.convs = nn.ModuleList()
545
-
546
- self.convs.append(
547
- nn.Conv2d(
548
- self.in_channels * 2,
549
- self.filters,
550
- kernel_size=self.kernel_size,
551
- padding=self.get_2d_padding(self.kernel_size),
552
- )
553
- )
554
-
555
- in_chs = min(self.filters_scale * self.filters, self.max_filters)
556
- for i, dilation in enumerate(self.dilations):
557
- out_chs = min(
558
- (self.filters_scale ** (i + 1)) * self.filters, self.max_filters
559
- )
560
- self.convs.append(
561
- weight_norm(nn.Conv2d(
562
- in_chs,
563
- out_chs,
564
- kernel_size=self.kernel_size,
565
- stride=self.stride,
566
- dilation=(dilation, 1),
567
- padding=self.get_2d_padding(self.kernel_size, (dilation, 1)),
568
- ))
569
- )
570
- in_chs = out_chs
571
- out_chs = min(
572
- (self.filters_scale ** (len(self.dilations) + 1)) * self.filters,
573
- self.max_filters,
574
- )
575
- self.convs.append(
576
- weight_norm(nn.Conv2d(
577
- in_chs,
578
- out_chs,
579
- kernel_size=(self.kernel_size[0], self.kernel_size[0]),
580
- padding=self.get_2d_padding((self.kernel_size[0], self.kernel_size[0])),
581
- ))
582
- )
583
-
584
- self.conv_post = weight_norm(nn.Conv2d(
585
- out_chs,
586
- self.out_channels,
587
- kernel_size=(self.kernel_size[0], self.kernel_size[0]),
588
- padding=self.get_2d_padding((self.kernel_size[0], self.kernel_size[0])),
589
- ))
590
-
591
- self.activation = torch.nn.LeakyReLU(negative_slope=0.1)
592
- self.resample = Resample(orig_freq=self.fs, new_freq=self.fs * 2)
593
-
594
- self.cqtd_normalize_volume = self.cfg.get("cqtd_normalize_volume", False)
595
- if self.cqtd_normalize_volume:
596
- print(f"INFO: cqtd_normalize_volume set to True. Will apply DC offset removal & peak volume normalization in CQTD!")
597
-
598
- def get_2d_padding(
599
- self, kernel_size: typing.Tuple[int, int], dilation: typing.Tuple[int, int] = (1, 1)
600
- ):
601
- return (
602
- ((kernel_size[0] - 1) * dilation[0]) // 2,
603
- ((kernel_size[1] - 1) * dilation[1]) // 2,
604
- )
605
-
606
- def forward(self, x):
607
- fmap = []
608
-
609
- if self.cqtd_normalize_volume:
610
- # Remove DC offset
611
- x = x - x.mean(dim=-1, keepdims=True)
612
- # Peak normalize the volume of input audio
613
- x = 0.8 * x / (x.abs().max(dim=-1, keepdim=True)[0] + 1e-9)
614
-
615
- x = self.resample(x)
616
-
617
- z = self.cqt_transform(x)
618
-
619
- z_amplitude = z[:, :, :, 0].unsqueeze(1)
620
- z_phase = z[:, :, :, 1].unsqueeze(1)
621
-
622
- z = torch.cat([z_amplitude, z_phase], dim=1)
623
- z = torch.permute(z, (0, 1, 3, 2)) # [B, C, W, T] -> [B, C, T, W]
624
-
625
- latent_z = []
626
- for i in range(self.n_octaves):
627
- latent_z.append(
628
- self.conv_pres[i](
629
- z[
630
- :,
631
- :,
632
- :,
633
- i * self.bins_per_octave : (i + 1) * self.bins_per_octave,
634
- ]
635
- )
636
- )
637
- latent_z = torch.cat(latent_z, dim=-1)
638
-
639
- for i, l in enumerate(self.convs):
640
- latent_z = l(latent_z)
641
-
642
- latent_z = self.activation(latent_z)
643
- fmap.append(latent_z)
644
-
645
- latent_z = self.conv_post(latent_z)
646
-
647
- return latent_z, fmap
648
-
649
-
650
- class MultiScaleSubbandCQTDiscriminator(nn.Module):
651
- def __init__(self, cfg):
652
- super().__init__()
653
-
654
- self.cfg = cfg
655
- # Using get with defaults
656
- self.cfg["cqtd_filters"] = self.cfg.get("cqtd_filters", 32)
657
- self.cfg["cqtd_max_filters"] = self.cfg.get("cqtd_max_filters", 1024)
658
- self.cfg["cqtd_filters_scale"] = self.cfg.get("cqtd_filters_scale", 1)
659
- self.cfg["cqtd_dilations"] = self.cfg.get("cqtd_dilations", [1, 2, 4])
660
- self.cfg["cqtd_in_channels"] = self.cfg.get("cqtd_in_channels", 1)
661
- self.cfg["cqtd_out_channels"] = self.cfg.get("cqtd_out_channels", 1)
662
- # multi-scale params to loop over
663
- self.cfg["cqtd_hop_lengths"] = self.cfg.get("cqtd_hop_lengths", [512, 256, 256])
664
- self.cfg["cqtd_n_octaves"] = self.cfg.get("cqtd_n_octaves", [9, 9, 9])
665
- self.cfg["cqtd_bins_per_octaves"] = self.cfg.get("cqtd_bins_per_octaves", [24, 36, 48])
666
-
667
- self.discriminators = nn.ModuleList(
668
- [
669
- DiscriminatorCQT(
670
- self.cfg,
671
- hop_length=self.cfg["cqtd_hop_lengths"][i],
672
- n_octaves=self.cfg["cqtd_n_octaves"][i],
673
- bins_per_octave=self.cfg["cqtd_bins_per_octaves"][i],
674
- )
675
- for i in range(len(self.cfg["cqtd_hop_lengths"]))
676
- ]
677
- )
678
-
679
- def forward(
680
- self,
681
- y: torch.Tensor,
682
- y_hat: torch.Tensor
683
- ) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[List[torch.Tensor]], List[List[torch.Tensor]]]:
684
-
685
- y_d_rs = []
686
- y_d_gs = []
687
- fmap_rs = []
688
- fmap_gs = []
689
-
690
- for disc in self.discriminators:
691
- y_d_r, fmap_r = disc(y)
692
- y_d_g, fmap_g = disc(y_hat)
693
- y_d_rs.append(y_d_r)
694
- fmap_rs.append(fmap_r)
695
- y_d_gs.append(y_d_g)
696
- fmap_gs.append(fmap_g)
697
-
698
- return y_d_rs, y_d_gs, fmap_rs, fmap_gs
699
-
700
-
701
- class CombinedDiscriminator(nn.Module):
702
- # wrapper of chaining multiple discrimiantor architectures
703
- # ex: combine mbd and cqtd as a single class
704
- def __init__(
705
- self,
706
- list_discriminator: List[nn.Module]
707
- ):
708
- super().__init__()
709
- self.discrimiantor = nn.ModuleList(list_discriminator)
710
-
711
- def forward(
712
- self,
713
- y: torch.Tensor,
714
- y_hat: torch.Tensor
715
- ) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[List[torch.Tensor]], List[List[torch.Tensor]]]:
716
-
717
- y_d_rs = []
718
- y_d_gs = []
719
- fmap_rs = []
720
- fmap_gs = []
721
-
722
- for disc in self.discrimiantor:
723
- y_d_r, y_d_g, fmap_r, fmap_g = disc(y, y_hat)
724
- y_d_rs.extend(y_d_r)
725
- fmap_rs.extend(fmap_r)
726
- y_d_gs.extend(y_d_g)
727
- fmap_gs.extend(fmap_g)
728
-
729
- return y_d_rs, y_d_gs, fmap_rs, fmap_gs
730
-
731
-
732
- # Adapted from https://github.com/descriptinc/descript-audio-codec/blob/main/dac/nn/loss.py under the MIT license.
733
- # LICENSE is in incl_licenses directory.
734
- class MultiScaleMelSpectrogramLoss(nn.Module):
735
- """Compute distance between mel spectrograms. Can be used
736
- in a multi-scale way.
737
-
738
- Parameters
739
- ----------
740
- n_mels : List[int]
741
- Number of mels per STFT, by default [5, 10, 20, 40, 80, 160, 320],
742
- window_lengths : List[int], optional
743
- Length of each window of each STFT, by default [32, 64, 128, 256, 512, 1024, 2048]
744
- loss_fn : typing.Callable, optional
745
- How to compare each loss, by default nn.L1Loss()
746
- clamp_eps : float, optional
747
- Clamp on the log magnitude, below, by default 1e-5
748
- mag_weight : float, optional
749
- Weight of raw magnitude portion of loss, by default 0.0 (no ampliciation on mag part)
750
- log_weight : float, optional
751
- Weight of log magnitude portion of loss, by default 1.0
752
- pow : float, optional
753
- Power to raise magnitude to before taking log, by default 1.0
754
- weight : float, optional
755
- Weight of this loss, by default 1.0
756
- match_stride : bool, optional
757
- Whether to match the stride of convolutional layers, by default False
758
-
759
- Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py
760
- Additional code copied and modified from https://github.com/descriptinc/audiotools/blob/master/audiotools/core/audio_signal.py
761
- """
762
-
763
- def __init__(
764
- self,
765
- sampling_rate: int,
766
- n_mels: List[int] = [5, 10, 20, 40, 80, 160, 320],
767
- window_lengths: List[int] = [32, 64, 128, 256, 512, 1024, 2048],
768
- loss_fn: typing.Callable = nn.L1Loss(),
769
- clamp_eps: float = 1e-5,
770
- mag_weight: float = 0.0,
771
- log_weight: float = 1.0,
772
- pow: float = 1.0,
773
- weight: float = 1.0,
774
- match_stride: bool = False,
775
- mel_fmin: List[float] = [0, 0, 0, 0, 0, 0, 0],
776
- mel_fmax: List[float] = [None, None, None, None, None, None, None],
777
- window_type: str = 'hann',
778
- ):
779
- super().__init__()
780
- self.sampling_rate = sampling_rate
781
-
782
- STFTParams = namedtuple(
783
- "STFTParams",
784
- ["window_length", "hop_length", "window_type", "match_stride"],
785
- )
786
-
787
- self.stft_params = [
788
- STFTParams(
789
- window_length=w,
790
- hop_length=w // 4,
791
- match_stride=match_stride,
792
- window_type=window_type,
793
- )
794
- for w in window_lengths
795
- ]
796
- self.n_mels = n_mels
797
- self.loss_fn = loss_fn
798
- self.clamp_eps = clamp_eps
799
- self.log_weight = log_weight
800
- self.mag_weight = mag_weight
801
- self.weight = weight
802
- self.mel_fmin = mel_fmin
803
- self.mel_fmax = mel_fmax
804
- self.pow = pow
805
-
806
- @staticmethod
807
- @functools.lru_cache(None)
808
- def get_window(
809
- window_type,window_length,
810
- ):
811
- return signal.get_window(window_type, window_length)
812
-
813
- @staticmethod
814
- @functools.lru_cache(None)
815
- def get_mel_filters(
816
- sr, n_fft, n_mels, fmin, fmax
817
- ):
818
- return librosa_mel_fn(sr=sr, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax)
819
-
820
- def mel_spectrogram(
821
- self, wav, n_mels, fmin, fmax, window_length, hop_length, match_stride, window_type
822
- ):
823
- # mirrors AudioSignal.mel_spectrogram used by BigVGAN-v2 training from:
824
- # https://github.com/descriptinc/audiotools/blob/master/audiotools/core/audio_signal.py
825
- B, C, T = wav.shape
826
-
827
- if match_stride:
828
- assert (
829
- hop_length == window_length // 4
830
- ), "For match_stride, hop must equal n_fft // 4"
831
- right_pad = math.ceil(T / hop_length) * hop_length - T
832
- pad = (window_length - hop_length) // 2
833
- else:
834
- right_pad = 0
835
- pad = 0
836
-
837
- wav = torch.nn.functional.pad(
838
- wav, (pad, pad + right_pad), mode='reflect'
839
- )
840
-
841
- window = self.get_window(window_type, window_length)
842
- window = torch.from_numpy(window).to(wav.device).float()
843
-
844
- stft = torch.stft(
845
- wav.reshape(-1, T),
846
- n_fft=window_length,
847
- hop_length=hop_length,
848
- window=window,
849
- return_complex=True,
850
- center=True,
851
- )
852
- _, nf, nt = stft.shape
853
- stft = stft.reshape(B, C, nf, nt)
854
- if match_stride:
855
- # Drop first two and last two frames, which are added
856
- # because of padding. Now num_frames * hop_length = num_samples.
857
- stft = stft[..., 2:-2]
858
- magnitude = torch.abs(stft)
859
-
860
- nf = magnitude.shape[2]
861
- mel_basis = self.get_mel_filters(self.sampling_rate, 2 * (nf - 1), n_mels, fmin, fmax)
862
- mel_basis = torch.from_numpy(mel_basis).to(wav.device)
863
- mel_spectrogram = magnitude.transpose(2, -1) @ mel_basis.T
864
- mel_spectrogram = mel_spectrogram.transpose(-1, 2)
865
-
866
- return mel_spectrogram
867
-
868
- def forward(
869
- self,
870
- x: torch.Tensor,
871
- y: torch.Tensor
872
- ) -> torch.Tensor:
873
- """Computes mel loss between an estimate and a reference
874
- signal.
875
-
876
- Parameters
877
- ----------
878
- x : torch.Tensor
879
- Estimate signal
880
- y : torch.Tensor
881
- Reference signal
882
-
883
- Returns
884
- -------
885
- torch.Tensor
886
- Mel loss.
887
- """
888
-
889
- loss = 0.0
890
- for n_mels, fmin, fmax, s in zip(
891
- self.n_mels, self.mel_fmin, self.mel_fmax, self.stft_params
892
- ):
893
- kwargs = {
894
- "n_mels": n_mels,
895
- "fmin": fmin,
896
- "fmax": fmax,
897
- "window_length": s.window_length,
898
- "hop_length": s.hop_length,
899
- "match_stride": s.match_stride,
900
- "window_type": s.window_type,
901
- }
902
-
903
- x_mels = self.mel_spectrogram(x, **kwargs)
904
- y_mels = self.mel_spectrogram(y, **kwargs)
905
- x_logmels = torch.log(x_mels.clamp(min=self.clamp_eps).pow(self.pow)) / torch.log(torch.tensor(10.0))
906
- y_logmels = torch.log(y_mels.clamp(min=self.clamp_eps).pow(self.pow)) / torch.log(torch.tensor(10.0))
907
-
908
- loss += self.log_weight * self.loss_fn(x_logmels, y_logmels)
909
- loss += self.mag_weight * self.loss_fn(x_logmels, y_logmels)
910
-
911
- return loss
912
-
913
-
914
- # loss functions
915
- def feature_loss(
916
- fmap_r: List[List[torch.Tensor]],
917
- fmap_g: List[List[torch.Tensor]]
918
- ) -> torch.Tensor:
919
-
920
- loss = 0
921
- for dr, dg in zip(fmap_r, fmap_g):
922
- for rl, gl in zip(dr, dg):
923
- loss += torch.mean(torch.abs(rl - gl))
924
-
925
- return loss*2 # this equates to lambda=2.0 for the feature matching loss
926
-
927
- def discriminator_loss(
928
- disc_real_outputs: List[torch.Tensor],
929
- disc_generated_outputs: List[torch.Tensor]
930
- ) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor]]:
931
-
932
- loss = 0
933
- r_losses = []
934
- g_losses = []
935
- for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
936
- r_loss = torch.mean((1-dr)**2)
937
- g_loss = torch.mean(dg**2)
938
- loss += (r_loss + g_loss)
939
- r_losses.append(r_loss.item())
940
- g_losses.append(g_loss.item())
941
-
942
- return loss, r_losses, g_losses
943
-
944
- def generator_loss(
945
- disc_outputs: List[torch.Tensor]
946
- ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
947
-
948
- loss = 0
949
- gen_losses = []
950
- for dg in disc_outputs:
951
- l = torch.mean((1-dg)**2)
952
- gen_losses.append(l)
953
- loss += l
954
-
955
- return loss, gen_losses