OFA-OCR / fairseq /tests /test_data_utils.py
JustinLin610's picture
first commit
ee21b96
raw
history blame
4.87 kB
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import unittest
import numpy as np
from fairseq.data.data_utils_fast import batch_by_size_fn
from fairseq.data.data_utils_fast import batch_by_size_vec
class TestBatchBySize(unittest.TestCase):
@classmethod
def batch_by_size_baseline(
cls,
indices,
num_tokens_vec,
max_tokens,
max_sentences,
bsz_mult,
):
"""Simple, reliable and slow implementation of batch by size """
batches = []
start = 0
while start < len(indices):
for end in range(start + 1, len(indices) + 1):
max_val = max(num_tokens_vec[pos] for pos in range(start, end))
sent_count = end - start
num_tokens = max_val * sent_count
overflow = num_tokens > max_tokens > 0 or sent_count > max_sentences > 0
terminate = overflow or end == len(indices)
if overflow:
sent_count -= 1
if terminate:
if sent_count > bsz_mult:
sent_count = sent_count - sent_count % bsz_mult
batches.append(indices[start : start + sent_count])
start = start + sent_count
break
return batches
@classmethod
def _get_error_message(
cls, max_sentences, max_tokens, bsz_mult, num_tokens_vec, validation, results
):
return f"""Reference batch_by_size implementation should produce
same output as the baseline method.
Params:
max_sentences={max_sentences},
max_tokens={max_tokens},
bsz_mult={bsz_mult},
num_tokens_vec={num_tokens_vec},
expected_batches={validation},
returned_batches={results}"""
def _compare_results(
self,
indices_len,
batch_by_size_impl,
max_sentences,
max_tokens,
bsz_mult,
num_tokens_vec,
):
indices = np.array(list(range(indices_len)))
validation = self.batch_by_size_baseline(
indices,
num_tokens_vec,
max_tokens=max_tokens,
max_sentences=max_sentences,
bsz_mult=bsz_mult,
)
results = batch_by_size_impl(
indices,
num_tokens_vec,
max_tokens=max_tokens,
max_sentences=max_sentences,
bsz_mult=bsz_mult,
)
error_msg = self._get_error_message(
max_sentences, max_tokens, bsz_mult, num_tokens_vec, validation, results
)
self.assertEqual(len(validation), len(results), error_msg)
for first, second in zip(validation, results):
self.assertTrue(np.array_equal(first, second), error_msg)
def _run_compare_with_baseline_sweep(self, batch_by_size_impl):
"""Compare reference batch_by_size implementation with batch_by_size_baseline
across a dense grid of hyperparam values"""
MAX_MAX_TOKENS = 10
NUM_TOKENS_VECS_COUNT = 5
for indices_len in [10, 11]: # try odd and even len of indices
for max_sentences in range(0, indices_len + 2):
for max_tokens in range(0, MAX_MAX_TOKENS):
for bsz_mult in range(1, max(MAX_MAX_TOKENS, indices_len) + 2):
for _ in range(NUM_TOKENS_VECS_COUNT):
num_tokens_vec = np.random.randint(
0, max_tokens + 1, size=indices_len
)
self._compare_results(
indices_len,
batch_by_size_impl,
max_sentences,
max_tokens,
bsz_mult,
num_tokens_vec,
)
class TestBatchBySizeVec(TestBatchBySize):
def test_compare_with_baseline(self):
self._run_compare_with_baseline_sweep(batch_by_size_vec)
class TestBatchBySizeFn(TestBatchBySize):
def test_compare_with_baseline(self):
def batch_by_size_fn_wrapper(
indices,
num_tokens_vec,
max_tokens,
max_sentences,
bsz_mult,
):
def num_tokens_fn(idx):
return num_tokens_vec[idx]
return batch_by_size_fn(
indices, num_tokens_fn, max_tokens, max_sentences, bsz_mult
)
self._run_compare_with_baseline_sweep(batch_by_size_fn_wrapper)
if __name__ == "__main__":
unittest.main()