JustinLin610's picture
first commit
ee21b96
raw
history blame
No virus
2.26 kB
#!/usr/bin/env python3 -u
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
Implement unsupervised metric for decoding hyperparameter selection:
$$ alpha * LM_PPL + ViterbitUER(%) * 100 $$
"""
import argparse
import logging
import sys
import editdistance
logging.root.setLevel(logging.INFO)
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logger = logging.getLogger(__name__)
def get_parser():
parser = argparse.ArgumentParser()
parser.add_argument("-s", "--hypo", help="hypo transcription", required=True)
parser.add_argument(
"-r", "--reference", help="reference transcription", required=True
)
return parser
def compute_wer(ref_uid_to_tra, hyp_uid_to_tra, g2p):
d_cnt = 0
w_cnt = 0
w_cnt_h = 0
for uid in hyp_uid_to_tra:
ref = ref_uid_to_tra[uid].split()
if g2p is not None:
hyp = g2p(hyp_uid_to_tra[uid])
hyp = [p for p in hyp if p != "'" and p != " "]
hyp = [p[:-1] if p[-1].isnumeric() else p for p in hyp]
else:
hyp = hyp_uid_to_tra[uid].split()
d_cnt += editdistance.eval(ref, hyp)
w_cnt += len(ref)
w_cnt_h += len(hyp)
wer = float(d_cnt) / w_cnt
logger.debug(
(
f"wer = {wer * 100:.2f}%; num. of ref words = {w_cnt}; "
f"num. of hyp words = {w_cnt_h}; num. of sentences = {len(ref_uid_to_tra)}"
)
)
return wer
def main():
args = get_parser().parse_args()
errs = 0
count = 0
with open(args.hypo, "r") as hf, open(args.reference, "r") as rf:
for h, r in zip(hf, rf):
h = h.rstrip().split()
r = r.rstrip().split()
errs += editdistance.eval(r, h)
count += len(r)
logger.info(f"UER: {errs / count * 100:.2f}%")
if __name__ == "__main__":
main()
def load_tra(tra_path):
with open(tra_path, "r") as f:
uid_to_tra = {}
for line in f:
uid, tra = line.split(None, 1)
uid_to_tra[uid] = tra
logger.debug(f"loaded {len(uid_to_tra)} utterances from {tra_path}")
return uid_to_tra