Spaces:
Build error
Build error
# Copyright (c) Facebook, Inc. and its affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
# Authors: Yossi Adi (adiyoss) | |
import argparse | |
from concurrent.futures import ProcessPoolExecutor | |
import json | |
import logging | |
import sys | |
import numpy as np | |
from pesq import pesq | |
from pystoi import stoi | |
import torch | |
from .models.sisnr_loss import cal_loss | |
from .data.data import Validset | |
from . import distrib | |
from .utils import bold, deserialize_model, LogProgress | |
from .evaluate import _run_metrics | |
logger = logging.getLogger(__name__) | |
parser = argparse.ArgumentParser( | |
'Evaluate model automatic selection performance') | |
parser.add_argument('model_path_2spk', | |
help='Path to 2spk model file created by training') | |
parser.add_argument('model_path_3spk', | |
help='Path to 3spk model file created by training') | |
parser.add_argument('model_path_4spk', | |
help='Path to 4spk model file created by training') | |
parser.add_argument('model_path_5spk', | |
help='Path to 5spk model file created by training') | |
parser.add_argument( | |
'data_dir', help='directory including mix.json, s1.json and s2.json files') | |
parser.add_argument('--device', default="cuda") | |
parser.add_argument('--sample_rate', default=8000, | |
type=int, help='Sample rate') | |
parser.add_argument('--thresh', default=0.001, | |
type=float, help='Threshold for model auto selection') | |
parser.add_argument('--num_workers', type=int, default=5) | |
parser.add_argument('-v', '--verbose', action='store_const', const=logging.DEBUG, | |
default=logging.INFO, help="More loggging") | |
# test pariwise matching | |
def pair_wise(padded_source, estimate_source): | |
pair_wise = torch.sum(padded_source.unsqueeze( | |
1)*estimate_source.unsqueeze(2), dim=3) | |
if estimate_source.shape[1] != padded_source.shape[1]: | |
idxs = pair_wise.argmax(dim=1) | |
new_src = torch.FloatTensor(padded_source.shape) | |
for b, idx in enumerate(idxs): | |
new_src[b:, :, ] = estimate_source[b][idx] | |
padded_source_pad = padded_source | |
estimate_source_pad = new_src.cuda() | |
else: | |
padded_source_pad = padded_source | |
estimate_source_pad = estimate_source | |
return estimate_source_pad | |
def evaluate_auto_select(args): | |
total_sisnr = 0 | |
total_pesq = 0 | |
total_stoi = 0 | |
total_cnt = 0 | |
updates = 5 | |
models = list() | |
paths = [args.model_path_2spk, args.model_path_3spk, | |
args.model_path_4spk, args.model_path_5spk] | |
for path in paths: | |
# Load model | |
pkg = torch.load(path) | |
if 'model' in pkg: | |
model = pkg['model'] | |
else: | |
model = pkg | |
model = deserialize_model(model) | |
if 'best_state' in pkg: | |
model.load_state_dict(pkg['best_state']) | |
logger.debug(model) | |
model.eval() | |
model.to(args.device) | |
models.append(model) | |
# Load data | |
dataset = Validset(args.data_dir) | |
data_loader = distrib.loader( | |
dataset, batch_size=1, num_workers=args.num_workers) | |
sr = args.sample_rate | |
y_hat = torch.zeros((4)) | |
pendings = [] | |
with ProcessPoolExecutor(args.num_workers) as pool: | |
with torch.no_grad(): | |
iterator = LogProgress(logger, data_loader, name="Eval estimates") | |
for i, data in enumerate(iterator): | |
# Get batch data | |
mixture, lengths, sources = [x.to(args.device) for x in data] | |
estimated_sources = list() | |
reorder_estimated_sources = list() | |
for model in models: | |
# Forward | |
with torch.no_grad(): | |
raw_estimate = model(mixture)[-1] | |
estimate = pair_wise(sources, raw_estimate) | |
sisnr_loss, snr, estimate, reorder_estimate = cal_loss( | |
sources, estimate, lengths) | |
estimated_sources.insert(0, raw_estimate) | |
reorder_estimated_sources.insert(0, reorder_estimate) | |
# =================== DETECT NUM. NON-ACTIVE CHANNELS ============== # | |
selected_idx = 0 | |
thresh = args.thresh | |
max_spk = 5 | |
mix_spk = 2 | |
ground = (max_spk - mix_spk) | |
while (selected_idx <= ground): | |
no_sils = 0 | |
vals = torch.mean( | |
(estimated_sources[selected_idx]/torch.abs(estimated_sources[selected_idx]).max())**2, axis=2) | |
new_selected_idx = max_spk - len(vals[vals > thresh]) | |
if new_selected_idx == selected_idx: | |
break | |
else: | |
selected_idx = new_selected_idx | |
if selected_idx < 0: | |
selected_idx = 0 | |
elif selected_idx > ground: | |
selected_idx = ground | |
y_hat[ground - selected_idx] += 1 | |
reorder_estimate = reorder_estimated_sources[selected_idx].cpu( | |
) | |
sources = sources.cpu() | |
mixture = mixture.cpu() | |
pendings.append( | |
pool.submit(_run_metrics, sources, reorder_estimate, mixture, None, | |
sr=sr)) | |
total_cnt += sources.shape[0] | |
for pending in LogProgress(logger, pendings, updates, name="Eval metrics"): | |
sisnr_i, pesq_i, stoi_i = pending.result() | |
total_sisnr += sisnr_i | |
total_pesq += pesq_i | |
total_stoi += stoi_i | |
metrics = [total_sisnr, total_pesq, total_stoi] | |
sisnr, pesq, stoi = distrib.average( | |
[m/total_cnt for m in metrics], total_cnt) | |
logger.info(bold(f'Test set performance: SISNRi={sisnr:.2f} ' | |
f'PESQ={pesq}, STOI={stoi}.')) | |
logger.info(f'Two spks prob: {y_hat[0]/(total_cnt)}') | |
logger.info(f'Three spks prob: {y_hat[1]/(total_cnt)}') | |
logger.info(f'Four spks prob: {y_hat[2]/(total_cnt)}') | |
logger.info(f'Five spks prob: {y_hat[3]/(total_cnt)}') | |
return sisnr, pesq, stoi | |
def main(): | |
args = parser.parse_args() | |
logging.basicConfig(stream=sys.stderr, level=args.verbose) | |
logger.debug(args) | |
sisnr, pesq, stoi = evaluate_auto_select(args) | |
json.dump({'sisnr': sisnr, | |
'pesq': pesq, 'stoi': stoi}, sys.stdout) | |
sys.stdout.write('\n') | |
if __name__ == '__main__': | |
main() | |