|
import numpy as np |
|
|
|
import streamlit as st |
|
import librosa |
|
import soundfile as sf |
|
import librosa.display |
|
from config import CONFIG |
|
import torch |
|
from dataset import MaskGenerator |
|
import onnxruntime, onnx |
|
import matplotlib.pyplot as plt |
|
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas |
|
from pystoi import stoi |
|
from pesq import pesq |
|
import pandas as pd |
|
import torchaudio |
|
|
|
|
|
from torchmetrics.audio import ShortTimeObjectiveIntelligibility as STOI |
|
from torchmetrics.audio.pesq import PerceptualEvaluationSpeechQuality as PESQ |
|
|
|
|
|
from PLCMOS.plc_mos import PLCMOSEstimator |
|
from speechmos import dnsmos |
|
from speechmos import plcmos |
|
|
|
import speech_recognition as speech_r |
|
from jiwer import wer |
|
import time |
|
|
|
@st.cache |
|
def load_model(model): |
|
path = 'lightning_logs/version_0/checkpoints/' + str(model) |
|
onnx_model = onnx.load(path) |
|
options = onnxruntime.SessionOptions() |
|
options.intra_op_num_threads = 2 |
|
options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL |
|
session = onnxruntime.InferenceSession(path, options) |
|
input_names = [x.name for x in session.get_inputs()] |
|
output_names = [x.name for x in session.get_outputs()] |
|
return session, onnx_model, input_names, output_names |
|
|
|
def inference(re_im, session, onnx_model, input_names, output_names): |
|
inputs = {input_names[i]: np.zeros([d.dim_value for d in _input.type.tensor_type.shape.dim], |
|
dtype=np.float32) |
|
for i, _input in enumerate(onnx_model.graph.input) |
|
} |
|
|
|
output_audio = [] |
|
for t in range(re_im.shape[0]): |
|
inputs[input_names[0]] = re_im[t] |
|
out, prev_mag, predictor_state, mlp_state = session.run(output_names, inputs) |
|
inputs[input_names[1]] = prev_mag |
|
inputs[input_names[2]] = predictor_state |
|
inputs[input_names[3]] = mlp_state |
|
output_audio.append(out) |
|
|
|
output_audio = torch.tensor(np.concatenate(output_audio, 0)) |
|
output_audio = output_audio.permute(1, 0, 2).contiguous() |
|
output_audio = torch.view_as_complex(output_audio) |
|
output_audio = torch.istft(output_audio, window, stride, window=hann) |
|
return output_audio.numpy() |
|
|
|
def visualize(hr, lr, recon, sr): |
|
sr = sr |
|
window_size = 1024 |
|
window = np.hanning(window_size) |
|
|
|
stft_hr = librosa.core.spectrum.stft(hr, n_fft=window_size, hop_length=512, window=window) |
|
stft_hr = 2 * np.abs(stft_hr) / np.sum(window) |
|
|
|
stft_lr = librosa.core.spectrum.stft(lr, n_fft=window_size, hop_length=512, window=window) |
|
stft_lr = 2 * np.abs(stft_lr) / np.sum(window) |
|
|
|
stft_recon = librosa.core.spectrum.stft(recon, n_fft=window_size, hop_length=512, window=window) |
|
stft_recon = 2 * np.abs(stft_recon) / np.sum(window) |
|
|
|
fig, (ax1, ax2, ax3) = plt.subplots(3, 1, sharey=True, sharex=True, figsize=(16, 12)) |
|
ax1.title.set_text('Оригинальный сигнал') |
|
ax2.title.set_text('Сигнал с потерями') |
|
ax3.title.set_text('Улучшенный сигнал') |
|
|
|
canvas = FigureCanvas(fig) |
|
p = librosa.display.specshow(librosa.amplitude_to_db(stft_hr), ax=ax1, y_axis='log', x_axis='time', sr=sr) |
|
p = librosa.display.specshow(librosa.amplitude_to_db(stft_lr), ax=ax2, y_axis='log', x_axis='time', sr=sr) |
|
p = librosa.display.specshow(librosa.amplitude_to_db(stft_recon), ax=ax3, y_axis='log', x_axis='time', sr=sr) |
|
|
|
ax1.set_xlabel('Время, с') |
|
ax1.set_ylabel('Частота, Гц') |
|
ax2.set_xlabel('Время, с') |
|
ax2.set_ylabel('Частота, Гц') |
|
ax3.set_xlabel('Время, с') |
|
ax3.set_ylabel('Частота, Гц') |
|
return fig |
|
|
|
|
|
|
|
def waveplot(hr, lr, recon, sr): |
|
fig, (ax1, ax2, ax3) = plt.subplots(3, 1, sharey=True, sharex=True, figsize=(16, 12)) |
|
ax1.title.set_text('Оригинальный сигнал') |
|
ax2.title.set_text('Сигнал с потерями') |
|
ax3.title.set_text('Улучшенный сигнал') |
|
|
|
canvas = FigureCanvas(fig) |
|
p = librosa.display.waveshow(hr, ax=ax1, sr=sr) |
|
p = librosa.display.waveshow(lr, ax=ax2, sr=sr) |
|
p = librosa.display.waveshow(recon, ax=ax3, sr=sr) |
|
|
|
ax1.set_xlabel('Время, с') |
|
|
|
ax2.set_xlabel('Время, с') |
|
|
|
ax3.set_xlabel('Время, с') |
|
|
|
return fig |
|
|
|
def sign_x_y(x,y): |
|
if x>y: |
|
return '-' |
|
else: |
|
return '' |
|
|
|
packet_size = CONFIG.DATA.EVAL.packet_size |
|
window = CONFIG.DATA.window_size |
|
stride = CONFIG.DATA.stride |
|
|
|
title = 'Маскировка потерь пакетов' |
|
st.set_page_config(page_title=title, page_icon=":sound:") |
|
st.title(title) |
|
|
|
st.subheader('1. Загрузка аудио') |
|
uploaded_file = st.file_uploader("Загрузите аудио формата (.wav) 48 КГц") |
|
|
|
is_file_uploaded = uploaded_file is not None |
|
if not is_file_uploaded: |
|
uploaded_file = 'sample.wav' |
|
|
|
target, sr = librosa.load(uploaded_file, sr=48000) |
|
target = target[:packet_size * (len(target) // packet_size)] |
|
|
|
st.text('Ваше аудио') |
|
st.audio(uploaded_file) |
|
|
|
model_ver = st.selectbox( |
|
'Веса оригинальной модели выбраны по умолчанию. Выберите модель', |
|
('frn.onnx', 'frn_out_QInt16.onnx', 'frn_out_QInt8.onnx', 'frn_out_QUInt8.onnx', 'frn_out_QUInt16.onnx', 'frn_fp16.onnx')) |
|
|
|
st.write('Вы выбрали:', model_ver) |
|
|
|
lang = st.selectbox( |
|
'Выберите язык вашего аудио для корректной работы распознавания речи', |
|
('ru-RU', 'en-EN')) |
|
|
|
st.write('Вы выбрали:', lang) |
|
|
|
|
|
st.subheader('2. Выберите желаемый процент потерь') |
|
slider = [st.slider("Ожидаемый процент потерь для генератора потерь цепи Маркова", 0, 100, step=1)] |
|
loss_percent = float(slider[0])/100 |
|
mask_gen = MaskGenerator(is_train=False, probs=[(1 - loss_percent, loss_percent)]) |
|
lossy_input = target.copy().reshape(-1, packet_size) |
|
mask = mask_gen.gen_mask(len(lossy_input), seed=0)[:, np.newaxis] |
|
lossy_input *= mask |
|
lossy_input = lossy_input.reshape(-1) |
|
hann = torch.sqrt(torch.hann_window(window)) |
|
lossy_input_tensor = torch.tensor(lossy_input) |
|
re_im = torch.stft(lossy_input_tensor, window, stride, window=hann, return_complex=False).permute(1, 0, 2).unsqueeze(1).numpy().astype(np.float32) |
|
|
|
session, onnx_model, input_names, output_names = load_model(model_ver) |
|
|
|
with st.sidebar: |
|
st.title('Full-band Reccurent Network', help = 'https://arxiv.org/abs/2211.04071') |
|
authors_c = st.container() |
|
authors_c.write('Авторы модели: Viet-Anh Nguyen and Anh H. T. Nguyen and Andy W. H. Khong') |
|
st.link_button("Github авторов", "https://github.com/Crystalsound/FRN", help = 'Кликни на меня') |
|
description_c = st.container() |
|
description_c.write("Это дополненный space оригинальной FRN модели. К исходной сети были применены методы квантования onnxruntime для уменьшения размера .onnx файла и повышения скорости обработки аудио при некотором ухудшении результата. В этом space вы можете сгенерировать потери пакетов и оценить работу модели визуально, на слух, и по нескольким метрикам.") |
|
st.header("Packet Loss Concealment", help = 'https://arxiv.org/abs/2204.05222') |
|
PLC_c = st.container() |
|
PLC_c.write("PLC (Packet Loss Concealment) - это технологии, созданные для борьбы с потерей пакетов при передаче речи в IP сети. Для ознакомления с данной темой рекомендуется статья INTERSPEECH 2022 Audio Deep Packet Loss Concealment Challenge с результатами одноимённого конкурса.") |
|
st.header("Метрики") |
|
Metrcs_c = st.container() |
|
Metrcs_c.write("Для оценивания речи были выбраны следующие метрики: PESQ, STOI, PLCMOS разных версий и WER. С каждой из них вы можете ознакомиться, перейдя по ссылке рядом с заголовком.") |
|
st.subheader("PESQ", help = 'https://ieeexplore.ieee.org/document/941023') |
|
st.write('Перцептивная оценка качества речи') |
|
st.subheader("STOI", help = 'https://ieeexplore.ieee.org/document/5495701') |
|
st.write('Индекс объективной кратковременной разборчивости') |
|
st.subheader("PLCMOS", help = 'https://arxiv.org/abs/2305.15127') |
|
PLCMOS_c=st.container() |
|
PLCMOS_c.write("Использованы две версии данной метрики (v1, v2). v1 - это первая версия, разработанная для INTERSPEECH 2022 Audio Deep Packet Loss Concealment Challenge. v2 - улучшенная версия метрики, вышедшая в 2023 году. Особенность - неэталонная метрика, которая выдаёт оценку, опираясь только на аудио с потерями без использования информации о исходном (оригинальном). Поставляется как часть пакета speechmos.") |
|
st.subheader("WAcc", help = 'https://docs.speechmatics.com/tutorials/accuracy-benchmarking') |
|
WAcc_c=st.container() |
|
WAcc_c.write('Первоначально использовалась метрика WER (Word Error Rate). Она выражает долю ошибочно распознанных слов. Я считаю, что для восприятия будет проще обратная ей - WAcc (Word Accuracy), то есть доля слов, которые распознаны верно. Для распознавания используется пакет jiwer') |
|
|
|
if st.button('Сгенерировать потери'): |
|
with st.spinner('Ожидайте...'): |
|
start_time = time.time() |
|
output = inference(re_im, session, onnx_model, input_names, output_names) |
|
st.text(str(time.time() - start_time)) |
|
st.subheader('3. Визуализация аудио') |
|
fig_1 = visualize(target, lossy_input, output, sr) |
|
fig_2 = waveplot(target, lossy_input, output, sr) |
|
tab1, tab2 = st.tabs(["Частотная область", "Временная область"]) |
|
|
|
with tab1: |
|
st.header("Частотная область") |
|
st.pyplot(fig_1) |
|
|
|
with tab2: |
|
st.header("Временная область") |
|
st.pyplot(fig_2) |
|
|
|
|
|
sf.write('target.wav', target, sr) |
|
sf.write('lossy.wav', lossy_input, sr) |
|
sf.write('enhanced.wav', output, sr) |
|
st.text('Оригинальное аудио') |
|
st.audio('target.wav') |
|
st.text('Аудио с потерями') |
|
st.audio('lossy.wav') |
|
st.text('Улучшенное аудио') |
|
st.audio('enhanced.wav') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
data_clean, samplerate = sf.read('target.wav') |
|
data_lossy, samplerate = sf.read('lossy.wav') |
|
data_enhanced, samplerate = sf.read('enhanced.wav') |
|
min_len = min(data_clean.shape[0], data_lossy.shape[0], data_enhanced.shape[0]) |
|
data_clean = data_clean[:min_len] |
|
data_lossy = data_lossy[:min_len] |
|
data_enhanced = data_enhanced[:min_len] |
|
|
|
|
|
stoi_orig = round(stoi(data_clean, data_clean, samplerate, extended=False),5) |
|
stoi_lossy = round(stoi(data_clean, data_lossy , samplerate, extended=False),5) |
|
stoi_enhanced = round(stoi(data_clean, data_enhanced, samplerate, extended=False),5) |
|
|
|
stoi_mass=[stoi_orig, stoi_lossy, stoi_enhanced] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if samplerate != 16000: |
|
data_lossy = librosa.resample(data_lossy, orig_sr=48000, target_sr=16000) |
|
data_clean = librosa.resample(data_clean, orig_sr=48000, target_sr=16000) |
|
data_enhanced = librosa.resample(data_enhanced, orig_sr=48000, target_sr=16000) |
|
|
|
|
|
|
|
pesq_orig = pesq(fs = 16000, ref = data_clean, deg = data_clean, mode='wb') |
|
pesq_lossy = pesq(fs = 16000, ref = data_clean, deg = data_lossy, mode='wb') |
|
pesq_enhanced = pesq(fs = 16000, ref = data_clean, deg = data_enhanced, mode='wb') |
|
|
|
psq_mas=[pesq_orig, pesq_lossy, pesq_enhanced] |
|
|
|
|
|
|
|
data_clean, fs = sf.read('target.wav') |
|
data_lossy, fs = sf.read('lossy.wav') |
|
data_enhanced, fs = sf.read('enhanced.wav') |
|
|
|
if fs!= 16000: |
|
data_lossy = librosa.resample(data_lossy, orig_sr=48000, target_sr=16000) |
|
data_clean = librosa.resample(data_clean, orig_sr=48000, target_sr=16000) |
|
data_enhanced = librosa.resample(data_enhanced, orig_sr=48000, target_sr=16000) |
|
|
|
PLC_example=PLCMOSEstimator() |
|
PLC_org = PLC_example.run(audio_degraded=data_clean, audio_clean=data_clean)[0] |
|
PLC_lossy = PLC_example.run(audio_degraded=data_lossy, audio_clean=data_clean)[0] |
|
PLC_enhanced = PLC_example.run(audio_degraded=data_enhanced, audio_clean=data_clean)[0] |
|
|
|
PLC_massv1 = [PLC_org, PLC_lossy, PLC_enhanced] |
|
|
|
|
|
|
|
df_1 = pd.DataFrame(columns=['Audio', 'PESQ', 'STOI', 'PLCMOSv1']) |
|
|
|
df_1['Аудио'] = ['Оригинал', 'С потерями', 'Улучшенное'] |
|
|
|
df_1['PESQ'] = psq_mas |
|
|
|
df_1['STOI'] = stoi_mass |
|
|
|
|
|
df_1['PLCMOSv1'] = PLC_massv1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
PLC_massv2 = [plcmos.run("target.wav", sr=16000)['plcmos'], plcmos.run("lossy.wav", sr=16000)['plcmos'], plcmos.run("enhanced.wav", sr=16000)['plcmos']] |
|
|
|
|
|
|
|
df_1['PLCMOSv2'] = PLC_massv2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
r = speech_r.Recognizer() |
|
|
|
|
|
|
|
|
|
harvard = speech_r.AudioFile('target.wav') |
|
with harvard as source: |
|
audio = r.record(source) |
|
|
|
orig = r.recognize_google(audio, language = str(lang)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
harvard = speech_r.AudioFile('lossy.wav') |
|
|
|
|
|
|
|
|
|
try: |
|
with harvard as source: |
|
audio = r.record(source) |
|
lossy = r.recognize_google(audio, language = str(lang)) |
|
|
|
except speech_r.UnknownValueError: |
|
|
|
lossy = '' |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
harvard = speech_r.AudioFile('enhanced.wav') |
|
|
|
|
|
|
|
|
|
try: |
|
with harvard as source: |
|
audio = r.record(source) |
|
enhanced = r.recognize_google(audio, language = str(lang)) |
|
|
|
except speech_r.UnknownValueError: |
|
|
|
enhanced = '' |
|
|
|
|
|
|
|
error1 = wer(orig, orig) |
|
error2 = wer(orig, lossy) |
|
error3 = wer(orig, enhanced) |
|
WAcc_mass=[(1-error1)*100, (1-error2)*100, (1-error3)*100] |
|
|
|
df_1['WAcc'] = WAcc_mass |
|
|
|
st.subheader('4. Метрики аудио') |
|
|
|
st.write("#### "+"Оригинал") |
|
col1, col2, col3, col4, col5 = st.columns(5) |
|
col1.metric("PESQ", value = round(psq_mas[0],3)) |
|
col2.metric("STOI", value = round(stoi_mass[0],3)) |
|
col3.metric("PLCMOSv1", value = round(PLC_massv1[0],3)) |
|
col4.metric("PLCMOSv2", value = round(PLC_massv2[0],3)) |
|
col5.metric("WAcc", value = round(WAcc_mass[0],3)) |
|
|
|
|
|
st.write("#### "+"С потерями") |
|
col1, col2, col3, col4, col5 = st.columns(5) |
|
col1.metric("PESQ", value = round(psq_mas[1],3), delta = str(round(-(abs(psq_mas[1] - psq_mas[0]) / psq_mas[0]) * 100.0,3))+'%') |
|
col2.metric("STOI", value = round(stoi_mass[1],3), delta = str(round(-(abs(stoi_mass[1] - stoi_mass[0]) / stoi_mass[0]) * 100.0,3))+'%') |
|
col3.metric("PLCMOSv1", value = round(PLC_massv1[1],3), delta = str(round(-(abs(PLC_massv1[1] - PLC_massv1[0]) / PLC_massv1[0]) * 100.0,3))+'%') |
|
col4.metric("PLCMOSv2", value = round(PLC_massv2[1],3), delta = str(round(-(abs(PLC_massv2[1] - PLC_massv2[0]) / PLC_massv2[0]) * 100.0,3))+'%') |
|
col5.metric("WAcc", value = round(WAcc_mass[1],3), delta = str(round(-(abs(WAcc_mass[1] - WAcc_mass[0]) / WAcc_mass[0]) * 100.0,3))+'%') |
|
|
|
|
|
st.write("#### "+"Улучшенное") |
|
col1, col2, col3, col4, col5 = st.columns(5) |
|
PESQ_s = sign_x_y(psq_mas[1], psq_mas[2]) |
|
col1.metric("PESQ", value = round(psq_mas[2],3), delta = PESQ_s + str(round((abs(psq_mas[2] - psq_mas[1]) / psq_mas[1]) * 100.0,3))+'%') |
|
STOI_s = sign_x_y(stoi_mass[1], stoi_mass[2]) |
|
col2.metric("STOI", value = round(stoi_mass[2],3), delta = STOI_s + str(round((abs(stoi_mass[2] - stoi_mass[1]) / stoi_mass[1]) * 100.0,3))+'%') |
|
PLCv1_s = sign_x_y(PLC_massv1[1], PLC_massv1[2]) |
|
col3.metric("PLCMOSv1", value = round(PLC_massv1[2],3), delta = PLCv1_s + str(round((abs(PLC_massv1[2] - PLC_massv1[1]) / PLC_massv1[1]) * 100.0,3))+'%') |
|
PLCv2_s = sign_x_y(PLC_massv2[1], PLC_massv2[2]) |
|
col4.metric("PLCMOSv2", value = round(PLC_massv2[2],3), delta = PLCv2_s + str(round((abs(PLC_massv2[2] - PLC_massv2[1]) / PLC_massv2[1]) * 100.0,3))+'%') |
|
WER_s = sign_x_y(WAcc_mass[1], WAcc_mass[2]) |
|
if WAcc_mass[1]==0: |
|
if WAcc_mass[2]!=0: |
|
col5.metric("WAcc", value = round(WAcc_mass[2],3), delta = WER_s + str(round((abs(WAcc_mass[2] - 0.001) / 0.001) * 100.0,3))+'%') |
|
else: |
|
col5.metric("WAcc", value = round(WAcc_mass[2],3)) |
|
else: |
|
col5.metric("WAcc", value = round(WAcc_mass[2],3), delta = WER_s + str(round((abs(WAcc_mass[2] - WAcc_mass[1]) / WAcc_mass[1]) * 100.0,3))+'%') |
|
|
|
tab1, tab2, tab3, tab4, tab5 = st.tabs(["PESQ", "STOI", "PLCMOSv1", "PLCMOSv2", "WAcc"]) |
|
|
|
with tab1: |
|
st.header("PESQ") |
|
st.bar_chart(df_1, x="Аудио", y="PESQ") |
|
with tab2: |
|
st.header("STOI") |
|
st.bar_chart(df_1, x="Аудио", y="STOI") |
|
with tab3: |
|
st.header("PLCMOSv1") |
|
st.bar_chart(df_1, x="Аудио", y="PLCMOSv1") |
|
with tab4: |
|
st.header("PLCMOSv2") |
|
st.bar_chart(df_1, x="Аудио", y="PLCMOSv2") |
|
with tab5: |
|
st.header("WAcc") |
|
st.bar_chart(df_1, x="Аудио", y="WAcc") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|