audio_denoiser / app.py
wrice's picture
Add streaming audio loading and writing
f939eb0
raw
history blame
No virus
2.09 kB
"""Gradio demo for denoisers."""
import tempfile
from pathlib import Path
import gradio as gr
import numpy as np
import torch
import torchaudio
from denoisers import UNet1DModel, WaveUNetModel
from tqdm import tqdm
MODELS = [
"wrice/unet1d-vctk-48khz",
"wrice/waveunet-vctk-48khz",
"wrice/waveunet-vctk-24khz",
]
def denoise(model_name: str, audio_path: str):
"""Denoise audio."""
if "unet1d" in model_name:
model = UNet1DModel.from_pretrained(model_name)
else:
model = WaveUNetModel.from_pretrained(model_name)
if torch.cuda.is_available():
model = model.cuda()
stream_reader = torchaudio.io.StreamReader(audio_path)
stream_reader.add_basic_audio_stream(
frames_per_chunk=model.config.max_length,
sample_rate=model.config.sample_rate,
num_channels=1,
)
stream_writer = torchaudio.io.StreamWriter("denoised.wav")
stream_writer.add_audio_stream(sample_rate=model.config.sample_rate, num_channels=1)
chunk_size = model.config.max_length
with stream_writer.open():
for (audio_chunk,) in tqdm(stream_reader.stream()):
if audio_chunk is None:
break
audio_chunk = audio_chunk.permute(1, 0)
original_chunk_size = audio_chunk.size(-1)
if audio_chunk.size(-1) < chunk_size:
padding = chunk_size - audio_chunk.size(-1)
audio_chunk = torch.nn.functional.pad(audio_chunk, (0, padding))
if torch.cuda.is_available():
audio_chunk = audio_chunk.cuda()
with torch.no_grad():
denoised_chunk = model(audio_chunk[None]).audio
denoised_chunk = denoised_chunk[:, :, :original_chunk_size]
stream_writer.write_audio_chunk(
0, denoised_chunk.squeeze(0).permute(1, 0).cpu()
)
return "denoised.wav"
iface = gr.Interface(
fn=denoise,
inputs=[gr.Dropdown(choices=MODELS, value=MODELS[0]), gr.Audio(type="filepath")],
outputs=gr.Audio(type="filepath"),
)
iface.launch()