Spaces:
Runtime error
Runtime error
#!/usr/bin/env python3 -u | |
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
""" | |
Score raw text with a trained model. | |
""" | |
from collections import namedtuple | |
import logging | |
from multiprocessing import Pool | |
import sys | |
import os | |
import random | |
import numpy as np | |
import sacrebleu | |
import torch | |
from fairseq import checkpoint_utils, options, utils | |
logger = logging.getLogger("fairseq_cli.drnmt_rerank") | |
logger.setLevel(logging.INFO) | |
Batch = namedtuple("Batch", "ids src_tokens src_lengths") | |
pool_init_variables = {} | |
def init_loaded_scores(mt_scores, model_scores, hyp, ref): | |
global pool_init_variables | |
pool_init_variables["mt_scores"] = mt_scores | |
pool_init_variables["model_scores"] = model_scores | |
pool_init_variables["hyp"] = hyp | |
pool_init_variables["ref"] = ref | |
def parse_fairseq_gen(filename, task): | |
source = {} | |
hypos = {} | |
scores = {} | |
with open(filename, "r", encoding="utf-8") as f: | |
for line in f: | |
line = line.strip() | |
if line.startswith("S-"): # source | |
uid, text = line.split("\t", 1) | |
uid = int(uid[2:]) | |
source[uid] = text | |
elif line.startswith("D-"): # hypo | |
uid, score, text = line.split("\t", 2) | |
uid = int(uid[2:]) | |
if uid not in hypos: | |
hypos[uid] = [] | |
scores[uid] = [] | |
hypos[uid].append(text) | |
scores[uid].append(float(score)) | |
else: | |
continue | |
source_out = [source[i] for i in range(len(hypos))] | |
hypos_out = [h for i in range(len(hypos)) for h in hypos[i]] | |
scores_out = [s for i in range(len(scores)) for s in scores[i]] | |
return source_out, hypos_out, scores_out | |
def read_target(filename): | |
with open(filename, "r", encoding="utf-8") as f: | |
output = [line.strip() for line in f] | |
return output | |
def make_batches(args, src, hyp, task, max_positions, encode_fn): | |
assert len(src) * args.beam == len( | |
hyp | |
), f"Expect {len(src) * args.beam} hypotheses for {len(src)} source sentences with beam size {args.beam}. Got {len(hyp)} hypotheses intead." | |
hyp_encode = [ | |
task.source_dictionary.encode_line(encode_fn(h), add_if_not_exist=False).long() | |
for h in hyp | |
] | |
if task.cfg.include_src: | |
src_encode = [ | |
task.source_dictionary.encode_line( | |
encode_fn(s), add_if_not_exist=False | |
).long() | |
for s in src | |
] | |
tokens = [(src_encode[i // args.beam], h) for i, h in enumerate(hyp_encode)] | |
lengths = [(t1.numel(), t2.numel()) for t1, t2 in tokens] | |
else: | |
tokens = [(h,) for h in hyp_encode] | |
lengths = [(h.numel(),) for h in hyp_encode] | |
itr = task.get_batch_iterator( | |
dataset=task.build_dataset_for_inference(tokens, lengths), | |
max_tokens=args.max_tokens, | |
max_sentences=args.batch_size, | |
max_positions=max_positions, | |
ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, | |
).next_epoch_itr(shuffle=False) | |
for batch in itr: | |
yield Batch( | |
ids=batch["id"], | |
src_tokens=batch["net_input"]["src_tokens"], | |
src_lengths=batch["net_input"]["src_lengths"], | |
) | |
def decode_rerank_scores(args): | |
if args.max_tokens is None and args.batch_size is None: | |
args.batch_size = 1 | |
logger.info(args) | |
use_cuda = torch.cuda.is_available() and not args.cpu | |
# Load ensemble | |
logger.info("loading model(s) from {}".format(args.path)) | |
models, _model_args, task = checkpoint_utils.load_model_ensemble_and_task( | |
[args.path], arg_overrides=eval(args.model_overrides), | |
) | |
for model in models: | |
if args.fp16: | |
model.half() | |
if use_cuda: | |
model.cuda() | |
# Initialize generator | |
generator = task.build_generator(args) | |
# Handle tokenization and BPE | |
tokenizer = task.build_tokenizer(args) | |
bpe = task.build_bpe(args) | |
def encode_fn(x): | |
if tokenizer is not None: | |
x = tokenizer.encode(x) | |
if bpe is not None: | |
x = bpe.encode(x) | |
return x | |
max_positions = utils.resolve_max_positions( | |
task.max_positions(), *[model.max_positions() for model in models] | |
) | |
src, hyp, mt_scores = parse_fairseq_gen(args.in_text, task) | |
model_scores = {} | |
logger.info("decode reranker score") | |
for batch in make_batches(args, src, hyp, task, max_positions, encode_fn): | |
src_tokens = batch.src_tokens | |
src_lengths = batch.src_lengths | |
if use_cuda: | |
src_tokens = src_tokens.cuda() | |
src_lengths = src_lengths.cuda() | |
sample = { | |
"net_input": {"src_tokens": src_tokens, "src_lengths": src_lengths}, | |
} | |
scores = task.inference_step(generator, models, sample) | |
for id, sc in zip(batch.ids.tolist(), scores.tolist()): | |
model_scores[id] = sc[0] | |
model_scores = [model_scores[i] for i in range(len(model_scores))] | |
return src, hyp, mt_scores, model_scores | |
def get_score(mt_s, md_s, w1, lp, tgt_len): | |
return mt_s / (tgt_len ** lp) * w1 + md_s | |
def get_best_hyps(mt_scores, md_scores, hypos, fw_weight, lenpen, beam): | |
assert len(mt_scores) == len(md_scores) and len(mt_scores) == len(hypos) | |
hypo_scores = [] | |
best_hypos = [] | |
best_scores = [] | |
offset = 0 | |
for i in range(len(hypos)): | |
tgt_len = len(hypos[i].split()) | |
hypo_scores.append( | |
get_score(mt_scores[i], md_scores[i], fw_weight, lenpen, tgt_len) | |
) | |
if (i + 1) % beam == 0: | |
max_i = np.argmax(hypo_scores) | |
best_hypos.append(hypos[offset + max_i]) | |
best_scores.append(hypo_scores[max_i]) | |
hypo_scores = [] | |
offset += beam | |
return best_hypos, best_scores | |
def eval_metric(args, hypos, ref): | |
if args.metric == "bleu": | |
score = sacrebleu.corpus_bleu(hypos, [ref]).score | |
else: | |
score = sacrebleu.corpus_ter(hypos, [ref]).score | |
return score | |
def score_target_hypo(args, fw_weight, lp): | |
mt_scores = pool_init_variables["mt_scores"] | |
model_scores = pool_init_variables["model_scores"] | |
hyp = pool_init_variables["hyp"] | |
ref = pool_init_variables["ref"] | |
best_hypos, _ = get_best_hyps( | |
mt_scores, model_scores, hyp, fw_weight, lp, args.beam | |
) | |
rerank_eval = None | |
if ref: | |
rerank_eval = eval_metric(args, best_hypos, ref) | |
print(f"fw_weight {fw_weight}, lenpen {lp}, eval {rerank_eval}") | |
return rerank_eval | |
def print_result(best_scores, best_hypos, output_file): | |
for i, (s, h) in enumerate(zip(best_scores, best_hypos)): | |
print(f"{i}\t{s}\t{h}", file=output_file) | |
def main(args): | |
utils.import_user_module(args) | |
src, hyp, mt_scores, model_scores = decode_rerank_scores(args) | |
assert ( | |
not args.tune or args.target_text is not None | |
), "--target-text has to be set when tuning weights" | |
if args.target_text: | |
ref = read_target(args.target_text) | |
assert len(src) == len( | |
ref | |
), f"different numbers of source and target sentences ({len(src)} vs. {len(ref)})" | |
orig_best_hypos = [hyp[i] for i in range(0, len(hyp), args.beam)] | |
orig_eval = eval_metric(args, orig_best_hypos, ref) | |
if args.tune: | |
logger.info("tune weights for reranking") | |
random_params = np.array( | |
[ | |
[ | |
random.uniform( | |
args.lower_bound_fw_weight, args.upper_bound_fw_weight | |
), | |
random.uniform(args.lower_bound_lenpen, args.upper_bound_lenpen), | |
] | |
for k in range(args.num_trials) | |
] | |
) | |
logger.info("launching pool") | |
with Pool( | |
32, | |
initializer=init_loaded_scores, | |
initargs=(mt_scores, model_scores, hyp, ref), | |
) as p: | |
rerank_scores = p.starmap( | |
score_target_hypo, | |
[ | |
(args, random_params[i][0], random_params[i][1],) | |
for i in range(args.num_trials) | |
], | |
) | |
if args.metric == "bleu": | |
best_index = np.argmax(rerank_scores) | |
else: | |
best_index = np.argmin(rerank_scores) | |
best_fw_weight = random_params[best_index][0] | |
best_lenpen = random_params[best_index][1] | |
else: | |
assert ( | |
args.lenpen is not None and args.fw_weight is not None | |
), "--lenpen and --fw-weight should be set" | |
best_fw_weight, best_lenpen = args.fw_weight, args.lenpen | |
best_hypos, best_scores = get_best_hyps( | |
mt_scores, model_scores, hyp, best_fw_weight, best_lenpen, args.beam | |
) | |
if args.results_path is not None: | |
os.makedirs(args.results_path, exist_ok=True) | |
output_path = os.path.join( | |
args.results_path, "generate-{}.txt".format(args.gen_subset), | |
) | |
with open(output_path, "w", buffering=1, encoding="utf-8") as o: | |
print_result(best_scores, best_hypos, o) | |
else: | |
print_result(best_scores, best_hypos, sys.stdout) | |
if args.target_text: | |
rerank_eval = eval_metric(args, best_hypos, ref) | |
print(f"before reranking, {args.metric.upper()}:", orig_eval) | |
print( | |
f"after reranking with fw_weight={best_fw_weight}, lenpen={best_lenpen}, {args.metric.upper()}:", | |
rerank_eval, | |
) | |
def cli_main(): | |
parser = options.get_generation_parser(interactive=True) | |
parser.add_argument( | |
"--in-text", | |
default=None, | |
required=True, | |
help="text from fairseq-interactive output, containing source sentences and hypotheses", | |
) | |
parser.add_argument("--target-text", default=None, help="reference text") | |
parser.add_argument("--metric", type=str, choices=["bleu", "ter"], default="bleu") | |
parser.add_argument( | |
"--tune", | |
action="store_true", | |
help="if set, tune weights on fw scores and lenpen instead of applying fixed weights for reranking", | |
) | |
parser.add_argument( | |
"--lower-bound-fw-weight", | |
default=0.0, | |
type=float, | |
help="lower bound of search space", | |
) | |
parser.add_argument( | |
"--upper-bound-fw-weight", | |
default=3, | |
type=float, | |
help="upper bound of search space", | |
) | |
parser.add_argument( | |
"--lower-bound-lenpen", | |
default=0.0, | |
type=float, | |
help="lower bound of search space", | |
) | |
parser.add_argument( | |
"--upper-bound-lenpen", | |
default=3, | |
type=float, | |
help="upper bound of search space", | |
) | |
parser.add_argument( | |
"--fw-weight", type=float, default=None, help="weight on the fw model score" | |
) | |
parser.add_argument( | |
"--num-trials", | |
default=1000, | |
type=int, | |
help="number of trials to do for random search", | |
) | |
args = options.parse_args_and_arch(parser) | |
main(args) | |
if __name__ == "__main__": | |
cli_main() | |