|
import gradio as gr |
|
import spaces |
|
import numpy as np |
|
import torch |
|
from fastgeco.model import ScoreModel |
|
from geco.util.other import pad_spec |
|
import os |
|
import torchaudio |
|
from speechbrain.lobes.models.dual_path import Encoder, SBTransformerBlock, SBTransformerBlock, Dual_Path_Model, Decoder |
|
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
sample_rate = 8000 |
|
num_spks = 2 |
|
ckpt_path = 'ckpts/' |
|
|
|
def load_sepformer(ckpt_path): |
|
encoder = Encoder( |
|
kernel_size=160, |
|
out_channels=256, |
|
in_channels=1 |
|
) |
|
SBtfintra = SBTransformerBlock( |
|
num_layers=8, |
|
d_model=256, |
|
nhead=8, |
|
d_ffn=1024, |
|
dropout=0, |
|
use_positional_encoding=True, |
|
norm_before=True, |
|
) |
|
SBtfinter = SBTransformerBlock( |
|
num_layers=8, |
|
d_model=256, |
|
nhead=8, |
|
d_ffn=1024, |
|
dropout=0, |
|
use_positional_encoding=True, |
|
norm_before=True, |
|
) |
|
masknet = Dual_Path_Model( |
|
num_spks=num_spks, |
|
in_channels=256, |
|
out_channels=256, |
|
num_layers=2, |
|
K=250, |
|
intra_model=SBtfintra, |
|
inter_model=SBtfinter, |
|
norm='ln', |
|
linear_layer_after_inter_intra=False, |
|
skip_around_intra=True, |
|
) |
|
decoder = Decoder( |
|
in_channels=256, |
|
out_channels=1, |
|
kernel_size=160, |
|
stride=80, |
|
bias=False, |
|
) |
|
|
|
encoder_weights = torch.load(os.path.join(ckpt_path, 'encoder.ckpt')) |
|
encoder.load_state_dict(encoder_weights) |
|
masknet_weights = torch.load(os.path.join(ckpt_path, 'masknet.ckpt')) |
|
masknet.load_state_dict(masknet_weights) |
|
decoder_weights = torch.load(os.path.join(ckpt_path, 'decoder.ckpt')) |
|
decoder.load_state_dict(decoder_weights) |
|
encoder = encoder.eval().to(device) |
|
masknet = masknet.eval().to(device) |
|
decoder = decoder.eval().to(device) |
|
return encoder, masknet, decoder |
|
|
|
def load_fastgeco(ckpt_path): |
|
checkpoint_file = os.path.join(ckpt_path, 'fastgeco.ckpt') |
|
model = ScoreModel.load_from_checkpoint( |
|
checkpoint_file, |
|
batch_size=1, num_workers=0, kwargs=dict(gpu=False) |
|
) |
|
model.eval(no_ema=False) |
|
model.to(device) |
|
return model |
|
|
|
encoder, masknet, decoder = load_sepformer(ckpt_path) |
|
fastgeco_model = load_fastgeco(ckpt_path) |
|
|
|
|
|
@spaces.GPU |
|
def separate(test_file, encoder, masknet, decoder): |
|
with torch.no_grad(): |
|
print('Process SepFormer...') |
|
mix, fs_file = torchaudio.load(test_file) |
|
mix = mix.to(device) |
|
fs_model = sample_rate |
|
|
|
|
|
if fs_file != fs_model: |
|
print( |
|
"Resampling the audio from {} Hz to {} Hz".format( |
|
fs_file, fs_model |
|
) |
|
) |
|
tf = torchaudio.transforms.Resample( |
|
orig_freq=fs_file, new_freq=fs_model |
|
).to(device) |
|
mix = mix.mean(dim=0, keepdim=True) |
|
mix = tf(mix) |
|
|
|
mix = mix.to(device) |
|
|
|
|
|
mix_w = encoder(mix) |
|
est_mask = masknet(mix_w) |
|
mix_w = torch.stack([mix_w] * num_spks) |
|
sep_h = mix_w * est_mask |
|
|
|
|
|
est_sources = torch.cat( |
|
[ |
|
decoder(sep_h[i]).unsqueeze(-1) |
|
for i in range(num_spks) |
|
], |
|
dim=-1, |
|
) |
|
est_sources = ( |
|
est_sources / est_sources.abs().max(dim=1, keepdim=True)[0] |
|
).squeeze() |
|
|
|
return est_sources, mix |
|
|
|
|
|
@spaces.GPU |
|
def correct(model, est_sources, mix): |
|
with torch.no_grad(): |
|
print('Process Fast-Geco...') |
|
N = 1 |
|
reverse_starting_point = 0.5 |
|
output = [] |
|
for idx in range(num_spks): |
|
y = est_sources[:, idx].unsqueeze(0) |
|
m = mix |
|
min_leng = min(y.shape[-1],m.shape[-1]) |
|
y = y[...,:min_leng] |
|
m = m[...,:min_leng] |
|
T_orig = y.size(1) |
|
|
|
norm_factor = y.abs().max() |
|
y = y / norm_factor |
|
m = m / norm_factor |
|
Y = torch.unsqueeze(model._forward_transform(model._stft(y.to(device))), 0) |
|
Y = pad_spec(Y) |
|
M = torch.unsqueeze(model._forward_transform(model._stft(m.to(device))), 0) |
|
M = pad_spec(M) |
|
|
|
timesteps = torch.linspace(reverse_starting_point, 0.03, N, device=Y.device) |
|
std = model.sde._std(reverse_starting_point*torch.ones((Y.shape[0],), device=Y.device)) |
|
z = torch.randn_like(Y) |
|
X_t = Y + z * std[:, None, None, None] |
|
|
|
t = timesteps[0] |
|
dt = timesteps[-1] |
|
f, g = model.sde.sde(X_t, t, Y) |
|
vec_t = torch.ones(Y.shape[0], device=Y.device) * t |
|
mean_x_tm1 = X_t - (f - g**2*model.forward(X_t, vec_t, Y, M, vec_t[:,None,None,None]))*dt |
|
sample = mean_x_tm1 |
|
sample = sample.squeeze() |
|
x_hat = model.to_audio(sample.squeeze(), T_orig) |
|
x_hat = x_hat * norm_factor |
|
new_norm_factor = x_hat.abs().max() |
|
x_hat = x_hat / new_norm_factor |
|
x_hat = x_hat.squeeze().cpu().numpy() |
|
output.append(x_hat) |
|
return (sample_rate, output[0]), (sample_rate, output[1]) |
|
|
|
@spaces.GPU |
|
def process_audio(test_file): |
|
result, mix = separate(test_file, encoder, masknet, decoder) |
|
audio1, audio2 = correct(fastgeco_model, result, mix) |
|
return audio1, audio2 |
|
|
|
|
|
|
|
demo_audio_files = [ |
|
("Demo Audio 1", "demo/item0_mix.wav"), |
|
("Demo Audio 2", "demo/item1_mix.wav"), |
|
("Demo Audio 3", "demo/item2_mix.wav"), |
|
("Demo Audio 4", "demo/item3_mix.wav"), |
|
("Demo Audio 5", "demo/item4_mix.wav"), |
|
] |
|
|
|
def update_audio_input(choice): |
|
return choice |
|
|
|
|
|
css = """ |
|
#col-container { |
|
margin: 0 auto; |
|
max-width: 1280px; |
|
} |
|
""" |
|
|
|
|
|
with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo: |
|
with gr.Column(elem_id="col-container"): |
|
gr.Markdown(""" |
|
# Fast-GeCo: Noise-robust Speech Separation with Fast Generative Correction |
|
Separate the noisy mixture speech with a generative correction method, only support 2 speakers now. |
|
|
|
Learn more about π£**Fast-GeCo** on the [Fast-GeCo Repo](https://github.com/WangHelin1997/Fast-GeCo/). |
|
""") |
|
|
|
with gr.Tab("Speech Separation"): |
|
|
|
with gr.Row(): |
|
gt_file_input = gr.Audio(label="Upload Audio to Separate", type="filepath", value="demo/item0_mix.wav") |
|
|
|
demo_selector = gr.Dropdown( |
|
label="Select Demo Audio", |
|
choices=[name for name, _ in demo_audio_files], |
|
value="Demo Audio 1" |
|
) |
|
button = gr.Button("Generate", scale=1) |
|
|
|
|
|
with gr.Row(): |
|
result1 = gr.Audio(label="Separated Audio 1", type="numpy") |
|
result2 = gr.Audio(label="Separated Audio 2", type="numpy") |
|
|
|
|
|
demo_selector.change( |
|
fn=lambda choice: next(path for name, path in demo_audio_files if name == choice), |
|
inputs=demo_selector, |
|
outputs=gt_file_input |
|
) |
|
|
|
|
|
button.click( |
|
fn=process_audio, |
|
inputs=[ |
|
gt_file_input, |
|
], |
|
outputs=[result1, result2] |
|
) |
|
|
|
|
|
demo.launch() |