JustinLin610
update
10b0761
raw
history blame
2.05 kB
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import unittest
import numpy as np
import torch
from examples.speech_recognition.data.collaters import Seq2SeqCollater
class TestSeq2SeqCollator(unittest.TestCase):
def test_collate(self):
eos_idx = 1
pad_idx = 0
collater = Seq2SeqCollater(
feature_index=0, label_index=1, pad_index=pad_idx, eos_index=eos_idx
)
# 2 frames in the first sample and 3 frames in the second one
frames1 = np.array([[7, 8], [9, 10]])
frames2 = np.array([[1, 2], [3, 4], [5, 6]])
target1 = np.array([4, 2, 3, eos_idx])
target2 = np.array([3, 2, eos_idx])
sample1 = {"id": 0, "data": [frames1, target1]}
sample2 = {"id": 1, "data": [frames2, target2]}
batch = collater.collate([sample1, sample2])
# collate sort inputs by frame's length before creating the batch
self.assertTensorEqual(batch["id"], torch.tensor([1, 0]))
self.assertEqual(batch["ntokens"], 7)
self.assertTensorEqual(
batch["net_input"]["src_tokens"],
torch.tensor(
[[[1, 2], [3, 4], [5, 6]], [[7, 8], [9, 10], [pad_idx, pad_idx]]]
),
)
self.assertTensorEqual(
batch["net_input"]["prev_output_tokens"],
torch.tensor([[eos_idx, 3, 2, pad_idx], [eos_idx, 4, 2, 3]]),
)
self.assertTensorEqual(batch["net_input"]["src_lengths"], torch.tensor([3, 2]))
self.assertTensorEqual(
batch["target"],
torch.tensor([[3, 2, eos_idx, pad_idx], [4, 2, 3, eos_idx]]),
)
self.assertEqual(batch["nsentences"], 2)
def assertTensorEqual(self, t1, t2):
self.assertEqual(t1.size(), t2.size(), "size mismatch")
self.assertEqual(t1.ne(t2).long().sum(), 0)
if __name__ == "__main__":
unittest.main()