Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
from collections import OrderedDict, defaultdict | |
import json | |
import os | |
import logging | |
from argparse import ArgumentError | |
from fairseq import options, models | |
from fairseq.data import ( | |
data_utils, | |
Dictionary, | |
LanguagePairDataset, | |
IndexedDataset, | |
FairseqDataset, | |
) | |
from .multitask_data_utils import ( | |
MultitaskDatasetWrapper, | |
MultidatasetEpochBatchIterator, | |
) | |
from fairseq.tasks import LegacyFairseqTask, register_task | |
logger = logging.getLogger(__name__) | |
class LaserTask(LegacyFairseqTask): | |
def add_args(parser): | |
"""Add task-specific arguments to the parser.""" | |
parser.add_argument( | |
"configfile", metavar="PATH", help="dataset configuration file in json" | |
) | |
parser.add_argument( | |
"--weighting-alpha", | |
type=float, | |
default=None, | |
help="alpha for automatic weighting", | |
) | |
parser.add_argument( | |
"--raw-text", action="store_true", help="load raw text dataset" | |
) | |
parser.add_argument( | |
"--left-pad-source", | |
default="True", | |
type=str, | |
metavar="BOOL", | |
help="pad the source on the left (default: True)", | |
) | |
parser.add_argument( | |
"--left-pad-target", | |
default="False", | |
type=str, | |
metavar="BOOL", | |
help="pad the target on the left (default: False)", | |
) | |
try: | |
parser.add_argument( | |
"--max-source-positions", | |
default=1024, | |
type=int, | |
metavar="N", | |
help="max number of tokens in the source sequence", | |
) | |
parser.add_argument( | |
"--max-target-positions", | |
default=1024, | |
type=int, | |
metavar="N", | |
help="max number of tokens in the target sequence", | |
) | |
except ArgumentError: | |
# this might have already been defined. Once we transition this to hydra it should be fine to add it here. | |
pass | |
def __init__(self, args, config, src_dictionary, tgt_dictionary, num_tasks): | |
super().__init__(args) | |
self.config = config | |
self.src_dictionary = src_dictionary | |
self.tgt_dictionary = tgt_dictionary | |
self.num_tasks = num_tasks | |
def setup_task(cls, args, **kwargs): | |
with open(args.configfile, "r") as f: | |
config = json.load(f) | |
num_tasks = max(dataset["id"] for dataset in config["train"]) + 1 | |
args.left_pad_source = options.eval_bool(args.left_pad_source) | |
args.left_pad_target = options.eval_bool(args.left_pad_target) | |
src_dictionary = Dictionary.load(config["src_vocab"]) | |
tgt_dictionary = Dictionary.load(config["tgt_vocab"]) | |
logger.info( | |
"| src Dictionary {} : {} types".format( | |
config["src_vocab"], len(src_dictionary) | |
) | |
) | |
logger.info( | |
"| tgt Dictionary {} : {} types".format( | |
config["tgt_vocab"], len(tgt_dictionary) | |
) | |
) | |
return cls(args, config, src_dictionary, tgt_dictionary, num_tasks) | |
# Experimental overriding for backtranslation | |
def build_model(self, args): | |
model = models.build_model(args, self) | |
return model | |
def dataset(self, split): | |
if split not in self.datasets: | |
raise KeyError("Dataset not loaded: " + split) | |
return self.datasets[split] | |
def load_dataset(self, split, epoch=1, **kwargs): | |
"""Load a dataset split.""" | |
def indexed_dataset(path, dictionary): | |
if self.args.raw_text: | |
raise Exception("Unable to handle raw text.") | |
dataset = IndexedDataset(path, fix_lua_indexing=True) | |
return dataset | |
pair_datasets = OrderedDict() | |
if split == "valid": | |
self.datasets[split] = pair_datasets | |
return | |
if split not in self.config: | |
raise FileNotFoundError( | |
"Dataset not found in config file: {}".format(split) | |
) | |
size_by_corpus = defaultdict(int) | |
size_sum = 0 | |
size_sum_with_subsampling = 0 | |
init_pair_datasets = {} | |
for dataset_config in self.config[split]: | |
src_path = os.path.dirname(dataset_config["src"]) | |
corpus_name = src_path.split("/")[-2] | |
language_pair_name = src_path.split("/")[-1] | |
pair_datasets_key = corpus_name + "-" + language_pair_name | |
logger.info(f"loading... {pair_datasets_key}") | |
if "src" in dataset_config: | |
src_dataset = indexed_dataset( | |
dataset_config["src"], self.src_dictionary | |
) | |
else: | |
src_dataset = None | |
if "tgt" in dataset_config: | |
tgt_dataset = indexed_dataset( | |
dataset_config["tgt"], self.tgt_dictionary | |
) | |
else: | |
tgt_dataset = None | |
dataset = LanguagePairDataset( | |
src_dataset, | |
src_dataset.sizes, | |
self.src_dictionary, | |
tgt_dataset, | |
tgt_dataset.sizes, | |
self.tgt_dictionary, | |
left_pad_source=self.args.left_pad_source, | |
left_pad_target=self.args.left_pad_target, | |
) | |
if pair_datasets_key in init_pair_datasets: | |
logger.warning( | |
f"Ignoring already added {pair_datasets_key}. " | |
f"Consider using `sample` key in order to upsample." | |
) | |
else: | |
init_pair_datasets[pair_datasets_key] = { | |
"dataset": dataset, | |
"sample": dataset_config.get("sample", None), | |
"id": dataset_config.get("id", None), | |
"len": len(dataset), | |
} | |
length_sum = 0 | |
weighted_freqs_sum = 0 | |
freq_per_dataset = {} | |
vmax = 0 | |
vmin = 1 | |
weighted_freq_per_dataset = {} | |
if self.args.weighting_alpha: | |
for key in init_pair_datasets: | |
if init_pair_datasets[key]["sample"] is None: | |
length_sum += len(init_pair_datasets[key]["dataset"]) | |
for key in init_pair_datasets: | |
if init_pair_datasets[key]["sample"] is None: | |
val = float(init_pair_datasets[key]["len"]) / length_sum | |
freq_per_dataset[key] = val | |
weighted_freqs_sum += val ** self.args.weighting_alpha | |
for key in freq_per_dataset: | |
val = ( | |
freq_per_dataset[key] ** self.args.weighting_alpha | |
/ weighted_freqs_sum | |
) | |
vmin = min(vmin, val) | |
vmax = max(vmax, val) | |
weighted_freq_per_dataset[key] = val | |
for pair_datasets_key in init_pair_datasets: | |
dataset_config = init_pair_datasets[pair_datasets_key] | |
dataset = dataset_config["dataset"] | |
sample = dataset_config["sample"] | |
if sample is None: | |
sample = 1.0 | |
if pair_datasets_key in weighted_freq_per_dataset: | |
w = vmax / weighted_freq_per_dataset[pair_datasets_key] | |
sample = w | |
sample = round(sample) | |
initial_sample = sample | |
initial_pair_datasets_key = pair_datasets_key | |
while sample >= 1.0: | |
assert ( | |
pair_datasets_key not in pair_datasets | |
), f"{pair_datasets_key} already in" | |
size_sum_with_subsampling += len(dataset) | |
pair_datasets[pair_datasets_key] = MultitaskDatasetWrapper( | |
dataset, dataset_config.get("id", 0), 1.0, name=pair_datasets_key | |
) | |
size_sum += len(dataset) | |
sample -= 1.0 | |
pair_datasets_key += "-up" | |
assert sample < 1e-6, f"sample remains > 0 {pair_datasets_key}" | |
logger.info( | |
f"added pair {initial_pair_datasets_key} length {len(dataset)} new_length = {len(dataset)*initial_sample}" | |
) | |
size_by_corpus[corpus_name] += len(dataset) | |
self.datasets[split] = pair_datasets | |
logger.info( | |
f"Datasets number = {len(self.datasets[split])} size = {size_sum} size_sum_with_subsampling = {size_sum_with_subsampling}" | |
) | |
def source_dictionary(self): | |
return self.src_dictionary | |
def target_dictionary(self): | |
return self.tgt_dictionary | |
def get_batch_iterator( | |
self, | |
dataset, | |
max_tokens=None, | |
max_sentences=None, | |
max_positions=None, | |
ignore_invalid_inputs=False, | |
required_batch_size_multiple=1, | |
seed=1, | |
num_shards=1, | |
shard_id=0, | |
num_workers=0, | |
epoch=1, | |
data_buffer_size=0, | |
disable_iterator_cache=False, | |
): | |
assert isinstance(dataset, OrderedDict) | |
assert len(dataset) | |
assert isinstance(dataset[next(iter(dataset))], FairseqDataset) | |
# initialize the dataset with the correct starting epoch | |
for _, dt in dataset.items(): | |
dt.set_epoch(epoch) | |
indices = OrderedDict() | |
batch_sampler = OrderedDict() | |
with data_utils.numpy_seed(seed + epoch): | |
for key, dt in dataset.items(): | |
logger.info(f"\t ordered_indices {key}") | |
indices[key] = dt.ordered_indices() | |
# filter examples that are too large | |
if max_positions is not None: | |
for key, dt in dataset.items(): | |
logger.info(f"\t filter_by_size {key}") | |
indices[key], ignored = dt.filter_indices_by_size( | |
indices[key], max_positions | |
) | |
for key, dt in dataset.items(): | |
logger.info(f"\t batch_by_size {key}") | |
batch_sampler[key] = data_utils.batch_by_size( | |
indices[key], | |
dt.num_tokens, | |
max_tokens=max_tokens, | |
max_sentences=max_sentences, | |
required_batch_size_multiple=required_batch_size_multiple, | |
) | |
epoch_iter = MultidatasetEpochBatchIterator( | |
dataset=dataset, | |
batch_sampler=batch_sampler, | |
seed=seed, | |
num_shards=num_shards, | |
shard_id=shard_id, | |
num_workers=num_workers, | |
epoch=epoch, | |
) | |
return epoch_iter | |