File size: 4,735 Bytes
bdb2571 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 |
import streamlit as st
import librosa
import librosa.display
from config import CONFIG
import torch
from dataset import MaskGenerator
import onnxruntime, onnx
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
@st.cache_resource
def load_model():
path = 'lightning_logs/version_0/checkpoints/frn.onnx'
onnx_model = onnx.load(path)
options = onnxruntime.SessionOptions()
options.intra_op_num_threads = 2
options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
session = onnxruntime.InferenceSession(path, options)
input_names = [x.name for x in session.get_inputs()]
output_names = [x.name for x in session.get_outputs()]
return session, onnx_model, input_names, output_names
def inference(re_im, session, onnx_model, input_names, output_names):
inputs = {input_names[i]: np.zeros([d.dim_value for d in _input.type.tensor_type.shape.dim],
dtype=np.float32)
for i, _input in enumerate(onnx_model.graph.input)
}
output_audio = []
for t in range(re_im.shape[0]):
inputs[input_names[0]] = re_im[t]
out, prev_mag, predictor_state, mlp_state = session.run(output_names, inputs)
inputs[input_names[1]] = prev_mag
inputs[input_names[2]] = predictor_state
inputs[input_names[3]] = mlp_state
output_audio.append(out)
output_audio = torch.tensor(np.concatenate(output_audio, 0))
output_audio = output_audio.permute(1, 0, 2).contiguous()
output_audio = torch.view_as_complex(output_audio)
output_audio = torch.istft(output_audio, window, stride, window=hann)
return output_audio.numpy()
def visualize(hr, lr, recon):
sr = CONFIG.DATA.sr
window_size = 1024
window = np.hanning(window_size)
stft_hr = librosa.core.spectrum.stft(hr, n_fft=window_size, hop_length=512, window=window)
stft_hr = 2 * np.abs(stft_hr) / np.sum(window)
stft_lr = librosa.core.spectrum.stft(lr, n_fft=window_size, hop_length=512, window=window)
stft_lr = 2 * np.abs(stft_lr) / np.sum(window)
stft_recon = librosa.core.spectrum.stft(recon, n_fft=window_size, hop_length=512, window=window)
stft_recon = 2 * np.abs(stft_recon) / np.sum(window)
fig, (ax1, ax2, ax3) = plt.subplots(3, 1, sharey=True, sharex=True, figsize=(16, 10))
ax1.title.set_text('Target signal')
ax2.title.set_text('Lossy signal')
ax3.title.set_text('Enhanced signal')
canvas = FigureCanvas(fig)
p = librosa.display.specshow(librosa.amplitude_to_db(stft_hr), ax=ax1, y_axis='linear', x_axis='time', sr=sr)
p = librosa.display.specshow(librosa.amplitude_to_db(stft_lr), ax=ax2, y_axis='linear', x_axis='time', sr=sr)
p = librosa.display.specshow(librosa.amplitude_to_db(stft_recon), ax=ax3, y_axis='linear', x_axis='time', sr=sr)
return fig
packet_size = CONFIG.DATA.EVAL.packet_size
window = CONFIG.DATA.window_size
stride = CONFIG.DATA.stride
title = 'Packet Loss Concealment'
st.set_page_config(page_title=title, page_icon=":sound:")
st.title(title)
uploaded_file = st.file_uploader("Upload your audio file (.wav)")
is_file_uploaded = uploaded_file is not None
if not is_file_uploaded:
uploaded_file = 'sample.wav'
target, sr = librosa.load(uploaded_file, sr=48000)
target = target[:packet_size * (len(target) // packet_size)]
st.subheader('Original audio')
st.audio(uploaded_file)
st.subheader('Choose loss packet percentage')
loss_percent = st.radio('Loss percentage', ['10%', '20%', '30%', '40%'])
loss_percent = float(loss_percent[:-1])/100
mask_gen = MaskGenerator(is_train=False, probs=[(1 - loss_percent, loss_percent)])
lossy_input = target.copy().reshape(-1, packet_size)
mask = mask_gen.gen_mask(len(lossy_input), seed=0)[:, np.newaxis]
lossy_input *= mask
lossy_input = lossy_input.reshape(-1)
hann = torch.sqrt(torch.hann_window(window))
lossy_input_tensor = torch.tensor(lossy_input)
re_im = torch.stft(lossy_input_tensor, window, stride, window=hann, return_complex=False).permute(1, 0, 2).unsqueeze(
1).numpy().astype(np.float32)
session, onnx_model, input_names, output_names = load_model()
if st.button('Conceal lossy audio!'):
with st.spinner('Please wait for completion'):
output = inference(re_im, session, onnx_model, input_names, output_names)
st.subheader('Visualization')
fig = visualize(target, lossy_input, output)
st.pyplot(fig)
st.success('Done!')
st.text('Original audio')
st.audio(target, sample_rate=sr)
st.text('Lossy audio')
st.audio(lossy_input, sample_rate=sr)
st.text('Enhanced audio')
st.audio(output, sample_rate=sr) |