Spaces:
Runtime error
Runtime error
#!/usr/bin/env python3 | |
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import unittest | |
import numpy as np | |
import torch | |
from examples.speech_recognition.data.collaters import Seq2SeqCollater | |
class TestSeq2SeqCollator(unittest.TestCase): | |
def test_collate(self): | |
eos_idx = 1 | |
pad_idx = 0 | |
collater = Seq2SeqCollater( | |
feature_index=0, label_index=1, pad_index=pad_idx, eos_index=eos_idx | |
) | |
# 2 frames in the first sample and 3 frames in the second one | |
frames1 = np.array([[7, 8], [9, 10]]) | |
frames2 = np.array([[1, 2], [3, 4], [5, 6]]) | |
target1 = np.array([4, 2, 3, eos_idx]) | |
target2 = np.array([3, 2, eos_idx]) | |
sample1 = {"id": 0, "data": [frames1, target1]} | |
sample2 = {"id": 1, "data": [frames2, target2]} | |
batch = collater.collate([sample1, sample2]) | |
# collate sort inputs by frame's length before creating the batch | |
self.assertTensorEqual(batch["id"], torch.tensor([1, 0])) | |
self.assertEqual(batch["ntokens"], 7) | |
self.assertTensorEqual( | |
batch["net_input"]["src_tokens"], | |
torch.tensor( | |
[[[1, 2], [3, 4], [5, 6]], [[7, 8], [9, 10], [pad_idx, pad_idx]]] | |
), | |
) | |
self.assertTensorEqual( | |
batch["net_input"]["prev_output_tokens"], | |
torch.tensor([[eos_idx, 3, 2, pad_idx], [eos_idx, 4, 2, 3]]), | |
) | |
self.assertTensorEqual(batch["net_input"]["src_lengths"], torch.tensor([3, 2])) | |
self.assertTensorEqual( | |
batch["target"], | |
torch.tensor([[3, 2, eos_idx, pad_idx], [4, 2, 3, eos_idx]]), | |
) | |
self.assertEqual(batch["nsentences"], 2) | |
def assertTensorEqual(self, t1, t2): | |
self.assertEqual(t1.size(), t2.size(), "size mismatch") | |
self.assertEqual(t1.ne(t2).long().sum(), 0) | |
if __name__ == "__main__": | |
unittest.main() | |