Improve the Gradio UI demo (thanks @blaisewf on GitHub)
Browse files
app.py
CHANGED
@@ -1,11 +1,12 @@
|
|
|
|
|
|
|
|
1 |
import spaces
|
2 |
import gradio as gr
|
3 |
-
|
4 |
-
|
5 |
-
import json
|
6 |
import torch
|
7 |
import os
|
8 |
-
|
9 |
from meldataset import get_mel_spectrogram, MAX_WAV_VALUE
|
10 |
from bigvgan import BigVGAN
|
11 |
import librosa
|
@@ -14,22 +15,14 @@ from utils import plot_spectrogram
|
|
14 |
import PIL
|
15 |
|
16 |
if torch.cuda.is_available():
|
17 |
-
device = torch.device(
|
18 |
torch.backends.cudnn.benchmark = False
|
19 |
print(f"using GPU")
|
20 |
else:
|
21 |
-
device = torch.device(
|
22 |
print(f"using CPU")
|
23 |
|
24 |
|
25 |
-
def load_checkpoint(filepath):
|
26 |
-
assert os.path.isfile(filepath)
|
27 |
-
print("Loading '{}'".format(filepath))
|
28 |
-
checkpoint_dict = torch.load(filepath, map_location='cpu')
|
29 |
-
print("Complete.")
|
30 |
-
return checkpoint_dict
|
31 |
-
|
32 |
-
|
33 |
def inference_gradio(input, model_choice): # input is audio waveform in [T, channel]
|
34 |
sr, audio = input # unpack input to sampling rate and audio itself
|
35 |
audio = np.transpose(audio) # transpose to [channel, T] for librosa
|
@@ -49,17 +42,11 @@ def inference_gradio(input, model_choice): # input is audio waveform in [T, cha
|
|
49 |
|
50 |
spec_plot_gen = plot_spectrogram(spec_gen)
|
51 |
|
52 |
-
output_audio = (model.h.sampling_rate, output)
|
53 |
|
54 |
buffer = spec_plot_gen.canvas.buffer_rgba()
|
55 |
output_image = PIL.Image.frombuffer(
|
56 |
-
"RGBA",
|
57 |
-
spec_plot_gen.canvas.get_width_height(),
|
58 |
-
buffer,
|
59 |
-
"raw",
|
60 |
-
"RGBA",
|
61 |
-
0,
|
62 |
-
1
|
63 |
)
|
64 |
|
65 |
return output_audio, output_image
|
@@ -228,7 +215,7 @@ css = """
|
|
228 |
}
|
229 |
"""
|
230 |
|
231 |
-
|
232 |
|
233 |
LIST_MODEL_ID = [
|
234 |
"bigvgan_24khz_100band",
|
@@ -239,7 +226,7 @@ LIST_MODEL_ID = [
|
|
239 |
"bigvgan_v2_22khz_80band_fmax8k_256x",
|
240 |
"bigvgan_v2_24khz_100band_256x",
|
241 |
"bigvgan_v2_44khz_128band_256x",
|
242 |
-
"bigvgan_v2_44khz_128band_512x"
|
243 |
]
|
244 |
|
245 |
dict_model = {}
|
@@ -247,16 +234,16 @@ dict_config = {}
|
|
247 |
|
248 |
for model_name in LIST_MODEL_ID:
|
249 |
|
250 |
-
generator = BigVGAN.from_pretrained(
|
251 |
-
generator.eval()
|
252 |
generator.remove_weight_norm()
|
|
|
253 |
|
254 |
dict_model[model_name] = generator
|
255 |
dict_config[model_name] = generator.h
|
256 |
|
257 |
-
|
258 |
|
259 |
-
iface = gr.Blocks(css=css)
|
260 |
|
261 |
with iface:
|
262 |
gr.HTML(
|
@@ -267,10 +254,10 @@ with iface:
|
|
267 |
display: inline-flex;
|
268 |
align-items: center;
|
269 |
gap: 0.8rem;
|
270 |
-
font-size: 1.
|
271 |
"
|
272 |
>
|
273 |
-
<h1 style="font-weight:
|
274 |
BigVGAN: A Universal Neural Vocoder with Large-Scale Training
|
275 |
</h1>
|
276 |
</div>
|
@@ -299,14 +286,15 @@ with iface:
|
|
299 |
<div>
|
300 |
<h3>Model Overview</h3>
|
301 |
BigVGAN is a universal neural vocoder model that generates audio waveforms using mel spectrogram as inputs.
|
302 |
-
<center><img src="https://user-images.githubusercontent.com/15963413/218609148-881e39df-33af-4af9-ab95-1427c4ebf062.png" width="800" style="margin-top: 20px;"></center>
|
303 |
</div>
|
304 |
"""
|
305 |
)
|
|
|
306 |
|
307 |
-
with gr.Group():
|
308 |
model_choice = gr.Dropdown(
|
309 |
-
label="Select the model
|
|
|
310 |
value="bigvgan_v2_24khz_100band_256x",
|
311 |
choices=[m for m in LIST_MODEL_ID],
|
312 |
interactive=True,
|
@@ -316,143 +304,129 @@ with iface:
|
|
316 |
label="Input Audio", elem_id="input-audio", interactive=True
|
317 |
)
|
318 |
|
319 |
-
|
320 |
|
321 |
-
|
322 |
-
|
|
|
|
|
|
|
|
|
323 |
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
|
329 |
-
|
330 |
|
331 |
-
|
|
|
332 |
[
|
333 |
-
|
334 |
-
|
335 |
-
[os.path.join(os.path.dirname(__file__), "examples/queen_24k.wav"), "bigvgan_v2_24khz_100band_256x"],
|
336 |
-
[os.path.join(os.path.dirname(__file__), "examples/dance_24k.wav"), "bigvgan_v2_24khz_100band_256x"],
|
337 |
-
[os.path.join(os.path.dirname(__file__), "examples/megalovania_24k.wav"), "bigvgan_v2_24khz_100band_256x"],
|
338 |
-
[os.path.join(os.path.dirname(__file__), "examples/hifitts_44k.wav"), "bigvgan_v2_44khz_128band_256x"],
|
339 |
-
[os.path.join(os.path.dirname(__file__), "examples/musdbhq_44k.wav"), "bigvgan_v2_44khz_128band_256x"],
|
340 |
-
[os.path.join(os.path.dirname(__file__), "examples/musiccaps1_44k.wav"), "bigvgan_v2_44khz_128band_256x"],
|
341 |
-
[os.path.join(os.path.dirname(__file__), "examples/musiccaps2_44k.wav"), "bigvgan_v2_44khz_128band_256x"],
|
342 |
],
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
347 |
|
348 |
-
|
349 |
-
|
350 |
-
|
351 |
-
|
352 |
-
|
353 |
-
|
354 |
-
|
355 |
-
|
356 |
-
|
357 |
-
|
358 |
-
|
359 |
-
|
360 |
-
|
361 |
-
|
362 |
-
|
363 |
-
|
364 |
-
|
365 |
-
|
366 |
-
|
367 |
-
|
368 |
-
|
369 |
-
|
370 |
-
|
371 |
-
|
372 |
-
|
373 |
-
|
374 |
-
|
375 |
-
|
376 |
-
|
377 |
-
|
378 |
-
|
379 |
-
|
380 |
-
|
381 |
-
|
382 |
-
|
383 |
-
|
384 |
-
|
385 |
-
|
386 |
-
|
387 |
-
|
388 |
-
|
389 |
-
|
390 |
-
|
391 |
-
|
392 |
-
|
393 |
-
|
394 |
-
|
395 |
-
|
396 |
-
|
397 |
-
|
398 |
-
|
399 |
-
|
400 |
-
|
401 |
-
|
402 |
-
|
403 |
-
|
404 |
-
|
405 |
-
|
406 |
-
|
407 |
-
|
408 |
-
|
409 |
-
|
410 |
-
|
411 |
-
<td>Large-scale Compilation</td>
|
412 |
-
<td>No</td>
|
413 |
-
</tr>
|
414 |
-
<tr>
|
415 |
-
<td><a href="https://huggingface.co/nvidia/bigvgan_24khz_100band">bigvgan_24khz_100band</a></td>
|
416 |
-
<td>24 kHz</td>
|
417 |
-
<td>100</td>
|
418 |
-
<td>12000</td>
|
419 |
-
<td>256</td>
|
420 |
-
<td>112M</td>
|
421 |
-
<td>LibriTTS</td>
|
422 |
-
<td>No</td>
|
423 |
-
</tr>
|
424 |
-
<tr>
|
425 |
-
<td><a href="https://huggingface.co/nvidia/bigvgan_base_24khz_100band">bigvgan_base_24khz_100band</a></td>
|
426 |
-
<td>24 kHz</td>
|
427 |
-
<td>100</td>
|
428 |
-
<td>12000</td>
|
429 |
-
<td>256</td>
|
430 |
-
<td>14M</td>
|
431 |
-
<td>LibriTTS</td>
|
432 |
-
<td>No</td>
|
433 |
-
</tr>
|
434 |
-
<tr>
|
435 |
-
<td><a href="https://huggingface.co/nvidia/bigvgan_22khz_80band">bigvgan_22khz_80band</a></td>
|
436 |
-
<td>22 kHz</td>
|
437 |
-
<td>80</td>
|
438 |
-
<td>8000</td>
|
439 |
-
<td>256</td>
|
440 |
-
<td>112M</td>
|
441 |
-
<td>LibriTTS + VCTK + LJSpeech</td>
|
442 |
-
<td>No</td>
|
443 |
-
</tr>
|
444 |
-
<tr>
|
445 |
-
<td><a href="https://huggingface.co/nvidia/bigvgan_base_22khz_80band">bigvgan_base_22khz_80band</a></td>
|
446 |
-
<td>22 kHz</td>
|
447 |
-
<td>80</td>
|
448 |
-
<td>8000</td>
|
449 |
-
<td>256</td>
|
450 |
-
<td>14M</td>
|
451 |
-
<td>LibriTTS + VCTK + LJSpeech</td>
|
452 |
-
<td>No</td>
|
453 |
-
</tr>
|
454 |
-
</tbody>
|
455 |
-
</table>
|
456 |
<p><b>NOTE: The v1 models are trained using speech audio datasets ONLY! (24kHz models: LibriTTS, 22kHz models: LibriTTS + VCTK + LJSpeech).</b></p>
|
457 |
</div>
|
458 |
"""
|
|
|
1 |
+
# Copyright (c) 2024 NVIDIA CORPORATION.
|
2 |
+
# Licensed under the MIT license.
|
3 |
+
|
4 |
import spaces
|
5 |
import gradio as gr
|
6 |
+
import pandas as pd
|
|
|
|
|
7 |
import torch
|
8 |
import os
|
9 |
+
|
10 |
from meldataset import get_mel_spectrogram, MAX_WAV_VALUE
|
11 |
from bigvgan import BigVGAN
|
12 |
import librosa
|
|
|
15 |
import PIL
|
16 |
|
17 |
if torch.cuda.is_available():
|
18 |
+
device = torch.device("cuda")
|
19 |
torch.backends.cudnn.benchmark = False
|
20 |
print(f"using GPU")
|
21 |
else:
|
22 |
+
device = torch.device("cpu")
|
23 |
print(f"using CPU")
|
24 |
|
25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
def inference_gradio(input, model_choice): # input is audio waveform in [T, channel]
|
27 |
sr, audio = input # unpack input to sampling rate and audio itself
|
28 |
audio = np.transpose(audio) # transpose to [channel, T] for librosa
|
|
|
42 |
|
43 |
spec_plot_gen = plot_spectrogram(spec_gen)
|
44 |
|
45 |
+
output_audio = (model.h.sampling_rate, output) # tuple for gr.Audio output
|
46 |
|
47 |
buffer = spec_plot_gen.canvas.buffer_rgba()
|
48 |
output_image = PIL.Image.frombuffer(
|
49 |
+
"RGBA", spec_plot_gen.canvas.get_width_height(), buffer, "raw", "RGBA", 0, 1
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
)
|
51 |
|
52 |
return output_audio, output_image
|
|
|
215 |
}
|
216 |
"""
|
217 |
|
218 |
+
# Script for loading the models
|
219 |
|
220 |
LIST_MODEL_ID = [
|
221 |
"bigvgan_24khz_100band",
|
|
|
226 |
"bigvgan_v2_22khz_80band_fmax8k_256x",
|
227 |
"bigvgan_v2_24khz_100band_256x",
|
228 |
"bigvgan_v2_44khz_128band_256x",
|
229 |
+
"bigvgan_v2_44khz_128band_512x",
|
230 |
]
|
231 |
|
232 |
dict_model = {}
|
|
|
234 |
|
235 |
for model_name in LIST_MODEL_ID:
|
236 |
|
237 |
+
generator = BigVGAN.from_pretrained("nvidia/" + model_name)
|
|
|
238 |
generator.remove_weight_norm()
|
239 |
+
generator.eval()
|
240 |
|
241 |
dict_model[model_name] = generator
|
242 |
dict_config[model_name] = generator.h
|
243 |
|
244 |
+
# Script for Gradio UI
|
245 |
|
246 |
+
iface = gr.Blocks(css=css, title="BigVGAN - Demo")
|
247 |
|
248 |
with iface:
|
249 |
gr.HTML(
|
|
|
254 |
display: inline-flex;
|
255 |
align-items: center;
|
256 |
gap: 0.8rem;
|
257 |
+
font-size: 1.5rem;
|
258 |
"
|
259 |
>
|
260 |
+
<h1 style="font-weight: 700; margin-bottom: 7px; line-height: normal;">
|
261 |
BigVGAN: A Universal Neural Vocoder with Large-Scale Training
|
262 |
</h1>
|
263 |
</div>
|
|
|
286 |
<div>
|
287 |
<h3>Model Overview</h3>
|
288 |
BigVGAN is a universal neural vocoder model that generates audio waveforms using mel spectrogram as inputs.
|
289 |
+
<center><img src="https://user-images.githubusercontent.com/15963413/218609148-881e39df-33af-4af9-ab95-1427c4ebf062.png" width="800" style="margin-top: 20px; border-radius: 15px;"></center>
|
290 |
</div>
|
291 |
"""
|
292 |
)
|
293 |
+
with gr.Accordion("Input"):
|
294 |
|
|
|
295 |
model_choice = gr.Dropdown(
|
296 |
+
label="Select the model to use",
|
297 |
+
info="The default model is bigvgan_v2_24khz_100band_256x",
|
298 |
value="bigvgan_v2_24khz_100band_256x",
|
299 |
choices=[m for m in LIST_MODEL_ID],
|
300 |
interactive=True,
|
|
|
304 |
label="Input Audio", elem_id="input-audio", interactive=True
|
305 |
)
|
306 |
|
307 |
+
button = gr.Button("Submit")
|
308 |
|
309 |
+
with gr.Accordion("Output"):
|
310 |
+
with gr.Column():
|
311 |
+
output_audio = gr.Audio(label="Output Audio", elem_id="output-audio")
|
312 |
+
output_image = gr.Image(
|
313 |
+
label="Output Mel Spectrogram", elem_id="output-image-gen"
|
314 |
+
)
|
315 |
|
316 |
+
button.click(
|
317 |
+
inference_gradio,
|
318 |
+
inputs=[audio_input, model_choice],
|
319 |
+
outputs=[output_audio, output_image],
|
320 |
+
concurrency_limit=10,
|
321 |
+
)
|
322 |
|
323 |
+
gr.Examples(
|
324 |
+
[
|
325 |
[
|
326 |
+
os.path.join(os.path.dirname(__file__), "examples/jensen_24k.wav"),
|
327 |
+
"bigvgan_v2_24khz_100band_256x",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
328 |
],
|
329 |
+
[
|
330 |
+
os.path.join(os.path.dirname(__file__), "examples/libritts_24k.wav"),
|
331 |
+
"bigvgan_v2_24khz_100band_256x",
|
332 |
+
],
|
333 |
+
[
|
334 |
+
os.path.join(os.path.dirname(__file__), "examples/queen_24k.wav"),
|
335 |
+
"bigvgan_v2_24khz_100band_256x",
|
336 |
+
],
|
337 |
+
[
|
338 |
+
os.path.join(os.path.dirname(__file__), "examples/dance_24k.wav"),
|
339 |
+
"bigvgan_v2_24khz_100band_256x",
|
340 |
+
],
|
341 |
+
[
|
342 |
+
os.path.join(os.path.dirname(__file__), "examples/megalovania_24k.wav"),
|
343 |
+
"bigvgan_v2_24khz_100band_256x",
|
344 |
+
],
|
345 |
+
[
|
346 |
+
os.path.join(os.path.dirname(__file__), "examples/hifitts_44k.wav"),
|
347 |
+
"bigvgan_v2_44khz_128band_256x",
|
348 |
+
],
|
349 |
+
[
|
350 |
+
os.path.join(os.path.dirname(__file__), "examples/musdbhq_44k.wav"),
|
351 |
+
"bigvgan_v2_44khz_128band_256x",
|
352 |
+
],
|
353 |
+
[
|
354 |
+
os.path.join(os.path.dirname(__file__), "examples/musiccaps1_44k.wav"),
|
355 |
+
"bigvgan_v2_44khz_128band_256x",
|
356 |
+
],
|
357 |
+
[
|
358 |
+
os.path.join(os.path.dirname(__file__), "examples/musiccaps2_44k.wav"),
|
359 |
+
"bigvgan_v2_44khz_128band_256x",
|
360 |
+
],
|
361 |
+
],
|
362 |
+
fn=inference_gradio,
|
363 |
+
inputs=[audio_input, model_choice],
|
364 |
+
outputs=[output_audio, output_image],
|
365 |
+
)
|
366 |
|
367 |
+
# Define the data for the table
|
368 |
+
data = {
|
369 |
+
"Model Name": [
|
370 |
+
"bigvgan_v2_44khz_128band_512x",
|
371 |
+
"bigvgan_v2_44khz_128band_256x",
|
372 |
+
"bigvgan_v2_24khz_100band_256x",
|
373 |
+
"bigvgan_v2_22khz_80band_256x",
|
374 |
+
"bigvgan_v2_22khz_80band_fmax8k_256x",
|
375 |
+
"bigvgan_24khz_100band",
|
376 |
+
"bigvgan_base_24khz_100band",
|
377 |
+
"bigvgan_22khz_80band",
|
378 |
+
"bigvgan_base_22khz_80band",
|
379 |
+
],
|
380 |
+
"Sampling Rate": [
|
381 |
+
"44 kHz",
|
382 |
+
"44 kHz",
|
383 |
+
"24 kHz",
|
384 |
+
"22 kHz",
|
385 |
+
"22 kHz",
|
386 |
+
"24 kHz",
|
387 |
+
"24 kHz",
|
388 |
+
"22 kHz",
|
389 |
+
"22 kHz",
|
390 |
+
],
|
391 |
+
"Mel band": [128, 128, 100, 80, 80, 100, 100, 80, 80],
|
392 |
+
"fmax": [22050, 22050, 12000, 11025, 8000, 12000, 12000, 8000, 8000],
|
393 |
+
"Upsampling Ratio": [512, 256, 256, 256, 256, 256, 256, 256, 256],
|
394 |
+
"Parameters": [
|
395 |
+
"122M",
|
396 |
+
"112M",
|
397 |
+
"112M",
|
398 |
+
"112M",
|
399 |
+
"112M",
|
400 |
+
"112M",
|
401 |
+
"14M",
|
402 |
+
"112M",
|
403 |
+
"14M",
|
404 |
+
],
|
405 |
+
"Dataset": [
|
406 |
+
"Large-scale Compilation",
|
407 |
+
"Large-scale Compilation",
|
408 |
+
"Large-scale Compilation",
|
409 |
+
"Large-scale Compilation",
|
410 |
+
"Large-scale Compilation",
|
411 |
+
"LibriTTS",
|
412 |
+
"LibriTTS",
|
413 |
+
"LibriTTS + VCTK + LJSpeech",
|
414 |
+
"LibriTTS + VCTK + LJSpeech",
|
415 |
+
],
|
416 |
+
"Fine-Tuned": ["No", "No", "No", "No", "No", "No", "No", "No", "No"],
|
417 |
+
}
|
418 |
+
|
419 |
+
base_url = "https://huggingface.co/nvidia/"
|
420 |
+
|
421 |
+
df = pd.DataFrame(data)
|
422 |
+
df["Model Name"] = df["Model Name"].apply(
|
423 |
+
lambda x: f'<a href="{base_url}{x}">{x}</a>'
|
424 |
+
)
|
425 |
+
|
426 |
+
html_table = gr.HTML(
|
427 |
+
f"""
|
428 |
+
<div style="text-align: center;">
|
429 |
+
{df.to_html(index=False, escape=False, classes='border="1" cellspacing="0" cellpadding="5" style="margin-left: auto; margin-right: auto;')}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
430 |
<p><b>NOTE: The v1 models are trained using speech audio datasets ONLY! (24kHz models: LibriTTS, 22kHz models: LibriTTS + VCTK + LJSpeech).</b></p>
|
431 |
</div>
|
432 |
"""
|