Spaces:
Runtime error
Runtime error
#!/usr/bin/env python3 -u | |
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
""" | |
Implement unsupervised metric for decoding hyperparameter selection: | |
$$ alpha * LM_PPL + ViterbitUER(%) * 100 $$ | |
""" | |
import argparse | |
import logging | |
import sys | |
import editdistance | |
logging.root.setLevel(logging.INFO) | |
logging.basicConfig(stream=sys.stdout, level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
def get_parser(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument("-s", "--hypo", help="hypo transcription", required=True) | |
parser.add_argument( | |
"-r", "--reference", help="reference transcription", required=True | |
) | |
return parser | |
def compute_wer(ref_uid_to_tra, hyp_uid_to_tra, g2p): | |
d_cnt = 0 | |
w_cnt = 0 | |
w_cnt_h = 0 | |
for uid in hyp_uid_to_tra: | |
ref = ref_uid_to_tra[uid].split() | |
if g2p is not None: | |
hyp = g2p(hyp_uid_to_tra[uid]) | |
hyp = [p for p in hyp if p != "'" and p != " "] | |
hyp = [p[:-1] if p[-1].isnumeric() else p for p in hyp] | |
else: | |
hyp = hyp_uid_to_tra[uid].split() | |
d_cnt += editdistance.eval(ref, hyp) | |
w_cnt += len(ref) | |
w_cnt_h += len(hyp) | |
wer = float(d_cnt) / w_cnt | |
logger.debug( | |
( | |
f"wer = {wer * 100:.2f}%; num. of ref words = {w_cnt}; " | |
f"num. of hyp words = {w_cnt_h}; num. of sentences = {len(ref_uid_to_tra)}" | |
) | |
) | |
return wer | |
def main(): | |
args = get_parser().parse_args() | |
errs = 0 | |
count = 0 | |
with open(args.hypo, "r") as hf, open(args.reference, "r") as rf: | |
for h, r in zip(hf, rf): | |
h = h.rstrip().split() | |
r = r.rstrip().split() | |
errs += editdistance.eval(r, h) | |
count += len(r) | |
logger.info(f"UER: {errs / count * 100:.2f}%") | |
if __name__ == "__main__": | |
main() | |
def load_tra(tra_path): | |
with open(tra_path, "r") as f: | |
uid_to_tra = {} | |
for line in f: | |
uid, tra = line.split(None, 1) | |
uid_to_tra[uid] = tra | |
logger.debug(f"loaded {len(uid_to_tra)} utterances from {tra_path}") | |
return uid_to_tra | |