Spaces:
Running
Running
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""Unit tests for `samplers.py`.""" | |
from absl.testing import absltest | |
from absl.testing import parameterized | |
import chex | |
from clrs._src import probing | |
from clrs._src import samplers | |
from clrs._src import specs | |
import jax | |
import numpy as np | |
class SamplersTest(parameterized.TestCase): | |
def test_sampler_determinism(self, name): | |
num_samples = 3 | |
num_nodes = 10 | |
sampler, _ = samplers.build_sampler(name, num_samples, num_nodes) | |
np.random.seed(47) # Set seed | |
feedback = sampler.next() | |
expected = feedback.outputs[0].data.copy() | |
np.random.seed(48) # Set a different seed | |
feedback = sampler.next() | |
actual = feedback.outputs[0].data.copy() | |
# Validate that datasets are the same. | |
np.testing.assert_array_equal(expected, actual) | |
def test_sampler_batch_determinism(self, name): | |
num_samples = 10 | |
batch_size = 5 | |
num_nodes = 10 | |
seed = 0 | |
sampler_1, _ = samplers.build_sampler( | |
name, num_samples, num_nodes, seed=seed) | |
sampler_2, _ = samplers.build_sampler( | |
name, num_samples, num_nodes, seed=seed) | |
feedback_1 = sampler_1.next(batch_size) | |
feedback_2 = sampler_2.next(batch_size) | |
# Validate that datasets are the same. | |
jax.tree_util.tree_map(np.testing.assert_array_equal, feedback_1, | |
feedback_2) | |
def test_end_to_end(self): | |
num_samples = 7 | |
num_nodes = 3 | |
sampler, _ = samplers.build_sampler("bfs", num_samples, num_nodes) | |
feedback = sampler.next() | |
inputs = feedback.features.inputs | |
self.assertLen(inputs, 4) | |
self.assertEqual(inputs[0].name, "pos") | |
self.assertEqual(inputs[0].data.shape, (num_samples, num_nodes)) | |
outputs = feedback.outputs | |
self.assertLen(outputs, 1) | |
self.assertEqual(outputs[0].name, "pi") | |
self.assertEqual(outputs[0].data.shape, (num_samples, num_nodes)) | |
def test_batch_size(self): | |
num_samples = 7 | |
num_nodes = 3 | |
sampler, _ = samplers.build_sampler("bfs", num_samples, num_nodes) | |
# Full-batch. | |
feedback = sampler.next() | |
for dp in feedback.features.inputs: # [B, ...] | |
self.assertEqual(dp.data.shape[0], num_samples) | |
for dp in feedback.outputs: # [B, ...] | |
self.assertEqual(dp.data.shape[0], num_samples) | |
for dp in feedback.features.hints: # [T, B, ...] | |
self.assertEqual(dp.data.shape[1], num_samples) | |
self.assertLen(feedback.features.lengths, num_samples) | |
# Specified batch. | |
batch_size = 5 | |
feedback = sampler.next(batch_size) | |
for dp in feedback.features.inputs: # [B, ...] | |
self.assertEqual(dp.data.shape[0], batch_size) | |
for dp in feedback.outputs: # [B, ...] | |
self.assertEqual(dp.data.shape[0], batch_size) | |
for dp in feedback.features.hints: # [T, B, ...] | |
self.assertEqual(dp.data.shape[1], batch_size) | |
self.assertLen(feedback.features.lengths, batch_size) | |
def test_batch_io(self): | |
sample = [ | |
probing.DataPoint( | |
name="x", | |
location=specs.Location.NODE, | |
type_=specs.Type.SCALAR, | |
data=np.zeros([1, 3]), | |
), | |
probing.DataPoint( | |
name="y", | |
location=specs.Location.EDGE, | |
type_=specs.Type.MASK, | |
data=np.zeros([1, 3, 3]), | |
), | |
] | |
trajectory = [sample.copy(), sample.copy(), sample.copy(), sample.copy()] | |
batched = samplers._batch_io(trajectory) | |
np.testing.assert_array_equal(batched[0].data, np.zeros([4, 3])) | |
np.testing.assert_array_equal(batched[1].data, np.zeros([4, 3, 3])) | |
def test_batch_hint(self): | |
sample0 = [ | |
probing.DataPoint( | |
name="x", | |
location=specs.Location.NODE, | |
type_=specs.Type.MASK, | |
data=np.zeros([2, 1, 3]), | |
), | |
probing.DataPoint( | |
name="y", | |
location=specs.Location.NODE, | |
type_=specs.Type.POINTER, | |
data=np.zeros([2, 1, 3]), | |
), | |
] | |
sample1 = [ | |
probing.DataPoint( | |
name="x", | |
location=specs.Location.NODE, | |
type_=specs.Type.MASK, | |
data=np.zeros([1, 1, 3]), | |
), | |
probing.DataPoint( | |
name="y", | |
location=specs.Location.NODE, | |
type_=specs.Type.POINTER, | |
data=np.zeros([1, 1, 3]), | |
), | |
] | |
trajectory = [sample0, sample1] | |
batched, lengths = samplers._batch_hints(trajectory, 0) | |
np.testing.assert_array_equal(batched[0].data, np.zeros([2, 2, 3])) | |
np.testing.assert_array_equal(batched[1].data, np.zeros([2, 2, 3])) | |
np.testing.assert_array_equal(lengths, np.array([2, 1])) | |
batched, lengths = samplers._batch_hints(trajectory, 5) | |
np.testing.assert_array_equal(batched[0].data, np.zeros([5, 2, 3])) | |
np.testing.assert_array_equal(batched[1].data, np.zeros([5, 2, 3])) | |
np.testing.assert_array_equal(lengths, np.array([2, 1])) | |
def test_padding(self): | |
lens = np.random.choice(10, (10,), replace=True) + 1 | |
trajectory = [] | |
for len_ in lens: | |
trajectory.append([ | |
probing.DataPoint( | |
name="x", | |
location=specs.Location.NODE, | |
type_=specs.Type.MASK, | |
data=np.ones([len_, 1, 3]), | |
) | |
]) | |
batched, lengths = samplers._batch_hints(trajectory, 0) | |
np.testing.assert_array_equal(lengths, lens) | |
for i in range(len(lens)): | |
ones = batched[0].data[:lens[i], i, :] | |
zeros = batched[0].data[lens[i]:, i, :] | |
np.testing.assert_array_equal(ones, np.ones_like(ones)) | |
np.testing.assert_array_equal(zeros, np.zeros_like(zeros)) | |
class ProcessRandomPosTest(parameterized.TestCase): | |
def test_random_pos(self, algorithm_name): | |
batch_size, length = 12, 10 | |
def _make_sampler(): | |
sampler, _ = samplers.build_sampler( | |
algorithm_name, | |
seed=0, | |
num_samples=100, | |
length=length, | |
) | |
while True: | |
yield sampler.next(batch_size) | |
sampler_1 = _make_sampler() | |
sampler_2 = _make_sampler() | |
sampler_2 = samplers.process_random_pos(sampler_2, np.random.RandomState(0)) | |
batch_without_rand_pos = next(sampler_1) | |
batch_with_rand_pos = next(sampler_2) | |
pos_idx = [x.name for x in batch_without_rand_pos.features.inputs].index( | |
"pos") | |
fixed_pos = batch_without_rand_pos.features.inputs[pos_idx] | |
rand_pos = batch_with_rand_pos.features.inputs[pos_idx] | |
self.assertEqual(rand_pos.location, specs.Location.NODE) | |
self.assertEqual(rand_pos.type_, specs.Type.SCALAR) | |
self.assertEqual(rand_pos.data.shape, (batch_size, length)) | |
self.assertEqual(rand_pos.data.shape, fixed_pos.data.shape) | |
self.assertEqual(rand_pos.type_, fixed_pos.type_) | |
self.assertEqual(rand_pos.location, fixed_pos.location) | |
assert (rand_pos.data.std(axis=0) > 1e-3).all() | |
assert (fixed_pos.data.std(axis=0) < 1e-9).all() | |
if "string" in algorithm_name: | |
expected = np.concatenate([np.arange(4*length//5)/(4*length//5), | |
np.arange(length//5)/(length//5)]) | |
else: | |
expected = np.arange(length)/length | |
np.testing.assert_array_equal( | |
fixed_pos.data, np.broadcast_to(expected, (batch_size, length))) | |
batch_with_rand_pos.features.inputs[pos_idx] = fixed_pos | |
chex.assert_trees_all_equal(batch_with_rand_pos, batch_without_rand_pos) | |
if __name__ == "__main__": | |
absltest.main() | |