|
import os |
|
|
|
import numpy as np |
|
import torch |
|
|
|
from .. import AudioSignal |
|
|
|
|
|
def stoi( |
|
estimates: AudioSignal, |
|
references: AudioSignal, |
|
extended: int = False, |
|
): |
|
"""Short term objective intelligibility |
|
Computes the STOI (See [1][2]) of a denoised signal compared to a clean |
|
signal, The output is expected to have a monotonic relation with the |
|
subjective speech-intelligibility, where a higher score denotes better |
|
speech intelligibility. Uses pystoi under the hood. |
|
|
|
Parameters |
|
---------- |
|
estimates : AudioSignal |
|
Denoised speech |
|
references : AudioSignal |
|
Clean original speech |
|
extended : int, optional |
|
Boolean, whether to use the extended STOI described in [3], by default False |
|
|
|
Returns |
|
------- |
|
Tensor[float] |
|
Short time objective intelligibility measure between clean and |
|
denoised speech |
|
|
|
References |
|
---------- |
|
1. C.H.Taal, R.C.Hendriks, R.Heusdens, J.Jensen 'A Short-Time |
|
Objective Intelligibility Measure for Time-Frequency Weighted Noisy |
|
Speech', ICASSP 2010, Texas, Dallas. |
|
2. C.H.Taal, R.C.Hendriks, R.Heusdens, J.Jensen 'An Algorithm for |
|
Intelligibility Prediction of Time-Frequency Weighted Noisy Speech', |
|
IEEE Transactions on Audio, Speech, and Language Processing, 2011. |
|
3. Jesper Jensen and Cees H. Taal, 'An Algorithm for Predicting the |
|
Intelligibility of Speech Masked by Modulated Noise Maskers', |
|
IEEE Transactions on Audio, Speech and Language Processing, 2016. |
|
""" |
|
import pystoi |
|
|
|
estimates = estimates.clone().to_mono() |
|
references = references.clone().to_mono() |
|
|
|
stois = [] |
|
for i in range(estimates.batch_size): |
|
_stoi = pystoi.stoi( |
|
references.audio_data[i, 0].detach().cpu().numpy(), |
|
estimates.audio_data[i, 0].detach().cpu().numpy(), |
|
references.sample_rate, |
|
extended=extended, |
|
) |
|
stois.append(_stoi) |
|
return torch.from_numpy(np.array(stois)) |
|
|
|
|
|
def pesq( |
|
estimates: AudioSignal, |
|
references: AudioSignal, |
|
mode: str = "wb", |
|
target_sr: float = 16000, |
|
): |
|
"""_summary_ |
|
|
|
Parameters |
|
---------- |
|
estimates : AudioSignal |
|
Degraded AudioSignal |
|
references : AudioSignal |
|
Reference AudioSignal |
|
mode : str, optional |
|
'wb' (wide-band) or 'nb' (narrow-band), by default "wb" |
|
target_sr : float, optional |
|
Target sample rate, by default 16000 |
|
|
|
Returns |
|
------- |
|
Tensor[float] |
|
PESQ score: P.862.2 Prediction (MOS-LQO) |
|
""" |
|
from pesq import pesq as pesq_fn |
|
|
|
estimates = estimates.clone().to_mono().resample(target_sr) |
|
references = references.clone().to_mono().resample(target_sr) |
|
|
|
pesqs = [] |
|
for i in range(estimates.batch_size): |
|
_pesq = pesq_fn( |
|
estimates.sample_rate, |
|
references.audio_data[i, 0].detach().cpu().numpy(), |
|
estimates.audio_data[i, 0].detach().cpu().numpy(), |
|
mode, |
|
) |
|
pesqs.append(_pesq) |
|
return torch.from_numpy(np.array(pesqs)) |
|
|
|
|
|
def visqol( |
|
estimates: AudioSignal, |
|
references: AudioSignal, |
|
mode: str = "audio", |
|
): |
|
"""ViSQOL score. |
|
|
|
Parameters |
|
---------- |
|
estimates : AudioSignal |
|
Degraded AudioSignal |
|
references : AudioSignal |
|
Reference AudioSignal |
|
mode : str, optional |
|
'audio' or 'speech', by default 'audio' |
|
|
|
Returns |
|
------- |
|
Tensor[float] |
|
ViSQOL score (MOS-LQO) |
|
""" |
|
from visqol import visqol_lib_py |
|
from visqol.pb2 import visqol_config_pb2 |
|
from visqol.pb2 import similarity_result_pb2 |
|
|
|
config = visqol_config_pb2.VisqolConfig() |
|
if mode == "audio": |
|
target_sr = 48000 |
|
config.options.use_speech_scoring = False |
|
svr_model_path = "libsvm_nu_svr_model.txt" |
|
elif mode == "speech": |
|
target_sr = 16000 |
|
config.options.use_speech_scoring = True |
|
svr_model_path = "lattice_tcditugenmeetpackhref_ls2_nl60_lr12_bs2048_learn.005_ep2400_train1_7_raw.tflite" |
|
else: |
|
raise ValueError(f"Unrecognized mode: {mode}") |
|
config.audio.sample_rate = target_sr |
|
config.options.svr_model_path = os.path.join( |
|
os.path.dirname(visqol_lib_py.__file__), "model", svr_model_path |
|
) |
|
|
|
api = visqol_lib_py.VisqolApi() |
|
api.Create(config) |
|
|
|
estimates = estimates.clone().to_mono().resample(target_sr) |
|
references = references.clone().to_mono().resample(target_sr) |
|
|
|
visqols = [] |
|
for i in range(estimates.batch_size): |
|
_visqol = api.Measure( |
|
references.audio_data[i, 0].detach().cpu().numpy().astype(float), |
|
estimates.audio_data[i, 0].detach().cpu().numpy().astype(float), |
|
) |
|
visqols.append(_visqol.moslqo) |
|
return torch.from_numpy(np.array(visqols)) |
|
|