Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import logging | |
import os | |
from argparse import Namespace | |
from pathlib import Path | |
import torch | |
from fairseq.data import ( | |
encoders, | |
Dictionary, | |
ResamplingDataset, | |
TransformEosLangPairDataset, | |
ConcatDataset, | |
) | |
from fairseq.data.iterators import GroupedEpochBatchIterator | |
from fairseq.data.audio.multi_modality_dataset import ( | |
MultiModalityDataset, | |
LangPairMaskDataset, | |
ModalityDatasetItem, | |
) | |
from fairseq.data.audio.speech_to_text_dataset import SpeechToTextDataset, SpeechToTextDatasetCreator | |
from fairseq.data.audio.speech_to_text_joint_dataset import ( | |
S2TJointDataConfig, | |
SpeechToTextJointDatasetCreator, | |
) | |
from fairseq.tasks import register_task | |
from fairseq.tasks.speech_to_text import SpeechToTextTask | |
from fairseq.tasks.translation import load_langpair_dataset | |
logger = logging.getLogger(__name__) | |
LANG_TAG_TEMPLATE = "<lang:{}>" | |
class SpeechTextJointToTextTask(SpeechToTextTask): | |
""" | |
Task for joint training speech and text to text. | |
""" | |
def add_args(cls, parser): | |
"""Add task-specific arguments to the parser.""" | |
super(SpeechTextJointToTextTask, cls).add_args(parser) | |
### | |
parser.add_argument( | |
"--parallel-text-data", | |
default="", | |
help="path to parallel text data directory", | |
) | |
parser.add_argument( | |
"--max-tokens-text", | |
type=int, | |
metavar="N", | |
help="maximum tokens for encoder text input ", | |
) | |
parser.add_argument( | |
"--max-positions-text", | |
type=int, | |
metavar="N", | |
default=400, | |
help="maximum tokens for per encoder text input ", | |
) | |
parser.add_argument( | |
"--langpairs", | |
default=None, | |
metavar="S", | |
help='language pairs for text training, separated with ","', | |
) | |
parser.add_argument( | |
"--speech-sample-ratio", | |
default=1, | |
type=float, | |
metavar="N", | |
help="Multiple Ratio for speech dataset with transcripts ", | |
) | |
parser.add_argument( | |
"--text-sample-ratio", | |
default=1, | |
type=float, | |
metavar="N", | |
help="Multiple Ratio for text set ", | |
) | |
parser.add_argument( | |
"--update-mix-data", | |
action="store_true", | |
help="use mixed data in one update when update-freq > 1", | |
) | |
parser.add_argument( | |
"--load-speech-only", | |
action="store_true", | |
help="load speech data only", | |
) | |
parser.add_argument( | |
"--mask-text-ratio", | |
type=float, | |
metavar="V", | |
default=0.0, | |
help="mask V source tokens for text only mode", | |
) | |
parser.add_argument( | |
"--mask-text-type", | |
default="random", | |
choices=["random", "tail"], | |
help="mask text typed", | |
) | |
parser.add_argument( | |
"--noise-token", | |
default="", | |
help="noise token for masking src text tokens if mask-text-ratio > 0", | |
) | |
parser.add_argument( | |
"--infer-target-lang", | |
default="", | |
metavar="S", | |
help="target language for inference", | |
) | |
def __init__(self, args, src_dict, tgt_dict, infer_tgt_lang_id=None): | |
super().__init__(args, tgt_dict) | |
self.src_dict = src_dict | |
self.data_cfg = S2TJointDataConfig(Path(args.data) / args.config_yaml) | |
assert self.tgt_dict.pad() == self.src_dict.pad() | |
assert self.tgt_dict.eos() == self.src_dict.eos() | |
self.speech_only = args.load_speech_only | |
self._infer_tgt_lang_id = infer_tgt_lang_id | |
def setup_task(cls, args, **kwargs): | |
"""Setup the task (e.g., load dictionaries).""" | |
data_cfg = S2TJointDataConfig(Path(args.data) / args.config_yaml) | |
tgt_dict_path = Path(args.data) / data_cfg.vocab_filename | |
src_dict_path = Path(args.data) / data_cfg.src_vocab_filename | |
if (not os.path.isfile(src_dict_path)) or (not os.path.isfile(tgt_dict_path)): | |
raise FileNotFoundError("Dict not found: {}".format(args.data)) | |
src_dict = Dictionary.load(src_dict_path.as_posix()) | |
tgt_dict = Dictionary.load(tgt_dict_path.as_posix()) | |
print("| src dictionary: {} types".format(len(src_dict))) | |
print("| tgt dictionary: {} types".format(len(tgt_dict))) | |
if args.parallel_text_data != "": | |
if not os.path.isabs(args.parallel_text_data): | |
args.parallel_text_data = os.path.join( | |
args.data, args.parallel_text_data | |
) | |
if args.langpairs is None: | |
raise Exception( | |
"Could not infer language pair, please provide it explicitly" | |
) | |
infer_tgt_lang_id = None | |
if args.infer_target_lang != "" and data_cfg.prepend_tgt_lang_tag_no_change: | |
tgt_lang_tag = SpeechToTextDataset.LANG_TAG_TEMPLATE.format( | |
args.infer_target_lang | |
) | |
infer_tgt_lang_id = tgt_dict.index(tgt_lang_tag) | |
assert infer_tgt_lang_id != tgt_dict.unk() | |
return cls(args, src_dict, tgt_dict, infer_tgt_lang_id=infer_tgt_lang_id) | |
def load_langpair_dataset(self, prepend_tgt_lang_tag=False, sampling_alpha=1.0, epoch=0): | |
lang_pairs = [] | |
text_dataset = None | |
split = "train" | |
for lp in self.args.langpairs.split(","): | |
src, tgt = lp.split("-") | |
text_dataset = load_langpair_dataset( | |
self.args.parallel_text_data, | |
split, | |
src, | |
self.src_dict, | |
tgt, | |
self.tgt_dict, | |
combine=True, | |
dataset_impl=None, | |
upsample_primary=1, | |
left_pad_source=False, | |
left_pad_target=False, | |
max_source_positions=self.args.max_positions_text, | |
max_target_positions=self.args.max_target_positions, | |
load_alignments=False, | |
truncate_source=False, | |
) | |
if prepend_tgt_lang_tag: | |
# TODO | |
text_dataset = TransformEosLangPairDataset( | |
text_dataset, | |
src_eos=self.src_dict.eos(), | |
tgt_bos=self.tgt_dict.eos(), # 'prev_output_tokens' starts with eos | |
new_tgt_bos=self.tgt_dict.index(LANG_TAG_TEMPLATE.format(tgt)), | |
) | |
lang_pairs.append(text_dataset) | |
if len(lang_pairs) > 1: | |
if sampling_alpha != 1.0: | |
size_ratios = SpeechToTextDatasetCreator.get_size_ratios( | |
self.args.langpairs.split(","), | |
[len(s) for s in lang_pairs], | |
alpha=sampling_alpha, | |
) | |
lang_pairs = [ | |
ResamplingDataset( | |
d, size_ratio=r, epoch=epoch, replace=(r >= 1.0) | |
) | |
for d, r in zip(lang_pairs, size_ratios) | |
] | |
return ConcatDataset(lang_pairs) | |
return text_dataset | |
def inference_step( | |
self, generator, models, sample, prefix_tokens=None, constraints=None | |
): | |
with torch.no_grad(): | |
return generator.generate( | |
models, | |
sample, | |
prefix_tokens=prefix_tokens, | |
constraints=constraints, | |
bos_token=self._infer_tgt_lang_id, | |
) | |
def build_src_tokenizer(self, args): | |
logger.info(f"src-pre-tokenizer: {self.data_cfg.src_pre_tokenizer}") | |
return encoders.build_tokenizer(Namespace(**self.data_cfg.src_pre_tokenizer)) | |
def build_src_bpe(self, args): | |
logger.info(f"tokenizer: {self.data_cfg.src_bpe_tokenizer}") | |
return encoders.build_bpe(Namespace(**self.data_cfg.src_bpe_tokenizer)) | |
def load_dataset(self, split, epoch=1, combine=False, **kwargs): | |
"""Load a given dataset split. | |
Args: | |
split (str): name of the split (e.g., train, valid, test) | |
""" | |
is_train_split = split.startswith("train") | |
pre_tokenizer = self.build_tokenizer(self.args) | |
bpe_tokenizer = self.build_bpe(self.args) | |
src_pre_tokenizer = self.build_src_tokenizer(self.args) | |
src_bpe_tokenizer = self.build_src_bpe(self.args) | |
ast_dataset = SpeechToTextJointDatasetCreator.from_tsv( | |
self.args.data, | |
self.data_cfg, | |
split, | |
self.tgt_dict, | |
src_dict=None if self.speech_only else self.src_dict, | |
pre_tokenizer=pre_tokenizer, | |
bpe_tokenizer=bpe_tokenizer, | |
src_pre_tokenizer=src_pre_tokenizer, | |
src_bpe_tokenizer=src_bpe_tokenizer, | |
is_train_split=is_train_split, | |
epoch=epoch, | |
seed=self.args.seed, | |
) | |
noise_token_id = -1 | |
text_dataset = None | |
if self.args.parallel_text_data != "" and is_train_split: | |
text_dataset = self.load_langpair_dataset( | |
self.data_cfg.prepend_tgt_lang_tag_no_change, | |
1.0, | |
epoch=epoch, | |
) | |
if self.args.mask_text_ratio > 0: | |
# add mask | |
noise_token_id = ( | |
self.src_dict.unk() | |
if self.args.noise_token == "" | |
else self.src_dict.index(self.args.noise_token) | |
) | |
text_dataset = LangPairMaskDataset( | |
text_dataset, | |
src_bos=self.src_dict.bos(), | |
src_eos=self.src_dict.eos(), | |
noise_id=noise_token_id, | |
mask_ratio=self.args.mask_text_ratio, | |
mask_type=self.args.mask_text_type, | |
) | |
if text_dataset is not None: | |
mdsets = [ | |
ModalityDatasetItem( | |
"sup_speech", | |
ast_dataset, | |
(self.args.max_source_positions, self.args.max_target_positions), | |
self.args.max_tokens, | |
self.args.batch_size, | |
), | |
ModalityDatasetItem( | |
"text", | |
text_dataset, | |
(self.args.max_positions_text, self.args.max_target_positions), | |
self.args.max_tokens_text | |
if self.args.max_tokens_text is not None | |
else self.args.max_tokens, | |
self.args.batch_size, | |
), | |
] | |
ast_dataset = MultiModalityDataset(mdsets) | |
self.datasets[split] = ast_dataset | |
def target_dictionary(self): | |
"""Return the :class:`~fairseq.data.Dictionary` for the language | |
model.""" | |
return self.tgt_dict | |
def source_dictionary(self): | |
"""Return the source :class:`~fairseq.data.Dictionary` (if applicable | |
for this task).""" | |
return None if self.speech_only else self.src_dict | |
def get_batch_iterator( | |
self, | |
dataset, | |
max_tokens=None, | |
max_sentences=None, | |
max_positions=None, | |
ignore_invalid_inputs=False, | |
required_batch_size_multiple=1, | |
seed=1, | |
num_shards=1, | |
shard_id=0, | |
num_workers=0, | |
epoch=0, | |
data_buffer_size=0, | |
disable_iterator_cache=False, | |
): | |
if not isinstance(dataset, MultiModalityDataset): | |
return super(SpeechTextJointToTextTask, self).get_batch_iterator( | |
dataset, | |
max_tokens, | |
max_sentences, | |
max_positions, | |
ignore_invalid_inputs, | |
required_batch_size_multiple, | |
seed, | |
num_shards, | |
shard_id, | |
num_workers, | |
epoch, | |
data_buffer_size, | |
disable_iterator_cache, | |
) | |
mult_ratio = [self.args.speech_sample_ratio, self.args.text_sample_ratio] | |
assert len(dataset.datasets) == 2 | |
# initialize the dataset with the correct starting epoch | |
dataset.set_epoch(epoch) | |
batch_samplers = dataset.get_batch_samplers( | |
mult_ratio, required_batch_size_multiple, seed | |
) | |
# return a reusable, sharded iterator | |
epoch_iter = GroupedEpochBatchIterator( | |
dataset=dataset, | |
collate_fn=dataset.collater, | |
batch_samplers=batch_samplers, | |
seed=seed, | |
num_shards=num_shards, | |
shard_id=shard_id, | |
num_workers=num_workers, | |
epoch=epoch, | |
mult_rate=1 if self.args.update_mix_data else max(self.args.update_freq), | |
buffer_size=data_buffer_size, | |
) | |
self.dataset_to_epoch_iter[dataset] = {} # refresh it every epoch | |
return epoch_iter | |