File size: 4,357 Bytes
109bb65 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# author: adiyoss
import argparse
from concurrent.futures import ProcessPoolExecutor
import json
import logging
import sys
from pesq import pesq
from pystoi import stoi
import torch
from .data import NoisyCleanSet
from .enhance import add_flags, get_estimate
from . import distrib, pretrained
from .utils import bold, LogProgress
logger = logging.getLogger(__name__)
parser = argparse.ArgumentParser(
'denoiser.evaluate',
description='Speech enhancement using Demucs - Evaluate model performance')
add_flags(parser)
parser.add_argument('--data_dir', help='directory including noisy.json and clean.json files')
parser.add_argument('--matching', default="sort", help='set this to dns for the dns dataset.')
parser.add_argument('--no_pesq', action="store_false", dest="pesq", default=True,
help="Don't compute PESQ.")
parser.add_argument('-v', '--verbose', action='store_const', const=logging.DEBUG,
default=logging.INFO, help="More loggging")
def evaluate(args, model=None, data_loader=None):
total_pesq = 0
total_stoi = 0
total_cnt = 0
updates = 5
# Load model
if not model:
model = pretrained.get_model(args).to(args.device)
model.eval()
# Load data
if data_loader is None:
dataset = NoisyCleanSet(args.data_dir, matching=args.matching, sample_rate=args.sample_rate)
data_loader = distrib.loader(dataset, batch_size=1, num_workers=2)
pendings = []
with ProcessPoolExecutor(args.num_workers) as pool:
with torch.no_grad():
iterator = LogProgress(logger, data_loader, name="Eval estimates")
for i, data in enumerate(iterator):
# Get batch data
noisy, clean = [x.to(args.device) for x in data]
# If device is CPU, we do parallel evaluation in each CPU worker.
if args.device == 'cpu':
pendings.append(
pool.submit(_estimate_and_run_metrics, clean, model, noisy, args))
else:
estimate = get_estimate(model, noisy, args)
estimate = estimate.cpu()
clean = clean.cpu()
pendings.append(
pool.submit(_run_metrics, clean, estimate, args))
total_cnt += clean.shape[0]
for pending in LogProgress(logger, pendings, updates, name="Eval metrics"):
pesq_i, stoi_i = pending.result()
total_pesq += pesq_i
total_stoi += stoi_i
metrics = [total_pesq, total_stoi]
pesq, stoi = distrib.average([m/total_cnt for m in metrics], total_cnt)
logger.info(bold(f'Test set performance:PESQ={pesq}, STOI={stoi}.'))
return pesq, stoi
def _estimate_and_run_metrics(clean, model, noisy, args):
estimate = get_estimate(model, noisy, args)
return _run_metrics(clean, estimate, args)
def _run_metrics(clean, estimate, args):
estimate = estimate.numpy()[:, 0]
clean = clean.numpy()[:, 0]
if args.pesq:
pesq_i = get_pesq(clean, estimate, sr=args.sample_rate)
else:
pesq_i = 0
stoi_i = get_stoi(clean, estimate, sr=args.sample_rate)
return pesq_i, stoi_i
def get_pesq(ref_sig, out_sig, sr):
"""Calculate PESQ.
Args:
ref_sig: numpy.ndarray, [B, T]
out_sig: numpy.ndarray, [B, T]
Returns:
PESQ
"""
pesq_val = 0
for i in range(len(ref_sig)):
pesq_val += pesq(sr, ref_sig[i], out_sig[i], 'wb')
return pesq_val
def get_stoi(ref_sig, out_sig, sr):
"""Calculate STOI.
Args:
ref_sig: numpy.ndarray, [B, T]
out_sig: numpy.ndarray, [B, T]
Returns:
STOI
"""
stoi_val = 0
for i in range(len(ref_sig)):
stoi_val += stoi(ref_sig[i], out_sig[i], sr, extended=False)
return stoi_val
def main():
args = parser.parse_args()
logging.basicConfig(stream=sys.stderr, level=args.verbose)
logger.debug(args)
pesq, stoi = evaluate(args)
json.dump({'pesq': pesq, 'stoi': stoi}, sys.stdout)
sys.stdout.write('\n')
if __name__ == '__main__':
main()
|