Spaces:
Runtime error
Runtime error
import torch | |
import torchaudio | |
import torchaudio.functional as F | |
from torchaudio.utils import download_asset | |
from pesq import pesq | |
from pystoi import stoi | |
import mir_eval | |
from pydub import AudioSegment | |
import matplotlib.pyplot as plt | |
import streamlit as st | |
from helper import plot_spectrogram,plot_mask,si_snr,generate_mixture,evaluate,get_irms | |
target_snr=3 | |
#parameters for STFT | |
N_FFT = 1024 | |
N_HOP = 256 | |
stft = torchaudio.transforms.Spectrogram( | |
n_fft=N_FFT, | |
hop_length=N_HOP, | |
power=None, | |
) | |
istft = torchaudio.transforms.InverseSpectrogram(n_fft=N_FFT, hop_length=N_HOP) | |
#defining a psd transform | |
psd_transform = torchaudio.transforms.PSD() | |
mvdr_transform = torchaudio.transforms.SoudenMVDR() | |
#defining the reference microphone | |
REFERENCE_CHANNEL = 0 | |
#creating a random noise for better calculations | |
SAMPLE_NOISE = download_asset("tutorial-assets/mvdr/noise.wav") | |
waveform_noise, sr2 = torchaudio.load(SAMPLE_NOISE) | |
waveform_noise = waveform_noise.to(torch.float32) | |
stft_noise = stft(waveform_noise) | |
def ui(): | |
st.title("Speech Enhancer") | |
st.markdown("Made by Vageesh") | |
#making an audio developer uploader: | |
audio_file = st.file_uploader("Upload an audio file in wav format", type=[ "wav"]) | |
if audio_file is not None: | |
waveform_clean,sr=torchaudio.load(audio_file) | |
waveform_clean = waveform_clean.to(torch.float32) | |
stft_clean = stft(waveform_clean) | |
st.text("Your uploaded audio") | |
st.audio(audio_file) | |
#creating a mixture of our audio file and the noise file | |
waveform_mix = generate_mixture(waveform_clean, waveform_noise, target_snr) | |
#making the files into torch double format | |
waveform_mix = waveform_mix.to(torch.float32) | |
#computing STFT | |
stft_mix = stft(waveform_mix) | |
#plotting the spectogram | |
spec_img=plot_spectrogram(stft_mix) | |
# st.image(spec_img) | |
#showing mixed audio in streamlit | |
torchaudio.save("./waveform_mix.wav", waveform_mix, sr) | |
st.audio("./waveform_mix.wav") | |
#getting the irms | |
irm_speech, irm_noise = get_irms(stft_clean, stft_noise) | |
#getting the psd speech | |
psd_speech = psd_transform(stft_mix, irm_speech) | |
psd_noise = psd_transform(stft_mix, irm_noise) | |
stft_souden = mvdr_transform(stft_mix, psd_speech, psd_noise, reference_channel=REFERENCE_CHANNEL) | |
waveform_souden = istft(stft_souden, length=waveform_mix.shape[-1]) | |
#plotting the cleaned audio and hearing it | |
spec_clean_img=plot_spectrogram(stft_souden) | |
waveform_souden = waveform_souden.reshape(1, -1) | |
# st.image(spec_clean_img) | |
torchaudio.save("./waveform_souden.wav", waveform_souden, sr) | |
st.audio("./waveform_souden.wav") | |
if __name__=="__main__": | |
ui() | |