JustinLin610
update
10b0761
raw
history blame
4.15 kB
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
Signal processing-based evaluation using waveforms
"""
import csv
import numpy as np
import os.path as op
import torch
import tqdm
from tabulate import tabulate
import torchaudio
from examples.speech_synthesis.utils import batch_mel_spectral_distortion
from fairseq.tasks.text_to_speech import batch_mel_cepstral_distortion
def load_eval_spec(path):
with open(path) as f:
reader = csv.DictReader(f, delimiter='\t')
samples = list(reader)
return samples
def eval_distortion(samples, distortion_fn, device="cuda"):
nmiss = 0
results = []
for sample in tqdm.tqdm(samples):
if not op.isfile(sample["ref"]) or not op.isfile(sample["syn"]):
nmiss += 1
results.append(None)
continue
# assume single channel
yref, sr = torchaudio.load(sample["ref"])
ysyn, _sr = torchaudio.load(sample["syn"])
yref, ysyn = yref[0].to(device), ysyn[0].to(device)
assert sr == _sr, f"{sr} != {_sr}"
distortion, extra = distortion_fn([yref], [ysyn], sr, None)[0]
_, _, _, _, _, pathmap = extra
nins = torch.sum(pathmap.sum(dim=1) - 1) # extra frames in syn
ndel = torch.sum(pathmap.sum(dim=0) - 1) # missing frames from syn
results.append(
(distortion.item(), # path distortion
pathmap.size(0), # yref num frames
pathmap.size(1), # ysyn num frames
pathmap.sum().item(), # path length
nins.item(), # insertion
ndel.item(), # deletion
)
)
return results
def eval_mel_cepstral_distortion(samples, device="cuda"):
return eval_distortion(samples, batch_mel_cepstral_distortion, device)
def eval_mel_spectral_distortion(samples, device="cuda"):
return eval_distortion(samples, batch_mel_spectral_distortion, device)
def print_results(results, show_bin):
results = np.array(list(filter(lambda x: x is not None, results)))
np.set_printoptions(precision=3)
def _print_result(results):
dist, dur_ref, dur_syn, dur_ali, nins, ndel = results.sum(axis=0)
res = {
"nutt": len(results),
"dist": dist,
"dur_ref": int(dur_ref),
"dur_syn": int(dur_syn),
"dur_ali": int(dur_ali),
"dist_per_ref_frm": dist/dur_ref,
"dist_per_syn_frm": dist/dur_syn,
"dist_per_ali_frm": dist/dur_ali,
"ins": nins/dur_ref,
"del": ndel/dur_ref,
}
print(tabulate(
[res.values()],
res.keys(),
floatfmt=".4f"
))
print(">>>> ALL")
_print_result(results)
if show_bin:
edges = [0, 200, 400, 600, 800, 1000, 2000, 4000]
for i in range(1, len(edges)):
mask = np.logical_and(results[:, 1] >= edges[i-1],
results[:, 1] < edges[i])
if not mask.any():
continue
bin_results = results[mask]
print(f">>>> ({edges[i-1]}, {edges[i]})")
_print_result(bin_results)
def main(eval_spec, mcd, msd, show_bin):
samples = load_eval_spec(eval_spec)
device = "cpu"
if mcd:
print("===== Evaluate Mean Cepstral Distortion =====")
results = eval_mel_cepstral_distortion(samples, device)
print_results(results, show_bin)
if msd:
print("===== Evaluate Mean Spectral Distortion =====")
results = eval_mel_spectral_distortion(samples, device)
print_results(results, show_bin)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("eval_spec")
parser.add_argument("--mcd", action="store_true")
parser.add_argument("--msd", action="store_true")
parser.add_argument("--show-bin", action="store_true")
args = parser.parse_args()
main(args.eval_spec, args.mcd, args.msd, args.show_bin)