Spaces:
Sleeping
Sleeping
#!/usr/bin/env python3 | |
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
""" | |
Data pre-processing: build vocabularies and binarize training data. | |
""" | |
import logging | |
import os | |
import shutil | |
import sys | |
import typing as tp | |
from argparse import Namespace | |
from itertools import zip_longest | |
from fairseq import options, tasks, utils | |
from fairseq.binarizer import ( | |
AlignmentDatasetBinarizer, | |
FileBinarizer, | |
VocabularyDatasetBinarizer, | |
) | |
from fairseq.data import Dictionary | |
logging.basicConfig( | |
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", | |
datefmt="%Y-%m-%d %H:%M:%S", | |
level=os.environ.get("LOGLEVEL", "INFO").upper(), | |
stream=sys.stdout, | |
) | |
logger = logging.getLogger("fairseq_cli.preprocess") | |
##################################################################### | |
# file name tools | |
##################################################################### | |
def _train_path(lang, trainpref): | |
return "{}{}".format(trainpref, ("." + lang) if lang else "") | |
def _file_name(prefix, lang): | |
fname = prefix | |
if lang is not None: | |
fname += ".{lang}".format(lang=lang) | |
return fname | |
def _dest_path(prefix, lang, destdir): | |
return os.path.join(destdir, _file_name(prefix, lang)) | |
def _dict_path(lang, destdir): | |
return _dest_path("dict", lang, destdir) + ".txt" | |
def dataset_dest_prefix(args, output_prefix, lang): | |
base = os.path.join(args.destdir, output_prefix) | |
if lang is not None: | |
lang_part = f".{args.source_lang}-{args.target_lang}.{lang}" | |
elif args.only_source: | |
lang_part = "" | |
else: | |
lang_part = f".{args.source_lang}-{args.target_lang}" | |
return "{}{}".format(base, lang_part) | |
def dataset_dest_file(args, output_prefix, lang, extension): | |
return "{}.{}".format(dataset_dest_prefix(args, output_prefix, lang), extension) | |
##################################################################### | |
# dictionary tools | |
##################################################################### | |
def _build_dictionary( | |
filenames, | |
task, | |
args, | |
src=False, | |
tgt=False, | |
): | |
assert src ^ tgt | |
return task.build_dictionary( | |
filenames, | |
workers=args.workers, | |
threshold=args.thresholdsrc if src else args.thresholdtgt, | |
nwords=args.nwordssrc if src else args.nwordstgt, | |
padding_factor=args.padding_factor, | |
) | |
##################################################################### | |
# bin file creation logic | |
##################################################################### | |
def _make_binary_dataset( | |
vocab: Dictionary, | |
input_prefix: str, | |
output_prefix: str, | |
lang: tp.Optional[str], | |
num_workers: int, | |
args: Namespace, | |
): | |
logger.info("[{}] Dictionary: {} types".format(lang, len(vocab))) | |
binarizer = VocabularyDatasetBinarizer( | |
vocab, | |
append_eos=True, | |
) | |
input_file = "{}{}".format(input_prefix, ("." + lang) if lang is not None else "") | |
full_output_prefix = dataset_dest_prefix(args, output_prefix, lang) | |
final_summary = FileBinarizer.multiprocess_dataset( | |
input_file, | |
args.dataset_impl, | |
binarizer, | |
full_output_prefix, | |
vocab_size=len(vocab), | |
num_workers=num_workers, | |
) | |
logger.info(f"[{lang}] {input_file}: {final_summary} (by {vocab.unk_word})") | |
def _make_binary_alignment_dataset( | |
input_prefix: str, output_prefix: str, num_workers: int, args: Namespace | |
): | |
binarizer = AlignmentDatasetBinarizer(utils.parse_alignment) | |
input_file = input_prefix | |
full_output_prefix = dataset_dest_prefix(args, output_prefix, lang=None) | |
final_summary = FileBinarizer.multiprocess_dataset( | |
input_file, | |
args.dataset_impl, | |
binarizer, | |
full_output_prefix, | |
vocab_size=None, | |
num_workers=num_workers, | |
) | |
logger.info( | |
"[alignments] {}: parsed {} alignments".format( | |
input_file, final_summary.num_seq | |
) | |
) | |
##################################################################### | |
# routing logic | |
##################################################################### | |
def _make_dataset( | |
vocab: Dictionary, | |
input_prefix: str, | |
output_prefix: str, | |
lang: tp.Optional[str], | |
args: Namespace, | |
num_workers: int, | |
): | |
if args.dataset_impl == "raw": | |
# Copy original text file to destination folder | |
output_text_file = _dest_path( | |
output_prefix + ".{}-{}".format(args.source_lang, args.target_lang), | |
lang, | |
args.destdir, | |
) | |
shutil.copyfile(_file_name(input_prefix, lang), output_text_file) | |
else: | |
_make_binary_dataset( | |
vocab, input_prefix, output_prefix, lang, num_workers, args | |
) | |
def _make_all(lang, vocab, args): | |
if args.trainpref: | |
_make_dataset( | |
vocab, args.trainpref, "train", lang, args=args, num_workers=args.workers | |
) | |
if args.validpref: | |
for k, validpref in enumerate(args.validpref.split(",")): | |
outprefix = "valid{}".format(k) if k > 0 else "valid" | |
_make_dataset( | |
vocab, validpref, outprefix, lang, args=args, num_workers=args.workers | |
) | |
if args.testpref: | |
for k, testpref in enumerate(args.testpref.split(",")): | |
outprefix = "test{}".format(k) if k > 0 else "test" | |
_make_dataset( | |
vocab, testpref, outprefix, lang, args=args, num_workers=args.workers | |
) | |
def _make_all_alignments(args): | |
if args.trainpref and os.path.exists(args.trainpref + "." + args.align_suffix): | |
_make_binary_alignment_dataset( | |
args.trainpref + "." + args.align_suffix, | |
"train.align", | |
num_workers=args.workers, | |
args=args, | |
) | |
if args.validpref and os.path.exists(args.validpref + "." + args.align_suffix): | |
_make_binary_alignment_dataset( | |
args.validpref + "." + args.align_suffix, | |
"valid.align", | |
num_workers=args.workers, | |
args=args, | |
) | |
if args.testpref and os.path.exists(args.testpref + "." + args.align_suffix): | |
_make_binary_alignment_dataset( | |
args.testpref + "." + args.align_suffix, | |
"test.align", | |
num_workers=args.workers, | |
args=args, | |
) | |
##################################################################### | |
# align | |
##################################################################### | |
def _align_files(args, src_dict, tgt_dict): | |
assert args.trainpref, "--trainpref must be set if --alignfile is specified" | |
src_file_name = _train_path(args.source_lang, args.trainpref) | |
tgt_file_name = _train_path(args.target_lang, args.trainpref) | |
freq_map = {} | |
with open(args.alignfile, "r", encoding="utf-8") as align_file: | |
with open(src_file_name, "r", encoding="utf-8") as src_file: | |
with open(tgt_file_name, "r", encoding="utf-8") as tgt_file: | |
for a, s, t in zip_longest(align_file, src_file, tgt_file): | |
si = src_dict.encode_line(s, add_if_not_exist=False) | |
ti = tgt_dict.encode_line(t, add_if_not_exist=False) | |
ai = list(map(lambda x: tuple(x.split("-")), a.split())) | |
for sai, tai in ai: | |
srcidx = si[int(sai)] | |
tgtidx = ti[int(tai)] | |
if srcidx != src_dict.unk() and tgtidx != tgt_dict.unk(): | |
assert srcidx != src_dict.pad() | |
assert srcidx != src_dict.eos() | |
assert tgtidx != tgt_dict.pad() | |
assert tgtidx != tgt_dict.eos() | |
if srcidx not in freq_map: | |
freq_map[srcidx] = {} | |
if tgtidx not in freq_map[srcidx]: | |
freq_map[srcidx][tgtidx] = 1 | |
else: | |
freq_map[srcidx][tgtidx] += 1 | |
align_dict = {} | |
for srcidx in freq_map.keys(): | |
align_dict[srcidx] = max(freq_map[srcidx], key=freq_map[srcidx].get) | |
with open( | |
os.path.join( | |
args.destdir, | |
"alignment.{}-{}.txt".format(args.source_lang, args.target_lang), | |
), | |
"w", | |
encoding="utf-8", | |
) as f: | |
for k, v in align_dict.items(): | |
print("{} {}".format(src_dict[k], tgt_dict[v]), file=f) | |
##################################################################### | |
# MAIN | |
##################################################################### | |
def main(args): | |
# setup some basic things | |
utils.import_user_module(args) | |
os.makedirs(args.destdir, exist_ok=True) | |
logger.addHandler( | |
logging.FileHandler( | |
filename=os.path.join(args.destdir, "preprocess.log"), | |
) | |
) | |
logger.info(args) | |
assert ( | |
args.dataset_impl != "huffman" | |
), "preprocessing.py doesn't support Huffman yet, use HuffmanCodeBuilder directly." | |
# build dictionaries | |
target = not args.only_source | |
if not args.srcdict and os.path.exists(_dict_path(args.source_lang, args.destdir)): | |
raise FileExistsError(_dict_path(args.source_lang, args.destdir)) | |
if ( | |
target | |
and not args.tgtdict | |
and os.path.exists(_dict_path(args.target_lang, args.destdir)) | |
): | |
raise FileExistsError(_dict_path(args.target_lang, args.destdir)) | |
task = tasks.get_task(args.task) | |
if args.joined_dictionary: | |
assert ( | |
not args.srcdict or not args.tgtdict | |
), "cannot use both --srcdict and --tgtdict with --joined-dictionary" | |
if args.srcdict: | |
src_dict = task.load_dictionary(args.srcdict) | |
elif args.tgtdict: | |
src_dict = task.load_dictionary(args.tgtdict) | |
else: | |
assert ( | |
args.trainpref | |
), "--trainpref must be set if --srcdict is not specified" | |
src_dict = _build_dictionary( | |
{ | |
_train_path(lang, args.trainpref) | |
for lang in [args.source_lang, args.target_lang] | |
}, | |
task=task, | |
args=args, | |
src=True, | |
) | |
tgt_dict = src_dict | |
else: | |
if args.srcdict: | |
src_dict = task.load_dictionary(args.srcdict) | |
else: | |
assert ( | |
args.trainpref | |
), "--trainpref must be set if --srcdict is not specified" | |
src_dict = _build_dictionary( | |
[_train_path(args.source_lang, args.trainpref)], | |
task=task, | |
args=args, | |
src=True, | |
) | |
if target: | |
if args.tgtdict: | |
tgt_dict = task.load_dictionary(args.tgtdict) | |
else: | |
assert ( | |
args.trainpref | |
), "--trainpref must be set if --tgtdict is not specified" | |
tgt_dict = _build_dictionary( | |
[_train_path(args.target_lang, args.trainpref)], | |
task=task, | |
args=args, | |
tgt=True, | |
) | |
else: | |
tgt_dict = None | |
# save dictionaries | |
src_dict.save(_dict_path(args.source_lang, args.destdir)) | |
if target and tgt_dict is not None: | |
tgt_dict.save(_dict_path(args.target_lang, args.destdir)) | |
if args.dict_only: | |
return | |
_make_all(args.source_lang, src_dict, args) | |
if target: | |
_make_all(args.target_lang, tgt_dict, args) | |
# align the datasets if needed | |
if args.align_suffix: | |
_make_all_alignments(args) | |
logger.info("Wrote preprocessed data to {}".format(args.destdir)) | |
if args.alignfile: | |
_align_files(args, src_dict=src_dict, tgt_dict=tgt_dict) | |
def cli_main(): | |
parser = options.get_preprocessing_parser() | |
args = parser.parse_args() | |
main(args) | |
if __name__ == "__main__": | |
cli_main() | |