Tzktz's picture
Upload 7664 files
6fc683c verified
raw
history blame
12.2 kB
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
Data pre-processing: build vocabularies and binarize training data.
"""
import logging
import os
import shutil
import sys
import typing as tp
from argparse import Namespace
from itertools import zip_longest
from fairseq import options, tasks, utils
from fairseq.binarizer import (
AlignmentDatasetBinarizer,
FileBinarizer,
VocabularyDatasetBinarizer,
)
from fairseq.data import Dictionary
logging.basicConfig(
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
level=os.environ.get("LOGLEVEL", "INFO").upper(),
stream=sys.stdout,
)
logger = logging.getLogger("fairseq_cli.preprocess")
#####################################################################
# file name tools
#####################################################################
def _train_path(lang, trainpref):
return "{}{}".format(trainpref, ("." + lang) if lang else "")
def _file_name(prefix, lang):
fname = prefix
if lang is not None:
fname += ".{lang}".format(lang=lang)
return fname
def _dest_path(prefix, lang, destdir):
return os.path.join(destdir, _file_name(prefix, lang))
def _dict_path(lang, destdir):
return _dest_path("dict", lang, destdir) + ".txt"
def dataset_dest_prefix(args, output_prefix, lang):
base = os.path.join(args.destdir, output_prefix)
if lang is not None:
lang_part = f".{args.source_lang}-{args.target_lang}.{lang}"
elif args.only_source:
lang_part = ""
else:
lang_part = f".{args.source_lang}-{args.target_lang}"
return "{}{}".format(base, lang_part)
def dataset_dest_file(args, output_prefix, lang, extension):
return "{}.{}".format(dataset_dest_prefix(args, output_prefix, lang), extension)
#####################################################################
# dictionary tools
#####################################################################
def _build_dictionary(
filenames,
task,
args,
src=False,
tgt=False,
):
assert src ^ tgt
return task.build_dictionary(
filenames,
workers=args.workers,
threshold=args.thresholdsrc if src else args.thresholdtgt,
nwords=args.nwordssrc if src else args.nwordstgt,
padding_factor=args.padding_factor,
)
#####################################################################
# bin file creation logic
#####################################################################
def _make_binary_dataset(
vocab: Dictionary,
input_prefix: str,
output_prefix: str,
lang: tp.Optional[str],
num_workers: int,
args: Namespace,
):
logger.info("[{}] Dictionary: {} types".format(lang, len(vocab)))
binarizer = VocabularyDatasetBinarizer(
vocab,
append_eos=True,
)
input_file = "{}{}".format(input_prefix, ("." + lang) if lang is not None else "")
full_output_prefix = dataset_dest_prefix(args, output_prefix, lang)
final_summary = FileBinarizer.multiprocess_dataset(
input_file,
args.dataset_impl,
binarizer,
full_output_prefix,
vocab_size=len(vocab),
num_workers=num_workers,
)
logger.info(f"[{lang}] {input_file}: {final_summary} (by {vocab.unk_word})")
def _make_binary_alignment_dataset(
input_prefix: str, output_prefix: str, num_workers: int, args: Namespace
):
binarizer = AlignmentDatasetBinarizer(utils.parse_alignment)
input_file = input_prefix
full_output_prefix = dataset_dest_prefix(args, output_prefix, lang=None)
final_summary = FileBinarizer.multiprocess_dataset(
input_file,
args.dataset_impl,
binarizer,
full_output_prefix,
vocab_size=None,
num_workers=num_workers,
)
logger.info(
"[alignments] {}: parsed {} alignments".format(
input_file, final_summary.num_seq
)
)
#####################################################################
# routing logic
#####################################################################
def _make_dataset(
vocab: Dictionary,
input_prefix: str,
output_prefix: str,
lang: tp.Optional[str],
args: Namespace,
num_workers: int,
):
if args.dataset_impl == "raw":
# Copy original text file to destination folder
output_text_file = _dest_path(
output_prefix + ".{}-{}".format(args.source_lang, args.target_lang),
lang,
args.destdir,
)
shutil.copyfile(_file_name(input_prefix, lang), output_text_file)
else:
_make_binary_dataset(
vocab, input_prefix, output_prefix, lang, num_workers, args
)
def _make_all(lang, vocab, args):
if args.trainpref:
_make_dataset(
vocab, args.trainpref, "train", lang, args=args, num_workers=args.workers
)
if args.validpref:
for k, validpref in enumerate(args.validpref.split(",")):
outprefix = "valid{}".format(k) if k > 0 else "valid"
_make_dataset(
vocab, validpref, outprefix, lang, args=args, num_workers=args.workers
)
if args.testpref:
for k, testpref in enumerate(args.testpref.split(",")):
outprefix = "test{}".format(k) if k > 0 else "test"
_make_dataset(
vocab, testpref, outprefix, lang, args=args, num_workers=args.workers
)
def _make_all_alignments(args):
if args.trainpref and os.path.exists(args.trainpref + "." + args.align_suffix):
_make_binary_alignment_dataset(
args.trainpref + "." + args.align_suffix,
"train.align",
num_workers=args.workers,
args=args,
)
if args.validpref and os.path.exists(args.validpref + "." + args.align_suffix):
_make_binary_alignment_dataset(
args.validpref + "." + args.align_suffix,
"valid.align",
num_workers=args.workers,
args=args,
)
if args.testpref and os.path.exists(args.testpref + "." + args.align_suffix):
_make_binary_alignment_dataset(
args.testpref + "." + args.align_suffix,
"test.align",
num_workers=args.workers,
args=args,
)
#####################################################################
# align
#####################################################################
def _align_files(args, src_dict, tgt_dict):
assert args.trainpref, "--trainpref must be set if --alignfile is specified"
src_file_name = _train_path(args.source_lang, args.trainpref)
tgt_file_name = _train_path(args.target_lang, args.trainpref)
freq_map = {}
with open(args.alignfile, "r", encoding="utf-8") as align_file:
with open(src_file_name, "r", encoding="utf-8") as src_file:
with open(tgt_file_name, "r", encoding="utf-8") as tgt_file:
for a, s, t in zip_longest(align_file, src_file, tgt_file):
si = src_dict.encode_line(s, add_if_not_exist=False)
ti = tgt_dict.encode_line(t, add_if_not_exist=False)
ai = list(map(lambda x: tuple(x.split("-")), a.split()))
for sai, tai in ai:
srcidx = si[int(sai)]
tgtidx = ti[int(tai)]
if srcidx != src_dict.unk() and tgtidx != tgt_dict.unk():
assert srcidx != src_dict.pad()
assert srcidx != src_dict.eos()
assert tgtidx != tgt_dict.pad()
assert tgtidx != tgt_dict.eos()
if srcidx not in freq_map:
freq_map[srcidx] = {}
if tgtidx not in freq_map[srcidx]:
freq_map[srcidx][tgtidx] = 1
else:
freq_map[srcidx][tgtidx] += 1
align_dict = {}
for srcidx in freq_map.keys():
align_dict[srcidx] = max(freq_map[srcidx], key=freq_map[srcidx].get)
with open(
os.path.join(
args.destdir,
"alignment.{}-{}.txt".format(args.source_lang, args.target_lang),
),
"w",
encoding="utf-8",
) as f:
for k, v in align_dict.items():
print("{} {}".format(src_dict[k], tgt_dict[v]), file=f)
#####################################################################
# MAIN
#####################################################################
def main(args):
# setup some basic things
utils.import_user_module(args)
os.makedirs(args.destdir, exist_ok=True)
logger.addHandler(
logging.FileHandler(
filename=os.path.join(args.destdir, "preprocess.log"),
)
)
logger.info(args)
assert (
args.dataset_impl != "huffman"
), "preprocessing.py doesn't support Huffman yet, use HuffmanCodeBuilder directly."
# build dictionaries
target = not args.only_source
if not args.srcdict and os.path.exists(_dict_path(args.source_lang, args.destdir)):
raise FileExistsError(_dict_path(args.source_lang, args.destdir))
if (
target
and not args.tgtdict
and os.path.exists(_dict_path(args.target_lang, args.destdir))
):
raise FileExistsError(_dict_path(args.target_lang, args.destdir))
task = tasks.get_task(args.task)
if args.joined_dictionary:
assert (
not args.srcdict or not args.tgtdict
), "cannot use both --srcdict and --tgtdict with --joined-dictionary"
if args.srcdict:
src_dict = task.load_dictionary(args.srcdict)
elif args.tgtdict:
src_dict = task.load_dictionary(args.tgtdict)
else:
assert (
args.trainpref
), "--trainpref must be set if --srcdict is not specified"
src_dict = _build_dictionary(
{
_train_path(lang, args.trainpref)
for lang in [args.source_lang, args.target_lang]
},
task=task,
args=args,
src=True,
)
tgt_dict = src_dict
else:
if args.srcdict:
src_dict = task.load_dictionary(args.srcdict)
else:
assert (
args.trainpref
), "--trainpref must be set if --srcdict is not specified"
src_dict = _build_dictionary(
[_train_path(args.source_lang, args.trainpref)],
task=task,
args=args,
src=True,
)
if target:
if args.tgtdict:
tgt_dict = task.load_dictionary(args.tgtdict)
else:
assert (
args.trainpref
), "--trainpref must be set if --tgtdict is not specified"
tgt_dict = _build_dictionary(
[_train_path(args.target_lang, args.trainpref)],
task=task,
args=args,
tgt=True,
)
else:
tgt_dict = None
# save dictionaries
src_dict.save(_dict_path(args.source_lang, args.destdir))
if target and tgt_dict is not None:
tgt_dict.save(_dict_path(args.target_lang, args.destdir))
if args.dict_only:
return
_make_all(args.source_lang, src_dict, args)
if target:
_make_all(args.target_lang, tgt_dict, args)
# align the datasets if needed
if args.align_suffix:
_make_all_alignments(args)
logger.info("Wrote preprocessed data to {}".format(args.destdir))
if args.alignfile:
_align_files(args, src_dict=src_dict, tgt_dict=tgt_dict)
def cli_main():
parser = options.get_preprocessing_parser()
args = parser.parse_args()
main(args)
if __name__ == "__main__":
cli_main()