Spaces:
Build error
Build error
# Copyright (c) Facebook, Inc. and its affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
# Authors: Eliya Nachmani (enk100), Yossi Adi (adiyoss), Lior Wolf and Alexandre Defossez (adefossez) | |
import argparse | |
from concurrent.futures import ProcessPoolExecutor | |
import json | |
import logging | |
import sys | |
import numpy as np | |
from pesq import pesq | |
from pystoi import stoi | |
import torch | |
from .models.sisnr_loss import cal_loss | |
from .data.data import Validset | |
from . import distrib | |
from .utils import bold, deserialize_model, LogProgress | |
logger = logging.getLogger(__name__) | |
parser = argparse.ArgumentParser( | |
'Evaluate separation performance using MulCat blocks') | |
parser.add_argument('model_path', | |
help='Path to model file created by training') | |
parser.add_argument('data_dir', | |
help='directory including mix.json, s1.json, s2.json, ... files') | |
parser.add_argument('--device', default="cuda") | |
parser.add_argument('--sdr', type=int, default=0) | |
parser.add_argument('--sample_rate', default=16000, | |
type=int, help='Sample rate') | |
parser.add_argument('--num_workers', type=int, default=5) | |
parser.add_argument('-v', '--verbose', action='store_const', const=logging.DEBUG, | |
default=logging.INFO, help="More loggging") | |
def evaluate(args, model=None, data_loader=None, sr=None): | |
total_sisnr = 0 | |
total_pesq = 0 | |
total_stoi = 0 | |
total_cnt = 0 | |
updates = 5 | |
# Load model | |
if not model: | |
pkg = torch.load(args.model_path, map_location=args.device) | |
if 'model' in pkg: | |
model = pkg['model'] | |
else: | |
model = pkg | |
model = deserialize_model(model) | |
if 'best_state' in pkg: | |
model.load_state_dict(pkg['best_state']) | |
logger.debug(model) | |
model.eval() | |
model.to(args.device) | |
# Load data | |
if not data_loader: | |
dataset = Validset(args.data_dir) | |
data_loader = distrib.loader( | |
dataset, batch_size=1, num_workers=args.num_workers) | |
sr = args.sample_rate | |
pendings = [] | |
with ProcessPoolExecutor(args.num_workers) as pool: | |
with torch.no_grad(): | |
iterator = LogProgress(logger, data_loader, name="Eval estimates") | |
for i, data in enumerate(iterator): | |
# Get batch data | |
mixture, lengths, sources = [x.to(args.device) for x in data] | |
# Forward | |
with torch.no_grad(): | |
mixture /= mixture.max() | |
estimate = model(mixture)[-1] | |
sisnr_loss, snr, estimate, reorder_estimate = cal_loss( | |
sources, estimate, lengths) | |
reorder_estimate = reorder_estimate.cpu() | |
sources = sources.cpu() | |
mixture = mixture.cpu() | |
pendings.append( | |
pool.submit(_run_metrics, sources, reorder_estimate, mixture, None, | |
sr=sr)) | |
total_cnt += sources.shape[0] | |
for pending in LogProgress(logger, pendings, updates, name="Eval metrics"): | |
sisnr_i, pesq_i, stoi_i = pending.result() | |
total_sisnr += sisnr_i | |
total_pesq += pesq_i | |
total_stoi += stoi_i | |
metrics = [total_sisnr, total_pesq, total_stoi] | |
sisnr, pesq, stoi = distrib.average( | |
[m/total_cnt for m in metrics], total_cnt) | |
logger.info( | |
bold(f'Test set performance: SISNRi={sisnr:.2f} PESQ={pesq}, STOI={stoi}.')) | |
return sisnr, pesq, stoi | |
def _run_metrics(clean, estimate, mix, model, sr, pesq=False): | |
if model is not None: | |
torch.set_num_threads(1) | |
# parallel evaluation here | |
with torch.no_grad(): | |
estimate = model(estimate)[-1] | |
estimate = estimate.numpy() | |
clean = clean.numpy() | |
mix = mix.numpy() | |
sisnr = cal_SISNRi(clean, estimate, mix) | |
if pesq: | |
pesq_i = cal_PESQ(clean, estimate, sr=sr) | |
stoi_i = cal_STOI(clean, estimate, sr=sr) | |
else: | |
pesq_i = 0 | |
stoi_i = 0 | |
return sisnr.mean(), pesq_i, stoi_i | |
def cal_SISNR(ref_sig, out_sig, eps=1e-8): | |
"""Calcuate Scale-Invariant Source-to-Noise Ratio (SI-SNR) | |
Args: | |
ref_sig: numpy.ndarray, [B, T] | |
out_sig: numpy.ndarray, [B, T] | |
Returns: | |
SISNR | |
""" | |
assert len(ref_sig) == len(out_sig) | |
B, T = ref_sig.shape | |
ref_sig = ref_sig - np.mean(ref_sig, axis=1).reshape(B, 1) | |
out_sig = out_sig - np.mean(out_sig, axis=1).reshape(B, 1) | |
ref_energy = (np.sum(ref_sig ** 2, axis=1) + eps).reshape(B, 1) | |
proj = (np.sum(ref_sig * out_sig, axis=1).reshape(B, 1)) * \ | |
ref_sig / ref_energy | |
noise = out_sig - proj | |
ratio = np.sum(proj ** 2, axis=1) / (np.sum(noise ** 2, axis=1) + eps) | |
sisnr = 10 * np.log(ratio + eps) / np.log(10.0) | |
return sisnr.mean() | |
def cal_PESQ(ref_sig, out_sig, sr): | |
"""Calculate PESQ. | |
Args: | |
ref_sig: numpy.ndarray, [B, C, T] | |
out_sig: numpy.ndarray, [B, C, T] | |
Returns | |
PESQ | |
""" | |
B, C, T = ref_sig.shape | |
ref_sig = ref_sig.reshape(B*C, T) | |
out_sig = out_sig.reshape(B*C, T) | |
pesq_val = 0 | |
for i in range(len(ref_sig)): | |
pesq_val += pesq(sr, ref_sig[i], out_sig[i], 'nb') | |
return pesq_val / (B*C) | |
def cal_STOI(ref_sig, out_sig, sr): | |
"""Calculate STOI. | |
Args: | |
ref_sig: numpy.ndarray, [B, C, T] | |
out_sig: numpy.ndarray, [B, C, T] | |
Returns: | |
STOI | |
""" | |
B, C, T = ref_sig.shape | |
ref_sig = ref_sig.reshape(B*C, T) | |
out_sig = out_sig.reshape(B*C, T) | |
try: | |
stoi_val = 0 | |
for i in range(len(ref_sig)): | |
stoi_val += stoi(ref_sig[i], out_sig[i], sr, extended=False) | |
return stoi_val / (B*C) | |
except: | |
return 0 | |
def cal_SISNRi(src_ref, src_est, mix): | |
"""Calculate Scale-Invariant Source-to-Noise Ratio improvement (SI-SNRi) | |
Args: | |
src_ref: numpy.ndarray, [B, C, T] | |
src_est: numpy.ndarray, [B, C, T], reordered by best PIT permutation | |
mix: numpy.ndarray, [T] | |
Returns: | |
average_SISNRi | |
""" | |
avg_SISNRi = 0.0 | |
B, C, T = src_ref.shape | |
for c in range(C): | |
sisnr = cal_SISNR(src_ref[:, c], src_est[:, c]) | |
sisnrb = cal_SISNR(src_ref[:, c], mix) | |
avg_SISNRi += (sisnr - sisnrb) | |
avg_SISNRi /= C | |
return avg_SISNRi | |
def main(): | |
args = parser.parse_args() | |
logging.basicConfig(stream=sys.stderr, level=args.verbose) | |
logger.debug(args) | |
sisnr, pesq, stoi = evaluate(args) | |
json.dump({'sisnr': sisnr, | |
'pesq': pesq, 'stoi': stoi}, sys.stdout) | |
sys.stdout.write('\n') | |
if __name__ == '__main__': | |
main() | |