File size: 4,872 Bytes
ee21b96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
#!/usr/bin/env python

import argparse
from multiprocessing import Pool
from pathlib import Path

import sacrebleu
import sentencepiece as spm


def read_text_file(filename):
    with open(filename, "r") as f:
        output = [line.strip() for line in f]

    return output


def get_bleu(in_sent, target_sent):
    bleu = sacrebleu.corpus_bleu([in_sent], [[target_sent]])
    out = " ".join(
        map(str, [bleu.score, bleu.sys_len, bleu.ref_len] + bleu.counts + bleu.totals)
    )
    return out


def get_ter(in_sent, target_sent):
    ter = sacrebleu.corpus_ter([in_sent], [[target_sent]])
    out = " ".join(map(str, [ter.score, ter.num_edits, ter.ref_length]))
    return out


def init(sp_model):
    global sp
    sp = spm.SentencePieceProcessor()
    sp.Load(sp_model)


def process(source_sent, target_sent, hypo_sent, metric):
    source_bpe = " ".join(sp.EncodeAsPieces(source_sent))
    hypo_bpe = [" ".join(sp.EncodeAsPieces(h)) for h in hypo_sent]

    if metric == "bleu":
        score_str = [get_bleu(h, target_sent) for h in hypo_sent]
    else:  # ter
        score_str = [get_ter(h, target_sent) for h in hypo_sent]

    return source_bpe, hypo_bpe, score_str


def main(args):
    assert (
        args.split.startswith("train") or args.num_shards == 1
    ), "--num-shards should be set to 1 for valid and test sets"
    assert (
        args.split.startswith("train")
        or args.split.startswith("valid")
        or args.split.startswith("test")
    ), "--split should be set to train[n]/valid[n]/test[n]"

    source_sents = read_text_file(args.input_source)
    target_sents = read_text_file(args.input_target)

    num_sents = len(source_sents)
    assert num_sents == len(
        target_sents
    ), f"{args.input_source} and {args.input_target} should have the same number of sentences."

    hypo_sents = read_text_file(args.input_hypo)
    assert (
        len(hypo_sents) % args.beam == 0
    ), f"Number of hypotheses ({len(hypo_sents)}) cannot be divided by beam size ({args.beam})."

    hypo_sents = [
        hypo_sents[i : i + args.beam] for i in range(0, len(hypo_sents), args.beam)
    ]
    assert num_sents == len(
        hypo_sents
    ), f"{args.input_hypo} should contain {num_sents * args.beam} hypotheses but only has {len(hypo_sents) * args.beam}. (--beam={args.beam})"

    output_dir = args.output_dir / args.metric
    for ns in range(args.num_shards):
        print(f"processing shard {ns+1}/{args.num_shards}")
        shard_output_dir = output_dir / f"split{ns+1}"
        source_output_dir = shard_output_dir / "input_src"
        hypo_output_dir = shard_output_dir / "input_tgt"
        metric_output_dir = shard_output_dir / args.metric

        source_output_dir.mkdir(parents=True, exist_ok=True)
        hypo_output_dir.mkdir(parents=True, exist_ok=True)
        metric_output_dir.mkdir(parents=True, exist_ok=True)

        if args.n_proc > 1:
            with Pool(
                args.n_proc, initializer=init, initargs=(args.sentencepiece_model,)
            ) as p:
                output = p.starmap(
                    process,
                    [
                        (source_sents[i], target_sents[i], hypo_sents[i], args.metric)
                        for i in range(ns, num_sents, args.num_shards)
                    ],
                )
        else:
            init(args.sentencepiece_model)
            output = [
                process(source_sents[i], target_sents[i], hypo_sents[i], args.metric)
                for i in range(ns, num_sents, args.num_shards)
            ]

        with open(source_output_dir / f"{args.split}.bpe", "w") as s_o, open(
            hypo_output_dir / f"{args.split}.bpe", "w"
        ) as h_o, open(metric_output_dir / f"{args.split}.{args.metric}", "w") as m_o:
            for source_bpe, hypo_bpe, score_str in output:
                assert len(hypo_bpe) == len(score_str)
                for h, m in zip(hypo_bpe, score_str):
                    s_o.write(f"{source_bpe}\n")
                    h_o.write(f"{h}\n")
                    m_o.write(f"{m}\n")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--input-source", type=Path, required=True)
    parser.add_argument("--input-target", type=Path, required=True)
    parser.add_argument("--input-hypo", type=Path, required=True)
    parser.add_argument("--output-dir", type=Path, required=True)
    parser.add_argument("--split", type=str, required=True)
    parser.add_argument("--beam", type=int, required=True)
    parser.add_argument("--sentencepiece-model", type=str, required=True)
    parser.add_argument("--metric", type=str, choices=["bleu", "ter"], default="bleu")
    parser.add_argument("--num-shards", type=int, default=1)
    parser.add_argument("--n-proc", type=int, default=8)

    args = parser.parse_args()

    main(args)