Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import argparse | |
import math | |
import os | |
import subprocess | |
import sys | |
import tempfile | |
from collections import defaultdict | |
from itertools import combinations | |
def read_translations(path, n_repeats): | |
segment_counter = 0 | |
segment_translations = [] | |
translations = defaultdict(list) | |
for line in open(path): | |
segment_translations.append(" ".join(line.split())) | |
if len(segment_translations) == n_repeats: | |
translations[segment_counter] = segment_translations | |
segment_translations = [] | |
segment_counter += 1 | |
return translations | |
def generate_input(translations, n_repeats): | |
_, ref_path = tempfile.mkstemp() | |
_, mt_path = tempfile.mkstemp() | |
ref_fh = open(ref_path, "w") | |
mt_fh = open(mt_path, "w") | |
for segid in sorted(translations.keys()): | |
assert len(translations[segid]) == n_repeats | |
indexes = combinations(range(n_repeats), 2) | |
for idx1, idx2 in indexes: | |
mt_fh.write(translations[segid][idx1].strip() + "\n") | |
ref_fh.write(translations[segid][idx2].strip() + "\n") | |
sys.stderr.write("\nSaved translations to %s and %s" % (ref_path, mt_path)) | |
return ref_path, mt_path | |
def run_meteor(ref_path, mt_path, metric_path, lang="en"): | |
_, out_path = tempfile.mkstemp() | |
subprocess.call( | |
[ | |
"java", | |
"-Xmx2G", | |
"-jar", | |
metric_path, | |
mt_path, | |
ref_path, | |
"-p", | |
"0.5 0.2 0.6 0.75", # default parameters, only changed alpha to give equal weight to P and R | |
"-norm", | |
"-l", | |
lang, | |
], | |
stdout=open(out_path, "w"), | |
) | |
os.remove(ref_path) | |
os.remove(mt_path) | |
sys.stderr.write("\nSaved Meteor output to %s" % out_path) | |
return out_path | |
def read_output(meteor_output_path, n_repeats): | |
n_combinations = math.factorial(n_repeats) / ( | |
math.factorial(2) * math.factorial(n_repeats - 2) | |
) | |
raw_scores = [] | |
average_scores = [] | |
for line in open(meteor_output_path): | |
if not line.startswith("Segment "): | |
continue | |
score = float(line.strip().split("\t")[1]) | |
raw_scores.append(score) | |
if len(raw_scores) == n_combinations: | |
average_scores.append(sum(raw_scores) / n_combinations) | |
raw_scores = [] | |
os.remove(meteor_output_path) | |
return average_scores | |
def main(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument("-i", "--infile") | |
parser.add_argument("-n", "--repeat_times", type=int) | |
parser.add_argument("-m", "--meteor") | |
parser.add_argument("-o", "--output") | |
args = parser.parse_args() | |
translations = read_translations(args.infile, args.repeat_times) | |
sys.stderr.write("\nGenerating input for Meteor...") | |
ref_path, mt_path = generate_input(translations, args.repeat_times) | |
sys.stderr.write("\nRunning Meteor...") | |
out_path = run_meteor(ref_path, mt_path, args.meteor) | |
sys.stderr.write("\nReading output...") | |
scores = read_output(out_path, args.repeat_times) | |
sys.stderr.write("\nWriting results...") | |
with open(args.output, "w") as o: | |
for scr in scores: | |
o.write("{}\n".format(scr)) | |
o.close() | |
if __name__ == "__main__": | |
main() | |