JustinLin610's picture
first commit
ee21b96
raw
history blame
2.95 kB
#!/usr/bin/env python3
#
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""Extracts random constraints from reference files."""
import argparse
import random
import sys
from sacrebleu import extract_ngrams
def get_phrase(words, index, length):
assert index < len(words) - length + 1
phr = " ".join(words[index : index + length])
for i in range(index, index + length):
words.pop(index)
return phr
def main(args):
if args.seed:
random.seed(args.seed)
for line in sys.stdin:
constraints = []
def add_constraint(constraint):
constraints.append(constraint)
source = line.rstrip()
if "\t" in line:
source, target = line.split("\t")
if args.add_sos:
target = f"<s> {target}"
if args.add_eos:
target = f"{target} </s>"
if len(target.split()) >= args.len:
words = [target]
num = args.number
choices = {}
for i in range(num):
if len(words) == 0:
break
segmentno = random.choice(range(len(words)))
segment = words.pop(segmentno)
tokens = segment.split()
phrase_index = random.choice(range(len(tokens)))
choice = " ".join(
tokens[phrase_index : min(len(tokens), phrase_index + args.len)]
)
for j in range(
phrase_index, min(len(tokens), phrase_index + args.len)
):
tokens.pop(phrase_index)
if phrase_index > 0:
words.append(" ".join(tokens[0:phrase_index]))
if phrase_index + 1 < len(tokens):
words.append(" ".join(tokens[phrase_index:]))
choices[target.find(choice)] = choice
# mask out with spaces
target = target.replace(choice, " " * len(choice), 1)
for key in sorted(choices.keys()):
add_constraint(choices[key])
print(source, *constraints, sep="\t")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--number", "-n", type=int, default=1, help="number of phrases")
parser.add_argument("--len", "-l", type=int, default=1, help="phrase length")
parser.add_argument(
"--add-sos", default=False, action="store_true", help="add <s> token"
)
parser.add_argument(
"--add-eos", default=False, action="store_true", help="add </s> token"
)
parser.add_argument("--seed", "-s", default=0, type=int)
args = parser.parse_args()
main(args)