Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import unittest | |
import numpy as np | |
from fairseq.data.data_utils_fast import batch_by_size_fn | |
from fairseq.data.data_utils_fast import batch_by_size_vec | |
class TestBatchBySize(unittest.TestCase): | |
def batch_by_size_baseline( | |
cls, | |
indices, | |
num_tokens_vec, | |
max_tokens, | |
max_sentences, | |
bsz_mult, | |
): | |
"""Simple, reliable and slow implementation of batch by size """ | |
batches = [] | |
start = 0 | |
while start < len(indices): | |
for end in range(start + 1, len(indices) + 1): | |
max_val = max(num_tokens_vec[pos] for pos in range(start, end)) | |
sent_count = end - start | |
num_tokens = max_val * sent_count | |
overflow = num_tokens > max_tokens > 0 or sent_count > max_sentences > 0 | |
terminate = overflow or end == len(indices) | |
if overflow: | |
sent_count -= 1 | |
if terminate: | |
if sent_count > bsz_mult: | |
sent_count = sent_count - sent_count % bsz_mult | |
batches.append(indices[start : start + sent_count]) | |
start = start + sent_count | |
break | |
return batches | |
def _get_error_message( | |
cls, max_sentences, max_tokens, bsz_mult, num_tokens_vec, validation, results | |
): | |
return f"""Reference batch_by_size implementation should produce | |
same output as the baseline method. | |
Params: | |
max_sentences={max_sentences}, | |
max_tokens={max_tokens}, | |
bsz_mult={bsz_mult}, | |
num_tokens_vec={num_tokens_vec}, | |
expected_batches={validation}, | |
returned_batches={results}""" | |
def _compare_results( | |
self, | |
indices_len, | |
batch_by_size_impl, | |
max_sentences, | |
max_tokens, | |
bsz_mult, | |
num_tokens_vec, | |
): | |
indices = np.array(list(range(indices_len))) | |
validation = self.batch_by_size_baseline( | |
indices, | |
num_tokens_vec, | |
max_tokens=max_tokens, | |
max_sentences=max_sentences, | |
bsz_mult=bsz_mult, | |
) | |
results = batch_by_size_impl( | |
indices, | |
num_tokens_vec, | |
max_tokens=max_tokens, | |
max_sentences=max_sentences, | |
bsz_mult=bsz_mult, | |
) | |
error_msg = self._get_error_message( | |
max_sentences, max_tokens, bsz_mult, num_tokens_vec, validation, results | |
) | |
self.assertEqual(len(validation), len(results), error_msg) | |
for first, second in zip(validation, results): | |
self.assertTrue(np.array_equal(first, second), error_msg) | |
def _run_compare_with_baseline_sweep(self, batch_by_size_impl): | |
"""Compare reference batch_by_size implementation with batch_by_size_baseline | |
across a dense grid of hyperparam values""" | |
MAX_MAX_TOKENS = 10 | |
NUM_TOKENS_VECS_COUNT = 5 | |
for indices_len in [10, 11]: # try odd and even len of indices | |
for max_sentences in range(0, indices_len + 2): | |
for max_tokens in range(0, MAX_MAX_TOKENS): | |
for bsz_mult in range(1, max(MAX_MAX_TOKENS, indices_len) + 2): | |
for _ in range(NUM_TOKENS_VECS_COUNT): | |
num_tokens_vec = np.random.randint( | |
0, max_tokens + 1, size=indices_len | |
) | |
self._compare_results( | |
indices_len, | |
batch_by_size_impl, | |
max_sentences, | |
max_tokens, | |
bsz_mult, | |
num_tokens_vec, | |
) | |
class TestBatchBySizeVec(TestBatchBySize): | |
def test_compare_with_baseline(self): | |
self._run_compare_with_baseline_sweep(batch_by_size_vec) | |
class TestBatchBySizeFn(TestBatchBySize): | |
def test_compare_with_baseline(self): | |
def batch_by_size_fn_wrapper( | |
indices, | |
num_tokens_vec, | |
max_tokens, | |
max_sentences, | |
bsz_mult, | |
): | |
def num_tokens_fn(idx): | |
return num_tokens_vec[idx] | |
return batch_by_size_fn( | |
indices, num_tokens_fn, max_tokens, max_sentences, bsz_mult | |
) | |
self._run_compare_with_baseline_sweep(batch_by_size_fn_wrapper) | |
if __name__ == "__main__": | |
unittest.main() | |