Spaces:
Runtime error
Runtime error
#!/usr/bin/env python3 | |
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
""" | |
Scoring script for computing pairwise BLEU and multi-ref BLEU over a set of | |
candidate hypotheses. | |
See `"Mixture Models for Diverse Machine Translation: Tricks of the Trade" | |
(Shen et al., 2019) <https://arxiv.org/abs/1902.07816>`_. | |
""" | |
import argparse | |
import random | |
import sys | |
from itertools import chain | |
import numpy as np | |
from sacrebleu import compute_bleu, corpus_bleu as _corpus_bleu | |
def main(): | |
parser = argparse.ArgumentParser(sys.argv[0]) | |
parser.add_argument( | |
"--sys", nargs="*", default="", metavar="FILE", help="path to system output" | |
) | |
parser.add_argument("--ref", default="", metavar="FILE", help="path to references") | |
parser.add_argument( | |
"--output", | |
default="", | |
metavar="FILE", | |
help="print outputs into a pretty format", | |
) | |
args = parser.parse_args() | |
if args.sys: | |
src, tgt, hypos, log_probs = load_sys(args.sys) | |
print("pairwise BLEU: %.2f" % pairwise(hypos)) | |
if args.output: | |
merge(src, tgt, hypos, log_probs, args.output) | |
if args.ref: | |
_, _, refs = load_ref(args.ref) | |
if args.sys: | |
multi_ref(refs, hypos) | |
else: | |
intra_ref(refs) | |
def dictolist(d): | |
a = sorted(d.items(), key=lambda i: i[0]) | |
return [i[1] for i in a] | |
def load_sys(paths): | |
src, tgt, hypos, log_probs = {}, {}, {}, {} | |
for path in paths: | |
with open(path) as f: | |
for line in f: | |
line = line.rstrip() | |
# S: source | |
# T: target | |
# D: detokenized system output | |
if line.startswith(("S-", "T-", "D-")): | |
i = int(line[line.find("-") + 1 : line.find("\t")]) | |
if line.startswith("S-"): | |
src[i] = line.split("\t")[1] | |
if line.startswith("T-"): | |
tgt[i] = line.split("\t")[1] | |
if line.startswith("D-"): | |
if i not in hypos: | |
hypos[i] = [] | |
log_probs[i] = [] | |
hypos[i].append(line.split("\t")[2]) | |
log_probs[i].append(float(line.split("\t")[1])) | |
return dictolist(src), dictolist(tgt), dictolist(hypos), dictolist(log_probs) | |
def load_ref(path): | |
with open(path) as f: | |
lines = f.readlines() | |
src, tgt, refs = [], [], [] | |
i = 0 | |
while i < len(lines): | |
if lines[i].startswith("S-"): | |
src.append(lines[i].split("\t")[1].rstrip()) | |
i += 1 | |
elif lines[i].startswith("T-"): | |
tgt.append(lines[i].split("\t")[1].rstrip()) | |
i += 1 | |
else: | |
a = [] | |
while i < len(lines) and lines[i].startswith("R"): | |
a.append(lines[i].split("\t")[1].rstrip()) | |
i += 1 | |
refs.append(a) | |
return src, tgt, refs | |
def merge(src, tgt, hypos, log_probs, path): | |
with open(path, "w") as f: | |
for s, t, hs, lps in zip(src, tgt, hypos, log_probs): | |
f.write(s + "\n") | |
f.write(t + "\n") | |
f.write("\n") | |
for h, lp in zip(hs, lps): | |
f.write("\t%f\t%s\n" % (lp, h.strip())) | |
f.write("------------------------------------------------------\n") | |
def corpus_bleu(sys_stream, ref_streams): | |
bleu = _corpus_bleu(sys_stream, ref_streams, tokenize="none") | |
return bleu.score | |
def sentence_bleu(hypothesis, reference): | |
bleu = _corpus_bleu(hypothesis, reference) | |
for i in range(1, 4): | |
bleu.counts[i] += 1 | |
bleu.totals[i] += 1 | |
bleu = compute_bleu( | |
bleu.counts, | |
bleu.totals, | |
bleu.sys_len, | |
bleu.ref_len, | |
smooth_method="exp", | |
) | |
return bleu.score | |
def pairwise(sents): | |
_ref, _hypo = [], [] | |
for s in sents: | |
for i in range(len(s)): | |
for j in range(len(s)): | |
if i != j: | |
_ref.append(s[i]) | |
_hypo.append(s[j]) | |
return corpus_bleu(_hypo, [_ref]) | |
def multi_ref(refs, hypos): | |
_ref, _hypo = [], [] | |
ref_cnt = 0 | |
assert len(refs) == len(hypos) | |
# count number of refs covered | |
for rs, hs in zip(refs, hypos): | |
a = set() | |
for h in hs: | |
s = [sentence_bleu(h, r) for r in rs] | |
j = np.argmax(s) | |
_ref.append(rs[j]) | |
_hypo.append(h) | |
best = [k for k in range(len(rs)) if s[k] == s[j]] | |
a.add(random.choice(best)) | |
ref_cnt += len(a) | |
print("#refs covered: %.2f" % (ref_cnt / len(refs))) | |
# transpose refs and hypos | |
refs = list(zip(*refs)) | |
hypos = list(zip(*hypos)) | |
# compute multi-ref corpus BLEU (leave-one-out to be comparable to intra_ref) | |
k = len(hypos) | |
m = len(refs) | |
flat_hypos = [hypos[j][i] for i in range(len(hypos[0])) for j in range(k)] | |
duplicated_refs = [[ref for ref in refs_i for _ in range(k)] for refs_i in refs] | |
loo_bleus = [] | |
for held_out_ref in range(m): | |
remaining_refs = ( | |
duplicated_refs[:held_out_ref] + duplicated_refs[held_out_ref + 1 :] | |
) | |
assert len(remaining_refs) == m - 1 | |
loo_bleus.append(corpus_bleu(flat_hypos, remaining_refs)) | |
print("average multi-reference BLEU (leave-one-out): %.2f" % np.mean(loo_bleus)) | |
def intra_ref(refs): | |
print("ref pairwise BLEU: %.2f" % pairwise(refs)) | |
refs = list(zip(*refs)) | |
m = len(refs) | |
concat_h = [] | |
concat_rest = [[] for j in range(m - 1)] | |
for i, h in enumerate(refs): | |
rest = refs[:i] + refs[i + 1 :] | |
concat_h.append(h) | |
for j in range(m - 1): | |
concat_rest[j].extend(rest[j]) | |
concat_h = list(chain.from_iterable(concat_h)) | |
bleu = corpus_bleu(concat_h, concat_rest) | |
print("multi-reference BLEU (leave-one-out): %.2f" % bleu) | |
if __name__ == "__main__": | |
main() | |