Spaces:
Sleeping
Sleeping
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
"""Tests for common.utils. | |
""" | |
from collections import Counter | |
import random | |
import tempfile | |
import numpy as np | |
import tensorflow as tf | |
from common import utils # brain coder | |
class UtilsTest(tf.test.TestCase): | |
def testStackPad(self): | |
# 1D. | |
tensors = [[1, 2, 3], [4, 5, 6, 7, 8], [9]] | |
result = utils.stack_pad(tensors, pad_axes=0, pad_to_lengths=6) | |
self.assertTrue(np.array_equal( | |
result, | |
np.asarray([[1, 2, 3, 0, 0, 0], | |
[4, 5, 6, 7, 8, 0], | |
[9, 0, 0, 0, 0, 0]], dtype=np.float32))) | |
# 3D. | |
tensors = [[[[1, 2, 3], [4, 5, 6]]], | |
[[[7, 8, 9], [0, 1, 2]], [[3, 4, 5], [6, 7, 8]]], | |
[[[0, 1, 2]], [[3, 4, 5]]]] | |
result = utils.stack_pad(tensors, pad_axes=[0, 1], pad_to_lengths=[2, 2]) | |
self.assertTrue(np.array_equal( | |
result, | |
np.asarray([[[[1, 2, 3], [4, 5, 6]], | |
[[0, 0, 0], [0, 0, 0]]], | |
[[[7, 8, 9], [0, 1, 2]], | |
[[3, 4, 5], [6, 7, 8]]], | |
[[[0, 1, 2], [0, 0, 0]], | |
[[3, 4, 5], [0, 0, 0]]]], dtype=np.float32))) | |
def testStackPadNoAxes(self): | |
# 2D. | |
tensors = [[[1, 2, 3], [4, 5, 6]], | |
[[7, 8, 9], [1, 2, 3]], | |
[[4, 5, 6], [7, 8, 9]]] | |
result = utils.stack_pad(tensors) | |
self.assertTrue(np.array_equal( | |
result, | |
np.asarray(tensors))) | |
def testStackPadNoneLength(self): | |
# 1D. | |
tensors = [[1, 2, 3], [4, 5, 6, 7, 8], [9]] | |
result = utils.stack_pad(tensors, pad_axes=0, pad_to_lengths=None) | |
self.assertTrue(np.array_equal( | |
result, | |
np.asarray([[1, 2, 3, 0, 0], | |
[4, 5, 6, 7, 8], | |
[9, 0, 0, 0, 0]], dtype=np.float32))) | |
# 3D. | |
tensors = [[[[1, 2, 3], [4, 5, 6]]], | |
[[[7, 8, 9], [0, 1, 2]], [[3, 4, 5], [6, 7, 8]]], | |
[[[0, 1, 2]], [[3, 4, 5]]]] | |
result = utils.stack_pad(tensors, pad_axes=[0, 1], pad_to_lengths=None) | |
self.assertTrue(np.array_equal( | |
result, | |
np.asarray([[[[1, 2, 3], [4, 5, 6]], | |
[[0, 0, 0], [0, 0, 0]]], | |
[[[7, 8, 9], [0, 1, 2]], | |
[[3, 4, 5], [6, 7, 8]]], | |
[[[0, 1, 2], [0, 0, 0]], | |
[[3, 4, 5], [0, 0, 0]]]], dtype=np.float32))) | |
# 3D with partial pad_to_lengths. | |
tensors = [[[[1, 2, 3], [4, 5, 6]]], | |
[[[7, 8, 9], [0, 1, 2]], [[3, 4, 5], [6, 7, 8]]], | |
[[[0, 1, 2]], [[3, 4, 5]]]] | |
result = utils.stack_pad(tensors, pad_axes=[0, 1], pad_to_lengths=[None, 3]) | |
self.assertTrue(np.array_equal( | |
result, | |
np.asarray([[[[1, 2, 3], [4, 5, 6], [0, 0, 0]], | |
[[0, 0, 0], [0, 0, 0], [0, 0, 0]]], | |
[[[7, 8, 9], [0, 1, 2], [0, 0, 0]], | |
[[3, 4, 5], [6, 7, 8], [0, 0, 0]]], | |
[[[0, 1, 2], [0, 0, 0], [0, 0, 0]], | |
[[3, 4, 5], [0, 0, 0], [0, 0, 0]]]], dtype=np.float32))) | |
def testStackPadValueError(self): | |
# 3D. | |
tensors = [[[[1, 2, 3], [4, 5, 6]]], | |
[[[7, 8, 9], [0, 1, 2]], [[3, 4, 5], [6, 7, 8]]], | |
[[[0, 1, 2]], [[3, 4, 5]]], | |
[[[1, 2, 3, 4]]]] | |
# Not all tensors have the same shape along axis 2. | |
with self.assertRaises(ValueError): | |
utils.stack_pad(tensors, pad_axes=[0, 1], pad_to_lengths=[2, 2]) | |
def testRecord(self): | |
my_record = utils.make_record('my_record', ['a', 'b', 'c'], {'b': 55}) | |
inst = my_record(a=1, b=2, c=3) | |
self.assertEqual(1, inst.a) | |
self.assertEqual(2, inst.b) | |
self.assertEqual(3, inst.c) | |
self.assertEqual(1, inst[0]) | |
self.assertEqual(2, inst[1]) | |
self.assertEqual(3, inst[2]) | |
self.assertEqual([1, 2, 3], list(iter(inst))) | |
self.assertEqual(3, len(inst)) | |
inst.b = 999 | |
self.assertEqual(999, inst.b) | |
self.assertEqual(999, inst[1]) | |
inst2 = my_record(1, 999, 3) | |
self.assertTrue(inst == inst2) | |
inst2[1] = 3 | |
self.assertFalse(inst == inst2) | |
inst3 = my_record(a=1, c=3) | |
inst.b = 55 | |
self.assertEqual(inst, inst3) | |
def testRecordUnique(self): | |
record1 = utils.make_record('record1', ['a', 'b', 'c']) | |
record2 = utils.make_record('record2', ['a', 'b', 'c']) | |
self.assertNotEqual(record1(1, 2, 3), record2(1, 2, 3)) | |
self.assertEqual(record1(1, 2, 3), record1(1, 2, 3)) | |
def testTupleToRecord(self): | |
my_record = utils.make_record('my_record', ['a', 'b', 'c']) | |
inst = utils.tuple_to_record((5, 6, 7), my_record) | |
self.assertEqual(my_record(5, 6, 7), inst) | |
def testRecordErrors(self): | |
my_record = utils.make_record('my_record', ['a', 'b', 'c'], {'b': 10}) | |
with self.assertRaises(ValueError): | |
my_record(c=5) # Did not provide required argument 'a'. | |
with self.assertRaises(ValueError): | |
my_record(1, 2, 3, 4) # Too many arguments. | |
def testRandomQueue(self): | |
np.random.seed(567890) | |
queue = utils.RandomQueue(5) | |
queue.push(5) | |
queue.push(6) | |
queue.push(7) | |
queue.push(8) | |
queue.push(9) | |
queue.push(10) | |
self.assertTrue(5 not in queue) | |
sample = queue.random_sample(1000) | |
self.assertEqual(1000, len(sample)) | |
self.assertEqual([6, 7, 8, 9, 10], sorted(np.unique(sample).tolist())) | |
def testMaxUniquePriorityQueue(self): | |
queue = utils.MaxUniquePriorityQueue(5) | |
queue.push(1.0, 'string 1') | |
queue.push(-0.5, 'string 2') | |
queue.push(0.5, 'string 3') | |
self.assertEqual((-0.5, 'string 2', None), queue.pop()) | |
queue.push(0.1, 'string 4') | |
queue.push(1.5, 'string 5') | |
queue.push(0.0, 'string 6') | |
queue.push(0.2, 'string 7') | |
self.assertEqual((1.5, 'string 5', None), queue.get_max()) | |
self.assertEqual((0.1, 'string 4', None), queue.get_min()) | |
self.assertEqual( | |
[('string 5', None), ('string 1', None), ('string 3', None), | |
('string 7', None), ('string 4', None)], | |
list(queue.iter_in_order())) | |
def testMaxUniquePriorityQueue_Duplicates(self): | |
queue = utils.MaxUniquePriorityQueue(5) | |
queue.push(0.0, 'string 1') | |
queue.push(0.0, 'string 2') | |
queue.push(0.0, 'string 3') | |
self.assertEqual((0.0, 'string 1', None), queue.pop()) | |
self.assertEqual((0.0, 'string 2', None), queue.pop()) | |
self.assertEqual((0.0, 'string 3', None), queue.pop()) | |
self.assertEqual(0, len(queue)) | |
queue.push(0.1, 'string 4') | |
queue.push(1.5, 'string 5') | |
queue.push(0.3, 'string 6') | |
queue.push(0.2, 'string 7') | |
queue.push(0.0, 'string 8') | |
queue.push(1.5, 'string 5') | |
queue.push(1.5, 'string 5') | |
self.assertEqual((1.5, 'string 5', None), queue.get_max()) | |
self.assertEqual((0.0, 'string 8', None), queue.get_min()) | |
self.assertEqual( | |
[('string 5', None), ('string 6', None), ('string 7', None), | |
('string 4', None), ('string 8', None)], | |
list(queue.iter_in_order())) | |
def testMaxUniquePriorityQueue_ExtraData(self): | |
queue = utils.MaxUniquePriorityQueue(5) | |
queue.push(1.0, 'string 1', [1, 2, 3]) | |
queue.push(0.5, 'string 2', [4, 5, 6]) | |
queue.push(0.5, 'string 3', [7, 8, 9]) | |
queue.push(0.5, 'string 2', [10, 11, 12]) | |
self.assertEqual((0.5, 'string 2', [4, 5, 6]), queue.pop()) | |
self.assertEqual((0.5, 'string 3', [7, 8, 9]), queue.pop()) | |
self.assertEqual((1.0, 'string 1', [1, 2, 3]), queue.pop()) | |
self.assertEqual(0, len(queue)) | |
queue.push(0.5, 'string 2', [10, 11, 12]) | |
self.assertEqual((0.5, 'string 2', [10, 11, 12]), queue.pop()) | |
def testRouletteWheel(self): | |
random.seed(12345678987654321) | |
r = utils.RouletteWheel() | |
self.assertTrue(r.is_empty()) | |
with self.assertRaises(RuntimeError): | |
r.sample() # Cannot sample when empty. | |
self.assertEqual(0, r.total_weight) | |
self.assertEqual(True, r.add('a', 0.1)) | |
self.assertFalse(r.is_empty()) | |
self.assertEqual(0.1, r.total_weight) | |
self.assertEqual(True, r.add('b', 0.01)) | |
self.assertEqual(0.11, r.total_weight) | |
self.assertEqual(True, r.add('c', 0.5)) | |
self.assertEqual(True, r.add('d', 0.1)) | |
self.assertEqual(True, r.add('e', 0.05)) | |
self.assertEqual(True, r.add('f', 0.03)) | |
self.assertEqual(True, r.add('g', 0.001)) | |
self.assertEqual(0.791, r.total_weight) | |
self.assertFalse(r.is_empty()) | |
# Check that sampling is correct. | |
obj, weight = r.sample() | |
self.assertTrue(isinstance(weight, float), 'Type: %s' % type(weight)) | |
self.assertTrue((obj, weight) in r) | |
for obj, weight in r.sample_many(100): | |
self.assertTrue(isinstance(weight, float), 'Type: %s' % type(weight)) | |
self.assertTrue((obj, weight) in r) | |
# Check that sampling distribution is correct. | |
n = 1000000 | |
c = Counter(r.sample_many(n)) | |
for obj, w in r: | |
estimated_w = c[(obj, w)] / float(n) * r.total_weight | |
self.assertTrue( | |
np.isclose(w, estimated_w, atol=1e-3), | |
'Expected %s, got %s, for object %s' % (w, estimated_w, obj)) | |
def testRouletteWheel_AddMany(self): | |
random.seed(12345678987654321) | |
r = utils.RouletteWheel() | |
self.assertTrue(r.is_empty()) | |
with self.assertRaises(RuntimeError): | |
r.sample() # Cannot sample when empty. | |
self.assertEqual(0, r.total_weight) | |
count = r.add_many( | |
['a', 'b', 'c', 'd', 'e', 'f', 'g'], | |
[0.1, 0.01, 0.5, 0.1, 0.05, 0.03, 0.001]) | |
self.assertEqual(7, count) | |
self.assertFalse(r.is_empty()) | |
self.assertEqual(0.791, r.total_weight) | |
# Adding no items is allowed. | |
count = r.add_many([], []) | |
self.assertEqual(0, count) | |
self.assertFalse(r.is_empty()) | |
self.assertEqual(0.791, r.total_weight) | |
# Check that sampling is correct. | |
obj, weight = r.sample() | |
self.assertTrue(isinstance(weight, float), 'Type: %s' % type(weight)) | |
self.assertTrue((obj, weight) in r) | |
for obj, weight in r.sample_many(100): | |
self.assertTrue(isinstance(weight, float), 'Type: %s' % type(weight)) | |
self.assertTrue((obj, weight) in r) | |
# Check that sampling distribution is correct. | |
n = 1000000 | |
c = Counter(r.sample_many(n)) | |
for obj, w in r: | |
estimated_w = c[(obj, w)] / float(n) * r.total_weight | |
self.assertTrue( | |
np.isclose(w, estimated_w, atol=1e-3), | |
'Expected %s, got %s, for object %s' % (w, estimated_w, obj)) | |
def testRouletteWheel_AddZeroWeights(self): | |
r = utils.RouletteWheel() | |
self.assertEqual(True, r.add('a', 0)) | |
self.assertFalse(r.is_empty()) | |
self.assertEqual(4, r.add_many(['b', 'c', 'd', 'e'], [0, 0.1, 0, 0])) | |
self.assertEqual( | |
[('a', 0.0), ('b', 0.0), ('c', 0.1), ('d', 0.0), ('e', 0.0)], | |
list(r)) | |
def testRouletteWheel_UniqueMode(self): | |
random.seed(12345678987654321) | |
r = utils.RouletteWheel(unique_mode=True) | |
self.assertEqual(True, r.add([1, 2, 3], 1, 'a')) | |
self.assertEqual(True, r.add([4, 5], 0.5, 'b')) | |
self.assertEqual(False, r.add([1, 2, 3], 1.5, 'a')) | |
self.assertEqual( | |
[([1, 2, 3], 1.0), ([4, 5], 0.5)], | |
list(r)) | |
self.assertEqual(1.5, r.total_weight) | |
self.assertEqual( | |
2, | |
r.add_many( | |
[[5, 6, 2, 3], [1, 2, 3], [8], [1, 2, 3]], | |
[0.1, 0.2, 0.1, 2.0], | |
['c', 'a', 'd', 'a'])) | |
self.assertEqual( | |
[([1, 2, 3], 1.0), ([4, 5], 0.5), ([5, 6, 2, 3], 0.1), ([8], 0.1)], | |
list(r)) | |
self.assertTrue(np.isclose(1.7, r.total_weight)) | |
self.assertEqual(0, r.add_many([], [], [])) # Adding no items is allowed. | |
with self.assertRaises(ValueError): | |
# Key not given. | |
r.add([7, 8, 9], 2.0) | |
with self.assertRaises(ValueError): | |
# Keys not given. | |
r.add_many([[7, 8, 9], [10]], [2.0, 2.0]) | |
self.assertEqual(True, r.has_key('a')) | |
self.assertEqual(True, r.has_key('b')) | |
self.assertEqual(False, r.has_key('z')) | |
self.assertEqual(1.0, r.get_weight('a')) | |
self.assertEqual(0.5, r.get_weight('b')) | |
r = utils.RouletteWheel(unique_mode=False) | |
self.assertEqual(True, r.add([1, 2, 3], 1)) | |
self.assertEqual(True, r.add([4, 5], 0.5)) | |
self.assertEqual(True, r.add([1, 2, 3], 1.5)) | |
self.assertEqual( | |
[([1, 2, 3], 1.0), ([4, 5], 0.5), ([1, 2, 3], 1.5)], | |
list(r)) | |
self.assertEqual(3, r.total_weight) | |
self.assertEqual( | |
4, | |
r.add_many( | |
[[5, 6, 2, 3], [1, 2, 3], [8], [1, 2, 3]], | |
[0.1, 0.2, 0.1, 0.2])) | |
self.assertEqual( | |
[([1, 2, 3], 1.0), ([4, 5], 0.5), ([1, 2, 3], 1.5), | |
([5, 6, 2, 3], 0.1), ([1, 2, 3], 0.2), ([8], 0.1), ([1, 2, 3], 0.2)], | |
list(r)) | |
self.assertTrue(np.isclose(3.6, r.total_weight)) | |
with self.assertRaises(ValueError): | |
# Key is given. | |
r.add([7, 8, 9], 2.0, 'a') | |
with self.assertRaises(ValueError): | |
# Keys are given. | |
r.add_many([[7, 8, 9], [10]], [2.0, 2.0], ['a', 'b']) | |
def testRouletteWheel_IncrementalSave(self): | |
f = tempfile.NamedTemporaryFile() | |
r = utils.RouletteWheel(unique_mode=True, save_file=f.name) | |
entries = [ | |
([1, 2, 3], 0.1, 'a'), | |
([4, 5], 0.2, 'b'), | |
([6], 0.3, 'c'), | |
([7, 8, 9, 10], 0.25, 'd'), | |
([-1, -2], 0.15, 'e'), | |
([-3, -4, -5], 0.5, 'f')] | |
self.assertTrue(r.is_empty()) | |
for i in range(0, len(entries), 2): | |
r.add(*entries[i]) | |
r.add(*entries[i + 1]) | |
r.incremental_save() | |
r2 = utils.RouletteWheel(unique_mode=True, save_file=f.name) | |
self.assertEqual(i + 2, len(r2)) | |
count = 0 | |
for j, (obj, weight) in enumerate(r2): | |
self.assertEqual(entries[j][0], obj) | |
self.assertEqual(entries[j][1], weight) | |
self.assertEqual(weight, r2.get_weight(entries[j][2])) | |
count += 1 | |
self.assertEqual(i + 2, count) | |
if __name__ == '__main__': | |
tf.test.main() | |