import streamlit as st import librosa import soundfile as sf import librosa.display from config import CONFIG import torch from dataset import MaskGenerator import onnxruntime, onnx import matplotlib.pyplot as plt import numpy as np from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas @st.cache def load_model(): path = 'lightning_logs/version_0/checkpoints/frn.onnx' onnx_model = onnx.load(path) options = onnxruntime.SessionOptions() options.intra_op_num_threads = 2 options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL session = onnxruntime.InferenceSession(path, options) input_names = [x.name for x in session.get_inputs()] output_names = [x.name for x in session.get_outputs()] return session, onnx_model, input_names, output_names def inference(re_im, session, onnx_model, input_names, output_names): inputs = {input_names[i]: np.zeros([d.dim_value for d in _input.type.tensor_type.shape.dim], dtype=np.float32) for i, _input in enumerate(onnx_model.graph.input) } output_audio = [] for t in range(re_im.shape[0]): inputs[input_names[0]] = re_im[t] out, prev_mag, predictor_state, mlp_state = session.run(output_names, inputs) inputs[input_names[1]] = prev_mag inputs[input_names[2]] = predictor_state inputs[input_names[3]] = mlp_state output_audio.append(out) output_audio = torch.tensor(np.concatenate(output_audio, 0)) output_audio = output_audio.permute(1, 0, 2).contiguous() output_audio = torch.view_as_complex(output_audio) output_audio = torch.istft(output_audio, window, stride, window=hann) return output_audio.numpy() def visualize(hr, lr, recon): sr = CONFIG.DATA.sr window_size = 1024 window = np.hanning(window_size) stft_hr = librosa.core.spectrum.stft(hr, n_fft=window_size, hop_length=512, window=window) stft_hr = 2 * np.abs(stft_hr) / np.sum(window) stft_lr = librosa.core.spectrum.stft(lr, n_fft=window_size, hop_length=512, window=window) stft_lr = 2 * np.abs(stft_lr) / np.sum(window) stft_recon = librosa.core.spectrum.stft(recon, n_fft=window_size, hop_length=512, window=window) stft_recon = 2 * np.abs(stft_recon) / np.sum(window) fig, (ax1, ax2, ax3) = plt.subplots(3, 1, sharey=True, sharex=True, figsize=(16, 10)) ax1.title.set_text('Оригинальный сигнал') ax2.title.set_text('Сигнал с потерями') ax3.title.set_text('Улучшенный сигнал') canvas = FigureCanvas(fig) p = librosa.display.specshow(librosa.amplitude_to_db(stft_hr), ax=ax1, y_axis='linear', x_axis='time', sr=sr) p = librosa.display.specshow(librosa.amplitude_to_db(stft_lr), ax=ax2, y_axis='linear', x_axis='time', sr=sr) p = librosa.display.specshow(librosa.amplitude_to_db(stft_recon), ax=ax3, y_axis='linear', x_axis='time', sr=sr) return fig packet_size = CONFIG.DATA.EVAL.packet_size window = CONFIG.DATA.window_size stride = CONFIG.DATA.stride title = 'Сокрытие потерь пакетов' st.set_page_config(page_title=title, page_icon=":sound:") st.title(title) st.subheader('Загрузить аудио') uploaded_file = st.file_uploader("Upload your audio file (.wav) at 48 kHz sampling rate") is_file_uploaded = uploaded_file is not None if not is_file_uploaded: uploaded_file = 'sample.wav' target, sr = librosa.load(uploaded_file, sr=48000) target = target[:packet_size * (len(target) // packet_size)] st.text('Audio sample') st.audio(uploaded_file) st.subheader('Выберите желаемый процент потерь') slider = [st.slider("Expected loss rate for Markov Chain loss generator", 0, 100, step=1)] loss_percent = float(slider[0])/100 mask_gen = MaskGenerator(is_train=False, probs=[(1 - loss_percent, loss_percent)]) lossy_input = target.copy().reshape(-1, packet_size) mask = mask_gen.gen_mask(len(lossy_input), seed=0)[:, np.newaxis] lossy_input *= mask lossy_input = lossy_input.reshape(-1) hann = torch.sqrt(torch.hann_window(window)) lossy_input_tensor = torch.tensor(lossy_input) re_im = torch.stft(lossy_input_tensor, window, stride, window=hann, return_complex=False).permute(1, 0, 2).unsqueeze( 1).numpy().astype(np.float32) session, onnx_model, input_names, output_names = load_model() if st.button('Улучшить'): with st.spinner('Ожидайте...'): output = inference(re_im, session, onnx_model, input_names, output_names) st.subheader('Визуализация') fig = visualize(target, lossy_input, output) st.pyplot(fig) st.success('Done!') sf.write('target.wav', target, sr) sf.write('lossy.wav', lossy_input, sr) sf.write('enhanced.wav', output, sr) st.text('Original audio') st.audio('target.wav') st.text('Lossy audio') st.audio('lossy.wav') st.text('Enhanced audio') st.audio('enhanced.wav')