Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn.functional as F | |
import matplotlib.pyplot as plt | |
from pesq import pesq | |
from pystoi import stoi | |
import mir_eval | |
REFERENCE_CHANNEL = 0 | |
def plot_spectrogram(stft, title="Spectrogram", xlim=None): | |
magnitude = stft.abs() | |
spectrogram = 20 * torch.log10(magnitude + 1e-8).numpy() | |
# figure, axis = plt.subplots(1, 1) | |
# img = axis.imshow(spectrogram, cmap="viridis", vmin=-100, vmax=0, origin="lower", aspect="auto") | |
# figure.suptitle(title) | |
# plt.colorbar(img, ax=axis) | |
# plt.show() | |
def plot_mask(mask, title="Mask", xlim=None): | |
mask = mask.numpy() | |
figure, axis = plt.subplots(1, 1) | |
img = axis.imshow(mask, cmap="viridis", origin="lower", aspect="auto") | |
figure.suptitle(title) | |
plt.colorbar(img, ax=axis) | |
plt.show() | |
def si_snr(estimate, reference, epsilon=1e-8): | |
estimate = estimate - estimate.mean() | |
reference = reference - reference.mean() | |
reference_pow = reference.pow(2).mean(axis=1, keepdim=True) | |
mix_pow = (estimate * reference).mean(axis=1, keepdim=True) | |
scale = mix_pow / (reference_pow + epsilon) | |
reference = scale * reference | |
error = estimate - reference | |
reference_pow = reference.pow(2) | |
error_pow = error.pow(2) | |
reference_pow = reference_pow.mean(axis=1) | |
error_pow = error_pow.mean(axis=1) | |
si_snr = 10 * torch.log10(reference_pow) - 10 * torch.log10(error_pow) | |
return si_snr.item() | |
# def generate_mixture(waveform_clean, waveform_noise, target_snr): | |
# power_clean_signal = waveform_clean.pow(2).mean() | |
# power_noise_signal = waveform_noise.pow(2).mean() | |
# current_snr = 10 * torch.log10(power_clean_signal / power_noise_signal) | |
# waveform_noise *= 10 ** (-(target_snr - current_snr) / 20) | |
# return waveform_clean + waveform_noise | |
def generate_mixture(waveform_clean, waveform_noise, target_snr): | |
if waveform_clean.size(1) > waveform_noise.size(1): | |
waveform_noise = F.pad(waveform_noise, (0, waveform_clean.size(1) - waveform_noise.size(1))) | |
elif waveform_noise.size(1) > waveform_clean.size(1): | |
waveform_clean = F.pad(waveform_clean, (0, waveform_noise.size(1) - waveform_clean.size(1))) | |
power_clean_signal = waveform_clean.pow(2).mean() | |
power_noise_signal = waveform_noise.pow(2).mean() | |
current_snr = 10 * torch.log10(power_clean_signal / power_noise_signal) | |
waveform_noise *= 10 ** (-(target_snr - current_snr) / 20) | |
return waveform_clean + waveform_noise | |
def evaluate(estimate, reference): | |
si_snr_score = si_snr(estimate, reference) | |
( | |
sdr, | |
_, | |
_, | |
_, | |
) = mir_eval.separation.bss_eval_sources(reference.numpy(), estimate.numpy(), False) | |
pesq_mix = pesq(SAMPLE_RATE, estimate[0].numpy(), reference[0].numpy(), "wb") | |
stoi_mix = stoi(reference[0].numpy(), estimate[0].numpy(), SAMPLE_RATE, extended=False) | |
print(f"SDR score: {sdr[0]}") | |
print(f"Si-SNR score: {si_snr_score}") | |
print(f"PESQ score: {pesq_mix}") | |
print(f"STOI score: {stoi_mix}") | |
def get_irms(stft_clean, stft_noise): | |
mag_clean = stft_clean.abs() ** 2 | |
mag_noise = stft_noise.abs() ** 2 | |
irm_speech = mag_clean / (mag_clean + mag_noise) | |
irm_noise = mag_noise / (mag_clean + mag_noise) | |
return irm_speech[REFERENCE_CHANNEL], irm_noise[REFERENCE_CHANNEL] | |