# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Unit tests for `samplers.py`.""" from absl.testing import absltest from absl.testing import parameterized import chex from clrs._src import probing from clrs._src import samplers from clrs._src import specs import jax import numpy as np class SamplersTest(parameterized.TestCase): @parameterized.parameters(*specs.CLRS_30_ALGS) def test_sampler_determinism(self, name): num_samples = 3 num_nodes = 10 sampler, _ = samplers.build_sampler(name, num_samples, num_nodes) np.random.seed(47) # Set seed feedback = sampler.next() expected = feedback.outputs[0].data.copy() np.random.seed(48) # Set a different seed feedback = sampler.next() actual = feedback.outputs[0].data.copy() # Validate that datasets are the same. np.testing.assert_array_equal(expected, actual) @parameterized.parameters(*specs.CLRS_30_ALGS) def test_sampler_batch_determinism(self, name): num_samples = 10 batch_size = 5 num_nodes = 10 seed = 0 sampler_1, _ = samplers.build_sampler( name, num_samples, num_nodes, seed=seed) sampler_2, _ = samplers.build_sampler( name, num_samples, num_nodes, seed=seed) feedback_1 = sampler_1.next(batch_size) feedback_2 = sampler_2.next(batch_size) # Validate that datasets are the same. jax.tree_util.tree_map(np.testing.assert_array_equal, feedback_1, feedback_2) def test_end_to_end(self): num_samples = 7 num_nodes = 3 sampler, _ = samplers.build_sampler("bfs", num_samples, num_nodes) feedback = sampler.next() inputs = feedback.features.inputs self.assertLen(inputs, 4) self.assertEqual(inputs[0].name, "pos") self.assertEqual(inputs[0].data.shape, (num_samples, num_nodes)) outputs = feedback.outputs self.assertLen(outputs, 1) self.assertEqual(outputs[0].name, "pi") self.assertEqual(outputs[0].data.shape, (num_samples, num_nodes)) def test_batch_size(self): num_samples = 7 num_nodes = 3 sampler, _ = samplers.build_sampler("bfs", num_samples, num_nodes) # Full-batch. feedback = sampler.next() for dp in feedback.features.inputs: # [B, ...] self.assertEqual(dp.data.shape[0], num_samples) for dp in feedback.outputs: # [B, ...] self.assertEqual(dp.data.shape[0], num_samples) for dp in feedback.features.hints: # [T, B, ...] self.assertEqual(dp.data.shape[1], num_samples) self.assertLen(feedback.features.lengths, num_samples) # Specified batch. batch_size = 5 feedback = sampler.next(batch_size) for dp in feedback.features.inputs: # [B, ...] self.assertEqual(dp.data.shape[0], batch_size) for dp in feedback.outputs: # [B, ...] self.assertEqual(dp.data.shape[0], batch_size) for dp in feedback.features.hints: # [T, B, ...] self.assertEqual(dp.data.shape[1], batch_size) self.assertLen(feedback.features.lengths, batch_size) def test_batch_io(self): sample = [ probing.DataPoint( name="x", location=specs.Location.NODE, type_=specs.Type.SCALAR, data=np.zeros([1, 3]), ), probing.DataPoint( name="y", location=specs.Location.EDGE, type_=specs.Type.MASK, data=np.zeros([1, 3, 3]), ), ] trajectory = [sample.copy(), sample.copy(), sample.copy(), sample.copy()] batched = samplers._batch_io(trajectory) np.testing.assert_array_equal(batched[0].data, np.zeros([4, 3])) np.testing.assert_array_equal(batched[1].data, np.zeros([4, 3, 3])) def test_batch_hint(self): sample0 = [ probing.DataPoint( name="x", location=specs.Location.NODE, type_=specs.Type.MASK, data=np.zeros([2, 1, 3]), ), probing.DataPoint( name="y", location=specs.Location.NODE, type_=specs.Type.POINTER, data=np.zeros([2, 1, 3]), ), ] sample1 = [ probing.DataPoint( name="x", location=specs.Location.NODE, type_=specs.Type.MASK, data=np.zeros([1, 1, 3]), ), probing.DataPoint( name="y", location=specs.Location.NODE, type_=specs.Type.POINTER, data=np.zeros([1, 1, 3]), ), ] trajectory = [sample0, sample1] batched, lengths = samplers._batch_hints(trajectory, 0) np.testing.assert_array_equal(batched[0].data, np.zeros([2, 2, 3])) np.testing.assert_array_equal(batched[1].data, np.zeros([2, 2, 3])) np.testing.assert_array_equal(lengths, np.array([2, 1])) batched, lengths = samplers._batch_hints(trajectory, 5) np.testing.assert_array_equal(batched[0].data, np.zeros([5, 2, 3])) np.testing.assert_array_equal(batched[1].data, np.zeros([5, 2, 3])) np.testing.assert_array_equal(lengths, np.array([2, 1])) def test_padding(self): lens = np.random.choice(10, (10,), replace=True) + 1 trajectory = [] for len_ in lens: trajectory.append([ probing.DataPoint( name="x", location=specs.Location.NODE, type_=specs.Type.MASK, data=np.ones([len_, 1, 3]), ) ]) batched, lengths = samplers._batch_hints(trajectory, 0) np.testing.assert_array_equal(lengths, lens) for i in range(len(lens)): ones = batched[0].data[:lens[i], i, :] zeros = batched[0].data[lens[i]:, i, :] np.testing.assert_array_equal(ones, np.ones_like(ones)) np.testing.assert_array_equal(zeros, np.zeros_like(zeros)) class ProcessRandomPosTest(parameterized.TestCase): @parameterized.parameters(["insertion_sort", "naive_string_matcher"]) def test_random_pos(self, algorithm_name): batch_size, length = 12, 10 def _make_sampler(): sampler, _ = samplers.build_sampler( algorithm_name, seed=0, num_samples=100, length=length, ) while True: yield sampler.next(batch_size) sampler_1 = _make_sampler() sampler_2 = _make_sampler() sampler_2 = samplers.process_random_pos(sampler_2, np.random.RandomState(0)) batch_without_rand_pos = next(sampler_1) batch_with_rand_pos = next(sampler_2) pos_idx = [x.name for x in batch_without_rand_pos.features.inputs].index( "pos") fixed_pos = batch_without_rand_pos.features.inputs[pos_idx] rand_pos = batch_with_rand_pos.features.inputs[pos_idx] self.assertEqual(rand_pos.location, specs.Location.NODE) self.assertEqual(rand_pos.type_, specs.Type.SCALAR) self.assertEqual(rand_pos.data.shape, (batch_size, length)) self.assertEqual(rand_pos.data.shape, fixed_pos.data.shape) self.assertEqual(rand_pos.type_, fixed_pos.type_) self.assertEqual(rand_pos.location, fixed_pos.location) assert (rand_pos.data.std(axis=0) > 1e-3).all() assert (fixed_pos.data.std(axis=0) < 1e-9).all() if "string" in algorithm_name: expected = np.concatenate([np.arange(4*length//5)/(4*length//5), np.arange(length//5)/(length//5)]) else: expected = np.arange(length)/length np.testing.assert_array_equal( fixed_pos.data, np.broadcast_to(expected, (batch_size, length))) batch_with_rand_pos.features.inputs[pos_idx] = fixed_pos chex.assert_trees_all_equal(batch_with_rand_pos, batch_without_rand_pos) if __name__ == "__main__": absltest.main()