import numpy as np import streamlit as st import librosa import soundfile as sf import librosa.display from config import CONFIG import torch from dataset import MaskGenerator import onnxruntime, onnx import matplotlib.pyplot as plt from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas from pystoi import stoi from pesq import pesq import pandas as pd import torchaudio from torchmetrics.audio import ShortTimeObjectiveIntelligibility as STOI from torchmetrics.audio.pesq import PerceptualEvaluationSpeechQuality as PESQ @st.cache def load_model(): path = 'lightning_logs/version_0/checkpoints/frn.onnx' onnx_model = onnx.load(path) options = onnxruntime.SessionOptions() options.intra_op_num_threads = 2 options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL session = onnxruntime.InferenceSession(path, options) input_names = [x.name for x in session.get_inputs()] output_names = [x.name for x in session.get_outputs()] return session, onnx_model, input_names, output_names def inference(re_im, session, onnx_model, input_names, output_names): inputs = {input_names[i]: np.zeros([d.dim_value for d in _input.type.tensor_type.shape.dim], dtype=np.float32) for i, _input in enumerate(onnx_model.graph.input) } output_audio = [] for t in range(re_im.shape[0]): inputs[input_names[0]] = re_im[t] out, prev_mag, predictor_state, mlp_state = session.run(output_names, inputs) inputs[input_names[1]] = prev_mag inputs[input_names[2]] = predictor_state inputs[input_names[3]] = mlp_state output_audio.append(out) output_audio = torch.tensor(np.concatenate(output_audio, 0)) output_audio = output_audio.permute(1, 0, 2).contiguous() output_audio = torch.view_as_complex(output_audio) output_audio = torch.istft(output_audio, window, stride, window=hann) return output_audio.numpy() def visualize(hr, lr, recon, sr): sr = sr window_size = 1024 window = np.hanning(window_size) stft_hr = librosa.core.spectrum.stft(hr, n_fft=window_size, hop_length=512, window=window) stft_hr = 2 * np.abs(stft_hr) / np.sum(window) stft_lr = librosa.core.spectrum.stft(lr, n_fft=window_size, hop_length=512, window=window) stft_lr = 2 * np.abs(stft_lr) / np.sum(window) stft_recon = librosa.core.spectrum.stft(recon, n_fft=window_size, hop_length=512, window=window) stft_recon = 2 * np.abs(stft_recon) / np.sum(window) fig, (ax1, ax2, ax3) = plt.subplots(3, 1, sharey=True, sharex=True, figsize=(16, 12)) ax1.title.set_text('Оригинальный сигнал') ax2.title.set_text('Сигнал с потерями') ax3.title.set_text('Улучшенный сигнал') canvas = FigureCanvas(fig) p = librosa.display.specshow(librosa.amplitude_to_db(stft_hr), ax=ax1, y_axis='log', x_axis='time', sr=sr) p = librosa.display.specshow(librosa.amplitude_to_db(stft_lr), ax=ax2, y_axis='log', x_axis='time', sr=sr) p = librosa.display.specshow(librosa.amplitude_to_db(stft_recon), ax=ax3, y_axis='log', x_axis='time', sr=sr) return fig packet_size = CONFIG.DATA.EVAL.packet_size window = CONFIG.DATA.window_size stride = CONFIG.DATA.stride title = 'Сокрытие потерь пакетов' st.set_page_config(page_title=title, page_icon=":sound:") st.title(title) st.subheader('1. Загрузка аудио') uploaded_file = st.file_uploader("Загрузите аудио формата (.wav) 48 КГц") is_file_uploaded = uploaded_file is not None if not is_file_uploaded: uploaded_file = 'sample.wav' target, sr = librosa.load(uploaded_file) target = target[:packet_size * (len(target) // packet_size)] st.text('Ваше аудио') st.audio(uploaded_file) st.subheader('2. Выберите желаемый процент потерь') slider = [st.slider("Ожидаемый процент потерь для генератора потерь цепи Маркова", 0, 100, step=1)] loss_percent = float(slider[0])/100 mask_gen = MaskGenerator(is_train=False, probs=[(1 - loss_percent, loss_percent)]) lossy_input = target.copy().reshape(-1, packet_size) mask = mask_gen.gen_mask(len(lossy_input), seed=0)[:, np.newaxis] lossy_input *= mask lossy_input = lossy_input.reshape(-1) hann = torch.sqrt(torch.hann_window(window)) lossy_input_tensor = torch.tensor(lossy_input) re_im = torch.stft(lossy_input_tensor, window, stride, window=hann, return_complex=False).permute(1, 0, 2).unsqueeze( 1).numpy().astype(np.float32) session, onnx_model, input_names, output_names = load_model() if st.button('Сгенерировать потери'): with st.spinner('Ожидайте...'): output = inference(re_im, session, onnx_model, input_names, output_names) st.subheader('3. Визуализация') fig = visualize(target, lossy_input, output, sr) st.pyplot(fig) st.success('Сделано!') sf.write('target.wav', target, sr) sf.write('lossy.wav', lossy_input, sr) sf.write('enhanced.wav', output, sr) st.text('Оригинальное аудио') st.audio('target.wav') st.text('Аудио с потерями') st.audio('lossy.wav') st.text('Улучшенное аудио') st.audio('enhanced.wav') data_clean, samplerate = torchaudio.load('/content/Катя_базу_выдала.wav') data_lossy, samplerate = torchaudio.load('/content/Катя_базу_выдала_40%.wav') data_enhanced, samplerate = torchaudio.load('/content/Катя_базу_выдала_демо.wav') min_len = min(data_clean.shape[1], data_lossy.shape[1], data_enhanced.shape[1]) data_clean = data_clean[:, :min_len] data_lossy = data_lossy[:, :min_len] data_enhanced = data_enhanced[:, :min_len] stoi = STOI(samplerate) stoi_orig = round(float(stoi(data_clean, data_clean)),3) stoi_lossy = round(float(stoi(data_clean, data_lossy)),5) stoi_enhanced = round(float(stoi(data_clean, data_enhanced)),5) stoi_mass=[stoi_orig, stoi_lossy, stoi_enhanced] pesq = PESQ(16000, 'nb') data_clean = data_clean.cpu().numpy() data_lossy = data_lossy.cpu().numpy() data_enhanced = data_enhanced.cpu().numpy() if samplerate != 16000: data_lossy = librosa.resample(data_lossy, orig_sr=48000, target_sr=16000) data_clean = librosa.resample(data_clean, orig_sr=48000, target_sr=16000) data_enhanced = librosa.resample(data_enhanced, orig_sr=48000, target_sr=16000) pesq_orig = float(pesq(torch.tensor(data_clean), torch.tensor(data_clean))) pesq_lossy = float(pesq(torch.tensor(data_lossy), torch.tensor(data_clean))) pesq_enhanced = float(pesq(torch.tensor(data_enhanced), torch.tensor(data_clean))) psq_mas=[pesq_orig, pesq_lossy, pesq_enhanced] #_____________________________________________ data_clean, samplerate = sf.read('target.wav') data_lossy, samplerate = sf.read('lossy.wav') data_enhanced, samplerate = sf.read('enhanced.wav') min_len = min(data_clean.shape[0], data_lossy.shape[0], data_enhanced.shape[0]) data_clean = data_clean[:min_len] data_lossy = data_lossy[:min_len] data_enhanced = data_enhanced[:min_len] stoi_orig = round(stoi(data_clean, data_clean, samplerate, extended=False),5) stoi_lossy = round(stoi(data_clean, data_lossy , samplerate, extended=False),5) stoi_enhanced = round(stoi(data_clean, data_enhanced, samplerate, extended=False),5) stoi_mass=[stoi_orig, stoi_lossy, stoi_enhanced] if samplerate != 16000: data_lossy = librosa.resample(data_lossy, orig_sr=48000, target_sr=16000) data_clean = librosa.resample(data_clean, orig_sr=48000, target_sr=16000) data_enhanced = librosa.resample(data_enhanced, orig_sr=48000, target_sr=16000) pesq_orig = pesq(fs = 16000, ref = data_clean, deg = data_clean, mode='nb') pesq_lossy = pesq(fs = 16000, ref = data_clean, deg = data_lossy, mode='nb') pesq_enhanced = pesq(fs = 16000, ref = data_clean, deg = data_enhanced, mode='nb') psq_mas=[pesq_orig, pesq_lossy, pesq_enhanced] df = pd.DataFrame(columns=['Audio', 'PESQ', 'STOI', 'PLCMOS', 'LSD']) df['Audio'] = ['Clean', 'Lossy', 'Enhanced'] df['PESQ'] = psq_mas df['STOI'] = stoi_mass st.table(df)