JustinLin610
update
10b0761
raw
history blame
11 kB
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from collections import OrderedDict, defaultdict
import json
import os
import logging
from argparse import ArgumentError
from fairseq import options, models
from fairseq.data import (
data_utils,
Dictionary,
LanguagePairDataset,
IndexedDataset,
FairseqDataset,
)
from .multitask_data_utils import (
MultitaskDatasetWrapper,
MultidatasetEpochBatchIterator,
)
from fairseq.tasks import LegacyFairseqTask, register_task
logger = logging.getLogger(__name__)
@register_task("laser")
class LaserTask(LegacyFairseqTask):
@staticmethod
def add_args(parser):
"""Add task-specific arguments to the parser."""
parser.add_argument(
"configfile", metavar="PATH", help="dataset configuration file in json"
)
parser.add_argument(
"--weighting-alpha",
type=float,
default=None,
help="alpha for automatic weighting",
)
parser.add_argument(
"--raw-text", action="store_true", help="load raw text dataset"
)
parser.add_argument(
"--left-pad-source",
default="True",
type=str,
metavar="BOOL",
help="pad the source on the left (default: True)",
)
parser.add_argument(
"--left-pad-target",
default="False",
type=str,
metavar="BOOL",
help="pad the target on the left (default: False)",
)
try:
parser.add_argument(
"--max-source-positions",
default=1024,
type=int,
metavar="N",
help="max number of tokens in the source sequence",
)
parser.add_argument(
"--max-target-positions",
default=1024,
type=int,
metavar="N",
help="max number of tokens in the target sequence",
)
except ArgumentError:
# this might have already been defined. Once we transition this to hydra it should be fine to add it here.
pass
def __init__(self, args, config, src_dictionary, tgt_dictionary, num_tasks):
super().__init__(args)
self.config = config
self.src_dictionary = src_dictionary
self.tgt_dictionary = tgt_dictionary
self.num_tasks = num_tasks
@classmethod
def setup_task(cls, args, **kwargs):
with open(args.configfile, "r") as f:
config = json.load(f)
num_tasks = max(dataset["id"] for dataset in config["train"]) + 1
args.left_pad_source = options.eval_bool(args.left_pad_source)
args.left_pad_target = options.eval_bool(args.left_pad_target)
src_dictionary = Dictionary.load(config["src_vocab"])
tgt_dictionary = Dictionary.load(config["tgt_vocab"])
logger.info(
"| src Dictionary {} : {} types".format(
config["src_vocab"], len(src_dictionary)
)
)
logger.info(
"| tgt Dictionary {} : {} types".format(
config["tgt_vocab"], len(tgt_dictionary)
)
)
return cls(args, config, src_dictionary, tgt_dictionary, num_tasks)
# Experimental overriding for backtranslation
def build_model(self, args):
model = models.build_model(args, self)
return model
def dataset(self, split):
if split not in self.datasets:
raise KeyError("Dataset not loaded: " + split)
return self.datasets[split]
def load_dataset(self, split, epoch=1, **kwargs):
"""Load a dataset split."""
def indexed_dataset(path, dictionary):
if self.args.raw_text:
raise Exception("Unable to handle raw text.")
dataset = IndexedDataset(path, fix_lua_indexing=True)
return dataset
pair_datasets = OrderedDict()
if split == "valid":
self.datasets[split] = pair_datasets
return
if split not in self.config:
raise FileNotFoundError(
"Dataset not found in config file: {}".format(split)
)
size_by_corpus = defaultdict(int)
size_sum = 0
size_sum_with_subsampling = 0
init_pair_datasets = {}
for dataset_config in self.config[split]:
src_path = os.path.dirname(dataset_config["src"])
corpus_name = src_path.split("/")[-2]
language_pair_name = src_path.split("/")[-1]
pair_datasets_key = corpus_name + "-" + language_pair_name
logger.info(f"loading... {pair_datasets_key}")
if "src" in dataset_config:
src_dataset = indexed_dataset(
dataset_config["src"], self.src_dictionary
)
else:
src_dataset = None
if "tgt" in dataset_config:
tgt_dataset = indexed_dataset(
dataset_config["tgt"], self.tgt_dictionary
)
else:
tgt_dataset = None
dataset = LanguagePairDataset(
src_dataset,
src_dataset.sizes,
self.src_dictionary,
tgt_dataset,
tgt_dataset.sizes,
self.tgt_dictionary,
left_pad_source=self.args.left_pad_source,
left_pad_target=self.args.left_pad_target,
)
if pair_datasets_key in init_pair_datasets:
logger.warning(
f"Ignoring already added {pair_datasets_key}. "
f"Consider using `sample` key in order to upsample."
)
else:
init_pair_datasets[pair_datasets_key] = {
"dataset": dataset,
"sample": dataset_config.get("sample", None),
"id": dataset_config.get("id", None),
"len": len(dataset),
}
length_sum = 0
weighted_freqs_sum = 0
freq_per_dataset = {}
vmax = 0
vmin = 1
weighted_freq_per_dataset = {}
if self.args.weighting_alpha:
for key in init_pair_datasets:
if init_pair_datasets[key]["sample"] is None:
length_sum += len(init_pair_datasets[key]["dataset"])
for key in init_pair_datasets:
if init_pair_datasets[key]["sample"] is None:
val = float(init_pair_datasets[key]["len"]) / length_sum
freq_per_dataset[key] = val
weighted_freqs_sum += val ** self.args.weighting_alpha
for key in freq_per_dataset:
val = (
freq_per_dataset[key] ** self.args.weighting_alpha
/ weighted_freqs_sum
)
vmin = min(vmin, val)
vmax = max(vmax, val)
weighted_freq_per_dataset[key] = val
for pair_datasets_key in init_pair_datasets:
dataset_config = init_pair_datasets[pair_datasets_key]
dataset = dataset_config["dataset"]
sample = dataset_config["sample"]
if sample is None:
sample = 1.0
if pair_datasets_key in weighted_freq_per_dataset:
w = vmax / weighted_freq_per_dataset[pair_datasets_key]
sample = w
sample = round(sample)
initial_sample = sample
initial_pair_datasets_key = pair_datasets_key
while sample >= 1.0:
assert (
pair_datasets_key not in pair_datasets
), f"{pair_datasets_key} already in"
size_sum_with_subsampling += len(dataset)
pair_datasets[pair_datasets_key] = MultitaskDatasetWrapper(
dataset, dataset_config.get("id", 0), 1.0, name=pair_datasets_key
)
size_sum += len(dataset)
sample -= 1.0
pair_datasets_key += "-up"
assert sample < 1e-6, f"sample remains > 0 {pair_datasets_key}"
logger.info(
f"added pair {initial_pair_datasets_key} length {len(dataset)} new_length = {len(dataset)*initial_sample}"
)
size_by_corpus[corpus_name] += len(dataset)
self.datasets[split] = pair_datasets
logger.info(
f"Datasets number = {len(self.datasets[split])} size = {size_sum} size_sum_with_subsampling = {size_sum_with_subsampling}"
)
@property
def source_dictionary(self):
return self.src_dictionary
@property
def target_dictionary(self):
return self.tgt_dictionary
def get_batch_iterator(
self,
dataset,
max_tokens=None,
max_sentences=None,
max_positions=None,
ignore_invalid_inputs=False,
required_batch_size_multiple=1,
seed=1,
num_shards=1,
shard_id=0,
num_workers=0,
epoch=1,
data_buffer_size=0,
disable_iterator_cache=False,
):
assert isinstance(dataset, OrderedDict)
assert len(dataset)
assert isinstance(dataset[next(iter(dataset))], FairseqDataset)
# initialize the dataset with the correct starting epoch
for _, dt in dataset.items():
dt.set_epoch(epoch)
indices = OrderedDict()
batch_sampler = OrderedDict()
with data_utils.numpy_seed(seed + epoch):
for key, dt in dataset.items():
logger.info(f"\t ordered_indices {key}")
indices[key] = dt.ordered_indices()
# filter examples that are too large
if max_positions is not None:
for key, dt in dataset.items():
logger.info(f"\t filter_by_size {key}")
indices[key], ignored = dt.filter_indices_by_size(
indices[key], max_positions
)
for key, dt in dataset.items():
logger.info(f"\t batch_by_size {key}")
batch_sampler[key] = data_utils.batch_by_size(
indices[key],
dt.num_tokens,
max_tokens=max_tokens,
max_sentences=max_sentences,
required_batch_size_multiple=required_batch_size_multiple,
)
epoch_iter = MultidatasetEpochBatchIterator(
dataset=dataset,
batch_sampler=batch_sampler,
seed=seed,
num_shards=num_shards,
shard_id=shard_id,
num_workers=num_workers,
epoch=epoch,
)
return epoch_iter