Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import unittest | |
import tests.utils as test_utils | |
import torch | |
from fairseq.data import ( | |
BacktranslationDataset, | |
LanguagePairDataset, | |
TransformEosDataset, | |
) | |
from fairseq.sequence_generator import SequenceGenerator | |
class TestBacktranslationDataset(unittest.TestCase): | |
def setUp(self): | |
( | |
self.tgt_dict, | |
self.w1, | |
self.w2, | |
self.src_tokens, | |
self.src_lengths, | |
self.model, | |
) = test_utils.sequence_generator_setup() | |
dummy_src_samples = self.src_tokens | |
self.tgt_dataset = test_utils.TestDataset(data=dummy_src_samples) | |
self.cuda = torch.cuda.is_available() | |
def _backtranslation_dataset_helper( | |
self, | |
remove_eos_from_input_src, | |
remove_eos_from_output_src, | |
): | |
tgt_dataset = LanguagePairDataset( | |
src=self.tgt_dataset, | |
src_sizes=self.tgt_dataset.sizes, | |
src_dict=self.tgt_dict, | |
tgt=None, | |
tgt_sizes=None, | |
tgt_dict=None, | |
) | |
generator = SequenceGenerator( | |
[self.model], | |
tgt_dict=self.tgt_dict, | |
max_len_a=0, | |
max_len_b=200, | |
beam_size=2, | |
unk_penalty=0, | |
) | |
backtranslation_dataset = BacktranslationDataset( | |
tgt_dataset=TransformEosDataset( | |
dataset=tgt_dataset, | |
eos=self.tgt_dict.eos(), | |
# remove eos from the input src | |
remove_eos_from_src=remove_eos_from_input_src, | |
), | |
src_dict=self.tgt_dict, | |
backtranslation_fn=( | |
lambda sample: generator.generate([self.model], sample) | |
), | |
output_collater=TransformEosDataset( | |
dataset=tgt_dataset, | |
eos=self.tgt_dict.eos(), | |
# if we remove eos from the input src, then we need to add it | |
# back to the output tgt | |
append_eos_to_tgt=remove_eos_from_input_src, | |
remove_eos_from_src=remove_eos_from_output_src, | |
).collater, | |
cuda=self.cuda, | |
) | |
dataloader = torch.utils.data.DataLoader( | |
backtranslation_dataset, | |
batch_size=2, | |
collate_fn=backtranslation_dataset.collater, | |
) | |
backtranslation_batch_result = next(iter(dataloader)) | |
eos, pad, w1, w2 = self.tgt_dict.eos(), self.tgt_dict.pad(), self.w1, self.w2 | |
# Note that we sort by src_lengths and add left padding, so actually | |
# ids will look like: [1, 0] | |
expected_src = torch.LongTensor([[w1, w2, w1, eos], [pad, pad, w1, eos]]) | |
if remove_eos_from_output_src: | |
expected_src = expected_src[:, :-1] | |
expected_tgt = torch.LongTensor([[w1, w2, eos], [w1, w2, eos]]) | |
generated_src = backtranslation_batch_result["net_input"]["src_tokens"] | |
tgt_tokens = backtranslation_batch_result["target"] | |
self.assertTensorEqual(expected_src, generated_src) | |
self.assertTensorEqual(expected_tgt, tgt_tokens) | |
def test_backtranslation_dataset_no_eos_in_output_src(self): | |
self._backtranslation_dataset_helper( | |
remove_eos_from_input_src=False, | |
remove_eos_from_output_src=True, | |
) | |
def test_backtranslation_dataset_with_eos_in_output_src(self): | |
self._backtranslation_dataset_helper( | |
remove_eos_from_input_src=False, | |
remove_eos_from_output_src=False, | |
) | |
def test_backtranslation_dataset_no_eos_in_input_src(self): | |
self._backtranslation_dataset_helper( | |
remove_eos_from_input_src=True, | |
remove_eos_from_output_src=False, | |
) | |
def assertTensorEqual(self, t1, t2): | |
self.assertEqual(t1.size(), t2.size(), "size mismatch") | |
self.assertEqual(t1.ne(t2).long().sum(), 0) | |
if __name__ == "__main__": | |
unittest.main() | |