Spaces:
Runtime error
Runtime error
#!/usr/bin/env python3 | |
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import argparse | |
from itertools import zip_longest | |
def replace_oovs(source_in, target_in, vocabulary, source_out, target_out): | |
"""Replaces out-of-vocabulary words in source and target text with <unk-N>, | |
where N in is the position of the word in the source sequence. | |
""" | |
def format_unk(pos): | |
return "<unk-{}>".format(pos) | |
if target_in is None: | |
target_in = [] | |
for seq_num, (source_seq, target_seq) in enumerate( | |
zip_longest(source_in, target_in) | |
): | |
source_seq_out = [] | |
target_seq_out = [] | |
word_to_pos = dict() | |
for position, token in enumerate(source_seq.strip().split()): | |
if token in vocabulary: | |
token_out = token | |
else: | |
if token in word_to_pos: | |
oov_pos = word_to_pos[token] | |
else: | |
word_to_pos[token] = position | |
oov_pos = position | |
token_out = format_unk(oov_pos) | |
source_seq_out.append(token_out) | |
source_out.write(" ".join(source_seq_out) + "\n") | |
if target_seq is not None: | |
for token in target_seq.strip().split(): | |
if token in word_to_pos: | |
token_out = format_unk(word_to_pos[token]) | |
else: | |
token_out = token | |
target_seq_out.append(token_out) | |
if target_out is not None: | |
target_out.write(" ".join(target_seq_out) + "\n") | |
def main(): | |
parser = argparse.ArgumentParser( | |
description="Replaces out-of-vocabulary words in both source and target " | |
"sequences with tokens that indicate the position of the word " | |
"in the source sequence." | |
) | |
parser.add_argument( | |
"--source", type=str, help="text file with source sequences", required=True | |
) | |
parser.add_argument( | |
"--target", type=str, help="text file with target sequences", default=None | |
) | |
parser.add_argument("--vocab", type=str, help="vocabulary file", required=True) | |
parser.add_argument( | |
"--source-out", | |
type=str, | |
help="where to write source sequences with <unk-N> entries", | |
required=True, | |
) | |
parser.add_argument( | |
"--target-out", | |
type=str, | |
help="where to write target sequences with <unk-N> entries", | |
default=None, | |
) | |
args = parser.parse_args() | |
with open(args.vocab, encoding="utf-8") as vocab: | |
vocabulary = vocab.read().splitlines() | |
target_in = ( | |
open(args.target, "r", encoding="utf-8") if args.target is not None else None | |
) | |
target_out = ( | |
open(args.target_out, "w", encoding="utf-8") | |
if args.target_out is not None | |
else None | |
) | |
with open(args.source, "r", encoding="utf-8") as source_in, open( | |
args.source_out, "w", encoding="utf-8" | |
) as source_out: | |
replace_oovs(source_in, target_in, vocabulary, source_out, target_out) | |
if target_in is not None: | |
target_in.close() | |
if target_out is not None: | |
target_out.close() | |
if __name__ == "__main__": | |
main() | |