Spaces:
Runtime error
Runtime error
File size: 5,433 Bytes
ee21b96 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import unittest
from fairseq.data import iterators
class TestIterators(unittest.TestCase):
def test_counting_iterator_index(self, ref=None, itr=None):
# Test the indexing functionality of CountingIterator
if ref is None:
assert itr is None
ref = list(range(10))
itr = iterators.CountingIterator(ref)
else:
assert len(ref) == 10
assert itr is not None
self.assertTrue(itr.has_next())
self.assertEqual(itr.n, 0)
self.assertEqual(next(itr), ref[0])
self.assertEqual(itr.n, 1)
self.assertEqual(next(itr), ref[1])
self.assertEqual(itr.n, 2)
itr.skip(3)
self.assertEqual(itr.n, 5)
self.assertEqual(next(itr), ref[5])
itr.skip(2)
self.assertEqual(itr.n, 8)
self.assertEqual(list(itr), [ref[8], ref[9]])
self.assertFalse(itr.has_next())
def test_counting_iterator_length_mismatch(self):
ref = list(range(10))
# When the underlying iterable is longer than the CountingIterator,
# the remaining items in the iterable should be ignored
itr = iterators.CountingIterator(ref, total=8)
self.assertEqual(list(itr), ref[:8])
# When the underlying iterable is shorter than the CountingIterator,
# raise an IndexError when the underlying iterable is exhausted
itr = iterators.CountingIterator(ref, total=12)
self.assertRaises(IndexError, list, itr)
def test_counting_iterator_take(self):
# Test the "take" method of CountingIterator
ref = list(range(10))
itr = iterators.CountingIterator(ref)
itr.take(5)
self.assertEqual(len(itr), len(list(iter(itr))))
self.assertEqual(len(itr), 5)
itr = iterators.CountingIterator(ref)
itr.take(5)
self.assertEqual(next(itr), ref[0])
self.assertEqual(next(itr), ref[1])
itr.skip(2)
self.assertEqual(next(itr), ref[4])
self.assertFalse(itr.has_next())
def test_grouped_iterator(self):
# test correctness
x = list(range(10))
itr = iterators.GroupedIterator(x, 1)
self.assertEqual(list(itr), [[0], [1], [2], [3], [4], [5], [6], [7], [8], [9]])
itr = iterators.GroupedIterator(x, 4)
self.assertEqual(list(itr), [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9]])
itr = iterators.GroupedIterator(x, 5)
self.assertEqual(list(itr), [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]])
# test the GroupIterator also works correctly as a CountingIterator
x = list(range(30))
ref = list(iterators.GroupedIterator(x, 3))
itr = iterators.GroupedIterator(x, 3)
self.test_counting_iterator_index(ref, itr)
def test_sharded_iterator(self):
# test correctness
x = list(range(10))
itr = iterators.ShardedIterator(x, num_shards=1, shard_id=0)
self.assertEqual(list(itr), x)
itr = iterators.ShardedIterator(x, num_shards=2, shard_id=0)
self.assertEqual(list(itr), [0, 2, 4, 6, 8])
itr = iterators.ShardedIterator(x, num_shards=2, shard_id=1)
self.assertEqual(list(itr), [1, 3, 5, 7, 9])
itr = iterators.ShardedIterator(x, num_shards=3, shard_id=0)
self.assertEqual(list(itr), [0, 3, 6, 9])
itr = iterators.ShardedIterator(x, num_shards=3, shard_id=1)
self.assertEqual(list(itr), [1, 4, 7, None])
itr = iterators.ShardedIterator(x, num_shards=3, shard_id=2)
self.assertEqual(list(itr), [2, 5, 8, None])
# test CountingIterator functionality
x = list(range(30))
ref = list(iterators.ShardedIterator(x, num_shards=3, shard_id=0))
itr = iterators.ShardedIterator(x, num_shards=3, shard_id=0)
self.test_counting_iterator_index(ref, itr)
def test_counting_iterator_buffered_iterator_take(self):
ref = list(range(10))
buffered_itr = iterators.BufferedIterator(2, ref)
itr = iterators.CountingIterator(buffered_itr)
itr.take(5)
self.assertEqual(len(itr), len(list(iter(itr))))
self.assertEqual(len(itr), 5)
buffered_itr = iterators.BufferedIterator(2, ref)
itr = iterators.CountingIterator(buffered_itr)
itr.take(5)
self.assertEqual(len(buffered_itr), 5)
self.assertEqual(len(list(iter(buffered_itr))), 5)
buffered_itr = iterators.BufferedIterator(2, ref)
itr = iterators.CountingIterator(buffered_itr)
itr.take(5)
self.assertEqual(next(itr), ref[0])
self.assertEqual(next(itr), ref[1])
itr.skip(2)
self.assertEqual(next(itr), ref[4])
self.assertFalse(itr.has_next())
self.assertRaises(StopIteration, next, buffered_itr)
ref = list(range(4, 10))
buffered_itr = iterators.BufferedIterator(2, ref)
itr = iterators.CountingIterator(buffered_itr, start=4)
itr.take(5)
self.assertEqual(len(itr), 5)
self.assertEqual(len(buffered_itr), 1)
self.assertEqual(next(itr), ref[0])
self.assertFalse(itr.has_next())
self.assertRaises(StopIteration, next, buffered_itr)
if __name__ == "__main__":
unittest.main()
|